diff options
Diffstat (limited to 'server/main.go')
-rw-r--r-- | server/main.go | 103 |
1 files changed, 43 insertions, 60 deletions
diff --git a/server/main.go b/server/main.go index 8189750..683e0d0 100644 --- a/server/main.go +++ b/server/main.go @@ -7,16 +7,15 @@ import ( "crypto/rand" "database/sql" "encoding/hex" - "encoding/json" "errors" "fmt" "io" "net/http" "time" - "github.com/gorilla/mux" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" _ "github.com/lib/pq" - "github.com/rs/cors" ) const ( @@ -27,6 +26,15 @@ const ( var db *sql.DB +func registerHandlers(e *echo.Echo) { + e.Use(middleware.Logger()) + e.Use(middleware.Recover()) + e.Use(middleware.CORS()) + e.POST("/upload", handleUpload) + e.GET("/download/:id", handleDownload) + e.GET("/get/:id", handleGetFileInfo) +} + func main() { var err error db, err = initDB() @@ -35,18 +43,9 @@ func main() { } defer db.Close() - router := mux.NewRouter() - router.HandleFunc("/upload", handleUpload).Methods("POST") - router.HandleFunc("/download/{id}", handleDownload).Methods("GET") - router.HandleFunc("/get/{id}", handleGetFileInfo).Methods("GET") - - handler := cors.New(cors.Options{ - AllowedOrigins: []string{"*"}, - AllowedMethods: []string{"GET", "POST"}, - AllowedHeaders: []string{"*"}, - }).Handler(router) - - http.ListenAndServe(":8080", handler) + e := echo.New() + registerHandlers(e) + e.Logger.Fatal(e.Start(":8080")) } func initDB() (*sql.DB, error) { @@ -76,22 +75,20 @@ func createFilesTable(ctx context.Context, db *sql.DB) error { return err } -func handleUpload(w http.ResponseWriter, r *http.Request) { +func handleUpload(c echo.Context) error { + r := c.Request() if err := r.ParseMultipartForm(maxUploadSize); err != nil { - handleError(w, fmt.Errorf("error parsing multipart form: %v", err), http.StatusBadRequest) - return + return handleError(c, fmt.Errorf("error parsing multipart form: %v", err), http.StatusBadRequest) } key, err := generateRandomKey() if err != nil { - handleError(w, fmt.Errorf("error generating encryption key: %v", err), http.StatusInternalServerError) - return + return handleError(c, fmt.Errorf("error generating encryption key: %v", err), http.StatusInternalServerError) } file, handler, err := r.FormFile("file") if err != nil { - handleError(w, fmt.Errorf("error getting form file: %v", err), http.StatusBadRequest) - return + return handleError(c, fmt.Errorf("error getting form file: %v", err), http.StatusBadRequest) } defer file.Close() @@ -99,13 +96,11 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { encryptedData, err := encryptFile(file, key) if err != nil { - handleError(w, fmt.Errorf("error encrypting file: %v", err), http.StatusInternalServerError) - return + return handleError(c, fmt.Errorf("error encrypting file: %v", err), http.StatusInternalServerError) } if err := storeFileInDB(r.Context(), id, handler.Filename, encryptedData); err != nil { - handleError(w, fmt.Errorf("error storing file in database: %v", err), http.StatusInternalServerError) - return + return handleError(c, fmt.Errorf("error storing file in database: %v", err), http.StatusInternalServerError) } encodedKey := hex.EncodeToString(key) @@ -120,58 +115,50 @@ func handleUpload(w http.ResponseWriter, r *http.Request) { Key: encodedKey, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - handleError(w, fmt.Errorf("error encoding response: %v", err), http.StatusInternalServerError) - return - } + return c.JSON(http.StatusOK, response) } -func handleDownload(w http.ResponseWriter, r *http.Request) { - id := mux.Vars(r)["id"] - keyHex := r.URL.Query().Get("key") +func handleDownload(c echo.Context) error { + id := c.Param("id") + keyHex := c.QueryParam("key") key, err := hex.DecodeString(keyHex) if err != nil { - handleError(w, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest) - return + return handleError(c, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest) } - fileName, encryptedData, err := getFileFromDB(r.Context(), id) + fileName, encryptedData, err := getFileFromDB(c.Request().Context(), id) if err != nil { - handleError(w, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError) - return + return handleError(c, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError) } - w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, fileName)) + c.Response().Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, fileName)) - err = decryptAndStreamFile(w, encryptedData, key) + err = decryptAndStreamFile(c.Response(), encryptedData, key) if err != nil { - handleError(w, fmt.Errorf("error decrypting and streaming file: %v", err), http.StatusInternalServerError) - return + return handleError(c, fmt.Errorf("error decrypting and streaming file: %v", err), http.StatusInternalServerError) } + + return nil } -func handleGetFileInfo(w http.ResponseWriter, r *http.Request) { - id := mux.Vars(r)["id"] - keyHex := r.URL.Query().Get("key") +func handleGetFileInfo(c echo.Context) error { + id := c.Param("id") + keyHex := c.QueryParam("key") key, err := hex.DecodeString(keyHex) if err != nil { - handleError(w, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest) - return + return handleError(c, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest) } - fileName, encryptedData, err := getFileFromDB(r.Context(), id) + fileName, encryptedData, err := getFileFromDB(c.Request().Context(), id) if err != nil { - handleError(w, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError) - return + return handleError(c, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError) } plaintext, err := decryptFile(encryptedData, key) if err != nil { - handleError(w, fmt.Errorf("error decrypting file: %v", err), http.StatusInternalServerError) - return + return handleError(c, fmt.Errorf("error decrypting file: %v", err), http.StatusInternalServerError) } fileSizeBytes := len(plaintext) @@ -190,11 +177,7 @@ func handleGetFileInfo(w http.ResponseWriter, r *http.Request) { FileSize: fileSize, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(fileInfo); err != nil { - handleError(w, fmt.Errorf("error encoding response: %v", err), http.StatusInternalServerError) - return - } + return c.JSON(http.StatusOK, fileInfo) } func storeFileInDB(ctx context.Context, id, fileName string, encryptedData []byte) error { @@ -220,8 +203,8 @@ func getFileFromDB(ctx context.Context, id string) (fileName string, encryptedDa return fileName, encryptedData, err } -func handleError(w http.ResponseWriter, err error, code int) { - http.Error(w, err.Error(), code) +func handleError(c echo.Context, err error, code int) error { + return c.JSON(code, map[string]string{"error": err.Error()}) } func encryptFile(in io.Reader, key []byte) ([]byte, error) { |