summaryrefslogtreecommitdiff
path: root/server/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'server/main.go')
-rw-r--r--server/main.go118
1 files changed, 72 insertions, 46 deletions
diff --git a/server/main.go b/server/main.go
index 1eb1b46..bfc827a 100644
--- a/server/main.go
+++ b/server/main.go
@@ -1,6 +1,7 @@
package main
import (
+ "context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
@@ -11,6 +12,7 @@ import (
"fmt"
"io"
"net/http"
+ "time"
"github.com/gorilla/mux"
_ "github.com/lib/pq"
@@ -26,7 +28,11 @@ const (
var db *sql.DB
func main() {
- db = initDB()
+ var err error
+ db, err = initDB()
+ if err != nil {
+ panic(err)
+ }
defer db.Close()
router := mux.NewRouter()
@@ -43,19 +49,24 @@ func main() {
http.ListenAndServe(":8080", handler)
}
-func initDB() *sql.DB {
+func initDB() (*sql.DB, error) {
db, err := sql.Open("postgres", "postgres://file:password@localhost/filedb?sslmode=disable")
if err != nil {
- panic(err)
+ return nil, err
}
- if err := createFilesTable(db); err != nil {
- panic(err)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ if err := createFilesTable(ctx, db); err != nil {
+ return nil, err
}
- return db
+
+ return db, nil
}
-func createFilesTable(db *sql.DB) error {
- _, err := db.Exec(`
+func createFilesTable(ctx context.Context, db *sql.DB) error {
+ _, err := db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS files (
id TEXT PRIMARY KEY,
name TEXT,
@@ -66,17 +77,20 @@ func createFilesTable(db *sql.DB) error {
}
func handleUpload(w http.ResponseWriter, r *http.Request) {
- r.ParseMultipartForm(maxUploadSize)
+ if err := r.ParseMultipartForm(maxUploadSize); err != nil {
+ handleError(w, fmt.Errorf("error parsing multipart form: %v", err), http.StatusBadRequest)
+ return
+ }
key, err := generateRandomKey()
if err != nil {
- handleError(w, err, http.StatusInternalServerError)
+ handleError(w, fmt.Errorf("error generating encryption key: %v", err), http.StatusInternalServerError)
return
}
file, handler, err := r.FormFile("file")
if err != nil {
- handleError(w, errors.New("error parsing uploaded file"), http.StatusBadRequest)
+ handleError(w, fmt.Errorf("error getting form file: %v", err), http.StatusBadRequest)
return
}
defer file.Close()
@@ -85,13 +99,12 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
encryptedData, err := encryptFile(file, key)
if err != nil {
- handleError(w, errors.New("error encrypting file"), http.StatusInternalServerError)
+ handleError(w, fmt.Errorf("error encrypting file: %v", err), http.StatusInternalServerError)
return
}
- err = storeFileInDB(id, handler.Filename, encryptedData)
- if err != nil {
- handleError(w, errors.New("error storing file in database"), http.StatusInternalServerError)
+ if err := storeFileInDB(r.Context(), id, handler.Filename, encryptedData); err != nil {
+ handleError(w, fmt.Errorf("error storing file in database: %v", err), http.StatusInternalServerError)
return
}
@@ -101,56 +114,49 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
}
func handleDownload(w http.ResponseWriter, r *http.Request) {
- vars := mux.Vars(r)
- id := vars["id"]
-
+ id := mux.Vars(r)["id"]
keyHex := r.URL.Query().Get("key")
+
key, err := hex.DecodeString(keyHex)
if err != nil {
- handleError(w, errors.New("invalid key"), http.StatusBadRequest)
+ handleError(w, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest)
return
}
- fileName, encryptedData, err := getFileFromDB(id)
+ fileName, encryptedData, err := getFileFromDB(r.Context(), id)
if err != nil {
- handleError(w, err, http.StatusInternalServerError)
+ handleError(w, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, fileName))
- plaintext, err := decryptFile(encryptedData, key)
+ err = decryptAndStreamFile(w, encryptedData, key)
if err != nil {
- handleError(w, errors.New("error decrypting file"), http.StatusInternalServerError)
- return
- }
-
- if _, err := w.Write(plaintext); err != nil {
- handleError(w, errors.New("error writing response"), http.StatusInternalServerError)
+ handleError(w, fmt.Errorf("error decrypting and streaming file: %v", err), http.StatusInternalServerError)
return
}
}
func handleGetFileInfo(w http.ResponseWriter, r *http.Request) {
- vars := mux.Vars(r)
- id := vars["id"]
-
+ id := mux.Vars(r)["id"]
keyHex := r.URL.Query().Get("key")
+
key, err := hex.DecodeString(keyHex)
if err != nil {
- handleError(w, errors.New("invalid key"), http.StatusBadRequest)
+ handleError(w, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest)
return
}
- fileName, encryptedData, err := getFileFromDB(id)
+ fileName, encryptedData, err := getFileFromDB(r.Context(), id)
if err != nil {
- handleError(w, err, http.StatusInternalServerError)
+ handleError(w, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError)
return
}
plaintext, err := decryptFile(encryptedData, key)
if err != nil {
- handleError(w, errors.New("error decrypting file"), http.StatusInternalServerError)
+ handleError(w, fmt.Errorf("error decrypting file: %v", err), http.StatusInternalServerError)
return
}
@@ -174,14 +180,14 @@ func handleGetFileInfo(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(fileInfo)
}
-func storeFileInDB(id, fileName string, encryptedData []byte) error {
- tx, err := db.Begin()
+func storeFileInDB(ctx context.Context, id, fileName string, encryptedData []byte) error {
+ tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
- _, err = tx.Exec("INSERT INTO files (id, name, data) VALUES ($1, $2, $3)", id, fileName, encryptedData)
+ _, err = tx.ExecContext(ctx, "INSERT INTO files (id, name, data) VALUES ($1, $2, $3)", id, fileName, encryptedData)
if err != nil {
return err
}
@@ -189,8 +195,8 @@ func storeFileInDB(id, fileName string, encryptedData []byte) error {
return tx.Commit()
}
-func getFileFromDB(id string) (fileName string, encryptedData []byte, err error) {
- err = db.QueryRow("SELECT name, data FROM files WHERE id = $1", id).Scan(&fileName, &encryptedData)
+func getFileFromDB(ctx context.Context, id string) (fileName string, encryptedData []byte, err error) {
+ err = db.QueryRowContext(ctx, "SELECT name, data FROM files WHERE id = $1", id).Scan(&fileName, &encryptedData)
if err == sql.ErrNoRows {
return "", nil, errors.New("file not found")
}
@@ -226,28 +232,48 @@ func encryptFile(in io.Reader, key []byte) ([]byte, error) {
return append(nonce, ciphertext...), nil
}
-func decryptFile(encryptedData []byte, key []byte) ([]byte, error) {
+func decryptAndStreamFile(w io.Writer, encryptedData []byte, key []byte) error {
+ if len(encryptedData) < nonceSize {
+ return errors.New("ciphertext too short")
+ }
+
block, err := aes.NewCipher(key)
if err != nil {
- return nil, err
+ return err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
- return nil, err
+ return err
+ }
+
+ nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:]
+ plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil)
+ if err != nil {
+ return err
}
+ _, err = w.Write(plaintext)
+ return err
+}
+
+func decryptFile(encryptedData []byte, key []byte) ([]byte, error) {
if len(encryptedData) < nonceSize {
return nil, errors.New("ciphertext too short")
}
- nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:]
- plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil)
+ block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
- return plaintext, nil
+ aesgcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return nil, err
+ }
+
+ nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:]
+ return aesgcm.Open(nil, nonce, ciphertext, nil)
}
func generateRandomKey() ([]byte, error) {