diff --git a/server/main.go b/server/main.go index 1eb1b46..bfc827a 100644 --- a/server/main.go +++ b/server/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/aes" "crypto/cipher" "crypto/rand" @@ -11,6 +12,7 @@ import ( "fmt" "io" "net/http" + "time" "github.com/gorilla/mux" _ "github.com/lib/pq" @@ -26,7 +28,11 @@ const ( var db *sql.DB func main() { - db = initDB() + var err error + db, err = initDB() + if err != nil { + panic(err) + } defer db.Close() router := mux.NewRouter() @@ -43,19 +49,24 @@ func main() { http.ListenAndServe(":8080", handler) } -func initDB() *sql.DB { +func initDB() (*sql.DB, error) { db, err := sql.Open("postgres", "postgres://file:password@localhost/filedb?sslmode=disable") if err != nil { - panic(err) + return nil, err } - if err := createFilesTable(db); err != nil { - panic(err) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := createFilesTable(ctx, db); err != nil { + return nil, err } - return db + + return db, nil } -func createFilesTable(db *sql.DB) error { - _, err := db.Exec(` +func createFilesTable(ctx context.Context, db *sql.DB) error { + _, err := db.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS files ( id TEXT PRIMARY KEY, name TEXT, @@ -66,17 +77,20 @@ func createFilesTable(db *sql.DB) error { } func handleUpload(w http.ResponseWriter, r *http.Request) { - r.ParseMultipartForm(maxUploadSize) + if err := r.ParseMultipartForm(maxUploadSize); err != nil { + handleError(w, fmt.Errorf("error parsing multipart form: %v", err), http.StatusBadRequest) + return + } key, err := generateRandomKey() if err != nil { - handleError(w, err, http.StatusInternalServerError) + handleError(w, fmt.Errorf("error generating encryption key: %v", err), http.StatusInternalServerError) return } file, handler, err := r.FormFile("file") if err != nil { - handleError(w, errors.New("error parsing uploaded file"), http.StatusBadRequest) + handleError(w, fmt.Errorf("error getting form file: %v", err), http.StatusBadRequest) return } defer file.Close() @@ -85,13 +99,12 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { encryptedData, err := encryptFile(file, key) if err != nil { - handleError(w, errors.New("error encrypting file"), http.StatusInternalServerError) + handleError(w, fmt.Errorf("error encrypting file: %v", err), http.StatusInternalServerError) return } - err = storeFileInDB(id, handler.Filename, encryptedData) - if err != nil { - handleError(w, errors.New("error storing file in database"), http.StatusInternalServerError) + if err := storeFileInDB(r.Context(), id, handler.Filename, encryptedData); err != nil { + handleError(w, fmt.Errorf("error storing file in database: %v", err), http.StatusInternalServerError) return } @@ -101,56 +114,49 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { } func handleDownload(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - id := vars["id"] - + id := mux.Vars(r)["id"] keyHex := r.URL.Query().Get("key") + key, err := hex.DecodeString(keyHex) if err != nil { - handleError(w, errors.New("invalid key"), http.StatusBadRequest) + handleError(w, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest) return } - fileName, encryptedData, err := getFileFromDB(id) + fileName, encryptedData, err := getFileFromDB(r.Context(), id) if err != nil { - handleError(w, err, http.StatusInternalServerError) + handleError(w, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError) return } w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, fileName)) - plaintext, err := decryptFile(encryptedData, key) + err = decryptAndStreamFile(w, encryptedData, key) if err != nil { - handleError(w, errors.New("error decrypting file"), http.StatusInternalServerError) - return - } - - if _, err := w.Write(plaintext); err != nil { - handleError(w, errors.New("error writing response"), http.StatusInternalServerError) + handleError(w, fmt.Errorf("error decrypting and streaming file: %v", err), http.StatusInternalServerError) return } } func handleGetFileInfo(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - id := vars["id"] - + id := mux.Vars(r)["id"] keyHex := r.URL.Query().Get("key") + key, err := hex.DecodeString(keyHex) if err != nil { - handleError(w, errors.New("invalid key"), http.StatusBadRequest) + handleError(w, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest) return } - fileName, encryptedData, err := getFileFromDB(id) + fileName, encryptedData, err := getFileFromDB(r.Context(), id) if err != nil { - handleError(w, err, http.StatusInternalServerError) + handleError(w, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError) return } plaintext, err := decryptFile(encryptedData, key) if err != nil { - handleError(w, errors.New("error decrypting file"), http.StatusInternalServerError) + handleError(w, fmt.Errorf("error decrypting file: %v", err), http.StatusInternalServerError) return } @@ -174,14 +180,14 @@ func handleGetFileInfo(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(fileInfo) } -func storeFileInDB(id, fileName string, encryptedData []byte) error { - tx, err := db.Begin() +func storeFileInDB(ctx context.Context, id, fileName string, encryptedData []byte) error { + tx, err := db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() - _, err = tx.Exec("INSERT INTO files (id, name, data) VALUES ($1, $2, $3)", id, fileName, encryptedData) + _, err = tx.ExecContext(ctx, "INSERT INTO files (id, name, data) VALUES ($1, $2, $3)", id, fileName, encryptedData) if err != nil { return err } @@ -189,8 +195,8 @@ func storeFileInDB(id, fileName string, encryptedData []byte) error { return tx.Commit() } -func getFileFromDB(id string) (fileName string, encryptedData []byte, err error) { - err = db.QueryRow("SELECT name, data FROM files WHERE id = $1", id).Scan(&fileName, &encryptedData) +func getFileFromDB(ctx context.Context, id string) (fileName string, encryptedData []byte, err error) { + err = db.QueryRowContext(ctx, "SELECT name, data FROM files WHERE id = $1", id).Scan(&fileName, &encryptedData) if err == sql.ErrNoRows { return "", nil, errors.New("file not found") } @@ -226,7 +232,36 @@ func encryptFile(in io.Reader, key []byte) ([]byte, error) { return append(nonce, ciphertext...), nil } +func decryptAndStreamFile(w io.Writer, encryptedData []byte, key []byte) error { + if len(encryptedData) < nonceSize { + return errors.New("ciphertext too short") + } + + block, err := aes.NewCipher(key) + if err != nil { + return err + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return err + } + + nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:] + plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return err + } + + _, err = w.Write(plaintext) + return err +} + func decryptFile(encryptedData []byte, key []byte) ([]byte, error) { + if len(encryptedData) < nonceSize { + return nil, errors.New("ciphertext too short") + } + block, err := aes.NewCipher(key) if err != nil { return nil, err @@ -237,17 +272,8 @@ func decryptFile(encryptedData []byte, key []byte) ([]byte, error) { return nil, err } - if len(encryptedData) < nonceSize { - return nil, errors.New("ciphertext too short") - } - nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:] - plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) - if err != nil { - return nil, err - } - - return plaintext, nil + return aesgcm.Open(nil, nonce, ciphertext, nil) } func generateRandomKey() ([]byte, error) {