summaryrefslogtreecommitdiff
path: root/server/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'server/main.go')
-rw-r--r--server/main.go103
1 files changed, 43 insertions, 60 deletions
diff --git a/server/main.go b/server/main.go
index 8189750..683e0d0 100644
--- a/server/main.go
+++ b/server/main.go
@@ -7,16 +7,15 @@ import (
"crypto/rand"
"database/sql"
"encoding/hex"
- "encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
- "github.com/gorilla/mux"
+ "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v4/middleware"
_ "github.com/lib/pq"
- "github.com/rs/cors"
)
const (
@@ -27,6 +26,15 @@ const (
var db *sql.DB
+func registerHandlers(e *echo.Echo) {
+ e.Use(middleware.Logger())
+ e.Use(middleware.Recover())
+ e.Use(middleware.CORS())
+ e.POST("/upload", handleUpload)
+ e.GET("/download/:id", handleDownload)
+ e.GET("/get/:id", handleGetFileInfo)
+}
+
func main() {
var err error
db, err = initDB()
@@ -35,18 +43,9 @@ func main() {
}
defer db.Close()
- router := mux.NewRouter()
- router.HandleFunc("/upload", handleUpload).Methods("POST")
- router.HandleFunc("/download/{id}", handleDownload).Methods("GET")
- router.HandleFunc("/get/{id}", handleGetFileInfo).Methods("GET")
-
- handler := cors.New(cors.Options{
- AllowedOrigins: []string{"*"},
- AllowedMethods: []string{"GET", "POST"},
- AllowedHeaders: []string{"*"},
- }).Handler(router)
-
- http.ListenAndServe(":8080", handler)
+ e := echo.New()
+ registerHandlers(e)
+ e.Logger.Fatal(e.Start(":8080"))
}
func initDB() (*sql.DB, error) {
@@ -76,22 +75,20 @@ func createFilesTable(ctx context.Context, db *sql.DB) error {
return err
}
-func handleUpload(w http.ResponseWriter, r *http.Request) {
+func handleUpload(c echo.Context) error {
+ r := c.Request()
if err := r.ParseMultipartForm(maxUploadSize); err != nil {
- handleError(w, fmt.Errorf("error parsing multipart form: %v", err), http.StatusBadRequest)
- return
+ return handleError(c, fmt.Errorf("error parsing multipart form: %v", err), http.StatusBadRequest)
}
key, err := generateRandomKey()
if err != nil {
- handleError(w, fmt.Errorf("error generating encryption key: %v", err), http.StatusInternalServerError)
- return
+ return handleError(c, fmt.Errorf("error generating encryption key: %v", err), http.StatusInternalServerError)
}
file, handler, err := r.FormFile("file")
if err != nil {
- handleError(w, fmt.Errorf("error getting form file: %v", err), http.StatusBadRequest)
- return
+ return handleError(c, fmt.Errorf("error getting form file: %v", err), http.StatusBadRequest)
}
defer file.Close()
@@ -99,13 +96,11 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
encryptedData, err := encryptFile(file, key)
if err != nil {
- handleError(w, fmt.Errorf("error encrypting file: %v", err), http.StatusInternalServerError)
- return
+ return handleError(c, fmt.Errorf("error encrypting file: %v", err), http.StatusInternalServerError)
}
if err := storeFileInDB(r.Context(), id, handler.Filename, encryptedData); err != nil {
- handleError(w, fmt.Errorf("error storing file in database: %v", err), http.StatusInternalServerError)
- return
+ return handleError(c, fmt.Errorf("error storing file in database: %v", err), http.StatusInternalServerError)
}
encodedKey := hex.EncodeToString(key)
@@ -120,58 +115,50 @@ func handleUpload(w http.ResponseWriter, r *http.Request) {
Key: encodedKey,
}
- w.Header().Set("Content-Type", "application/json")
- if err := json.NewEncoder(w).Encode(response); err != nil {
- handleError(w, fmt.Errorf("error encoding response: %v", err), http.StatusInternalServerError)
- return
- }
+ return c.JSON(http.StatusOK, response)
}
-func handleDownload(w http.ResponseWriter, r *http.Request) {
- id := mux.Vars(r)["id"]
- keyHex := r.URL.Query().Get("key")
+func handleDownload(c echo.Context) error {
+ id := c.Param("id")
+ keyHex := c.QueryParam("key")
key, err := hex.DecodeString(keyHex)
if err != nil {
- handleError(w, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest)
- return
+ return handleError(c, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest)
}
- fileName, encryptedData, err := getFileFromDB(r.Context(), id)
+ fileName, encryptedData, err := getFileFromDB(c.Request().Context(), id)
if err != nil {
- handleError(w, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError)
- return
+ return handleError(c, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError)
}
- w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, fileName))
+ c.Response().Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, fileName))
- err = decryptAndStreamFile(w, encryptedData, key)
+ err = decryptAndStreamFile(c.Response(), encryptedData, key)
if err != nil {
- handleError(w, fmt.Errorf("error decrypting and streaming file: %v", err), http.StatusInternalServerError)
- return
+ return handleError(c, fmt.Errorf("error decrypting and streaming file: %v", err), http.StatusInternalServerError)
}
+
+ return nil
}
-func handleGetFileInfo(w http.ResponseWriter, r *http.Request) {
- id := mux.Vars(r)["id"]
- keyHex := r.URL.Query().Get("key")
+func handleGetFileInfo(c echo.Context) error {
+ id := c.Param("id")
+ keyHex := c.QueryParam("key")
key, err := hex.DecodeString(keyHex)
if err != nil {
- handleError(w, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest)
- return
+ return handleError(c, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest)
}
- fileName, encryptedData, err := getFileFromDB(r.Context(), id)
+ fileName, encryptedData, err := getFileFromDB(c.Request().Context(), id)
if err != nil {
- handleError(w, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError)
- return
+ return handleError(c, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError)
}
plaintext, err := decryptFile(encryptedData, key)
if err != nil {
- handleError(w, fmt.Errorf("error decrypting file: %v", err), http.StatusInternalServerError)
- return
+ return handleError(c, fmt.Errorf("error decrypting file: %v", err), http.StatusInternalServerError)
}
fileSizeBytes := len(plaintext)
@@ -190,11 +177,7 @@ func handleGetFileInfo(w http.ResponseWriter, r *http.Request) {
FileSize: fileSize,
}
- w.Header().Set("Content-Type", "application/json")
- if err := json.NewEncoder(w).Encode(fileInfo); err != nil {
- handleError(w, fmt.Errorf("error encoding response: %v", err), http.StatusInternalServerError)
- return
- }
+ return c.JSON(http.StatusOK, fileInfo)
}
func storeFileInDB(ctx context.Context, id, fileName string, encryptedData []byte) error {
@@ -220,8 +203,8 @@ func getFileFromDB(ctx context.Context, id string) (fileName string, encryptedDa
return fileName, encryptedData, err
}
-func handleError(w http.ResponseWriter, err error, code int) {
- http.Error(w, err.Error(), code)
+func handleError(c echo.Context, err error, code int) error {
+ return c.JSON(code, map[string]string{"error": err.Error()})
}
func encryptFile(in io.Reader, key []byte) ([]byte, error) {