refactor: Optimize code
- Use context for database operations - Improve error handeling - Stream files instead of copying into memory
This commit is contained in:
parent
390f3e849a
commit
ec37d3b2ce
1 changed files with 76 additions and 50 deletions
126
server/main.go
126
server/main.go
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
|
@ -11,6 +12,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
_ "github.com/lib/pq"
|
||||
|
@ -26,7 +28,11 @@ const (
|
|||
var db *sql.DB
|
||||
|
||||
func main() {
|
||||
db = initDB()
|
||||
var err error
|
||||
db, err = initDB()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
@ -43,19 +49,24 @@ func main() {
|
|||
http.ListenAndServe(":8080", handler)
|
||||
}
|
||||
|
||||
func initDB() *sql.DB {
|
||||
func initDB() (*sql.DB, error) {
|
||||
db, err := sql.Open("postgres", "postgres://file:password@localhost/filedb?sslmode=disable")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return nil, err
|
||||
}
|
||||
if err := createFilesTable(db); err != nil {
|
||||
panic(err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := createFilesTable(ctx, db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func createFilesTable(db *sql.DB) error {
|
||||
_, err := db.Exec(`
|
||||
func createFilesTable(ctx context.Context, db *sql.DB) error {
|
||||
_, err := db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT,
|
||||
|
@ -66,17 +77,20 @@ func createFilesTable(db *sql.DB) error {
|
|||
}
|
||||
|
||||
func handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseMultipartForm(maxUploadSize)
|
||||
if err := r.ParseMultipartForm(maxUploadSize); err != nil {
|
||||
handleError(w, fmt.Errorf("error parsing multipart form: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
key, err := generateRandomKey()
|
||||
if err != nil {
|
||||
handleError(w, err, http.StatusInternalServerError)
|
||||
handleError(w, fmt.Errorf("error generating encryption key: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
file, handler, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
handleError(w, errors.New("error parsing uploaded file"), http.StatusBadRequest)
|
||||
handleError(w, fmt.Errorf("error getting form file: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
@ -85,13 +99,12 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
encryptedData, err := encryptFile(file, key)
|
||||
if err != nil {
|
||||
handleError(w, errors.New("error encrypting file"), http.StatusInternalServerError)
|
||||
handleError(w, fmt.Errorf("error encrypting file: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
err = storeFileInDB(id, handler.Filename, encryptedData)
|
||||
if err != nil {
|
||||
handleError(w, errors.New("error storing file in database"), http.StatusInternalServerError)
|
||||
if err := storeFileInDB(r.Context(), id, handler.Filename, encryptedData); err != nil {
|
||||
handleError(w, fmt.Errorf("error storing file in database: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -101,56 +114,49 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func handleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
id := vars["id"]
|
||||
|
||||
id := mux.Vars(r)["id"]
|
||||
keyHex := r.URL.Query().Get("key")
|
||||
|
||||
key, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
handleError(w, errors.New("invalid key"), http.StatusBadRequest)
|
||||
handleError(w, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
fileName, encryptedData, err := getFileFromDB(id)
|
||||
fileName, encryptedData, err := getFileFromDB(r.Context(), id)
|
||||
if err != nil {
|
||||
handleError(w, err, http.StatusInternalServerError)
|
||||
handleError(w, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, fileName))
|
||||
|
||||
plaintext, err := decryptFile(encryptedData, key)
|
||||
err = decryptAndStreamFile(w, encryptedData, key)
|
||||
if err != nil {
|
||||
handleError(w, errors.New("error decrypting file"), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := w.Write(plaintext); err != nil {
|
||||
handleError(w, errors.New("error writing response"), http.StatusInternalServerError)
|
||||
handleError(w, fmt.Errorf("error decrypting and streaming file: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func handleGetFileInfo(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
id := vars["id"]
|
||||
|
||||
id := mux.Vars(r)["id"]
|
||||
keyHex := r.URL.Query().Get("key")
|
||||
|
||||
key, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
handleError(w, errors.New("invalid key"), http.StatusBadRequest)
|
||||
handleError(w, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
fileName, encryptedData, err := getFileFromDB(id)
|
||||
fileName, encryptedData, err := getFileFromDB(r.Context(), id)
|
||||
if err != nil {
|
||||
handleError(w, err, http.StatusInternalServerError)
|
||||
handleError(w, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
plaintext, err := decryptFile(encryptedData, key)
|
||||
if err != nil {
|
||||
handleError(w, errors.New("error decrypting file"), http.StatusInternalServerError)
|
||||
handleError(w, fmt.Errorf("error decrypting file: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -174,14 +180,14 @@ func handleGetFileInfo(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(fileInfo)
|
||||
}
|
||||
|
||||
func storeFileInDB(id, fileName string, encryptedData []byte) error {
|
||||
tx, err := db.Begin()
|
||||
func storeFileInDB(ctx context.Context, id, fileName string, encryptedData []byte) error {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec("INSERT INTO files (id, name, data) VALUES ($1, $2, $3)", id, fileName, encryptedData)
|
||||
_, err = tx.ExecContext(ctx, "INSERT INTO files (id, name, data) VALUES ($1, $2, $3)", id, fileName, encryptedData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -189,8 +195,8 @@ func storeFileInDB(id, fileName string, encryptedData []byte) error {
|
|||
return tx.Commit()
|
||||
}
|
||||
|
||||
func getFileFromDB(id string) (fileName string, encryptedData []byte, err error) {
|
||||
err = db.QueryRow("SELECT name, data FROM files WHERE id = $1", id).Scan(&fileName, &encryptedData)
|
||||
func getFileFromDB(ctx context.Context, id string) (fileName string, encryptedData []byte, err error) {
|
||||
err = db.QueryRowContext(ctx, "SELECT name, data FROM files WHERE id = $1", id).Scan(&fileName, &encryptedData)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil, errors.New("file not found")
|
||||
}
|
||||
|
@ -226,7 +232,36 @@ func encryptFile(in io.Reader, key []byte) ([]byte, error) {
|
|||
return append(nonce, ciphertext...), nil
|
||||
}
|
||||
|
||||
func decryptAndStreamFile(w io.Writer, encryptedData []byte, key []byte) error {
|
||||
if len(encryptedData) < nonceSize {
|
||||
return errors.New("ciphertext too short")
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
aesgcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:]
|
||||
plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = w.Write(plaintext)
|
||||
return err
|
||||
}
|
||||
|
||||
func decryptFile(encryptedData []byte, key []byte) ([]byte, error) {
|
||||
if len(encryptedData) < nonceSize {
|
||||
return nil, errors.New("ciphertext too short")
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -237,17 +272,8 @@ func decryptFile(encryptedData []byte, key []byte) ([]byte, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if len(encryptedData) < nonceSize {
|
||||
return nil, errors.New("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:]
|
||||
plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
return aesgcm.Open(nil, nonce, ciphertext, nil)
|
||||
}
|
||||
|
||||
func generateRandomKey() ([]byte, error) {
|
||||
|
|
Loading…
Reference in a new issue