diff options
author | Blaster4385 <venkatesh@tablaster.dev> | 2024-07-25 11:40:09 +0530 |
---|---|---|
committer | Blaster4385 <venkatesh@tablaster.dev> | 2024-07-25 17:16:07 +0530 |
commit | b933d6ab405fdda250a26c86f23586da82f66fe9 (patch) | |
tree | f974f2b276ff4ed4a1cfb84f4bed6c91cad2c526 /server/main.go | |
parent | f21cffacc308a9d43efca0185adce274d69e9d4d (diff) |
feat: upload and store file in chunks to bypass network and postgres limits
Diffstat (limited to 'server/main.go')
-rw-r--r-- | server/main.go | 296 |
1 files changed, 194 insertions, 102 deletions
diff --git a/server/main.go b/server/main.go index 8060b41..b7bd5b7 100644 --- a/server/main.go +++ b/server/main.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "context" "crypto/aes" "crypto/cipher" @@ -14,6 +15,7 @@ import ( "io" "net/http" "os" + "strconv" "time" "github.com/labstack/echo/v4" @@ -28,7 +30,6 @@ const ( ) var db *sql.DB - var port string //go:embed all:dist @@ -45,7 +46,8 @@ func registerHandlers(e *echo.Echo) { e.Use(middleware.Logger()) e.Use(middleware.Recover()) e.Use(middleware.CORS()) - e.POST("/upload", handleUpload) + e.POST("/upload_chunk", handleUploadChunk) + e.POST("/upload_complete", handleUploadComplete) e.GET("/download/:id", handleDownload) e.GET("/get/:id", handleGetFileInfo) } @@ -62,6 +64,9 @@ func main() { e := echo.New() registerHandlers(e) + + startCleanupScheduler() + e.Logger.Fatal(e.Start(":" + port)) } @@ -71,8 +76,6 @@ func initDB() (*sql.DB, error) { dbname := os.Getenv("POSTGRES_DB") dbURL := fmt.Sprintf("postgres://%s:%s@localhost/%s?sslmode=disable", user, password, dbname) - // print dbURL for debugging - fmt.Println(dbURL) db, err := sql.Open("postgres", dbURL) if err != nil { return nil, err @@ -81,60 +84,103 @@ func initDB() (*sql.DB, error) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if err := createFilesTable(ctx, db); err != nil { + if err := createTables(ctx, db); err != nil { return nil, err } return db, nil } -func createFilesTable(ctx context.Context, db *sql.DB) error { +func createTables(ctx context.Context, db *sql.DB) error { _, err := db.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS chunks ( + upload_id TEXT, + chunk_index INT, + chunk_data BYTEA, + created_at TIMESTAMPTZ DEFAULT NOW(), + PRIMARY KEY (upload_id, chunk_index) + ); CREATE TABLE IF NOT EXISTS files ( - id TEXT PRIMARY KEY, + id TEXT, name TEXT, - data BYTEA + chunk_index INT, + chunk_data BYTEA, + created_at TIMESTAMPTZ DEFAULT NOW(), + PRIMARY KEY (id, chunk_index) ); `) return err } -func handleUpload(c echo.Context) error { - r := c.Request() - if err := r.ParseMultipartForm(maxUploadSize); err != nil { - return handleError(c, fmt.Errorf("error parsing multipart form: %v", err), http.StatusBadRequest) +func handleUploadChunk(c echo.Context) error { + uploadId := c.FormValue("uploadId") + chunkIndex, err := strconv.Atoi(c.FormValue("chunkIndex")) + if err != nil { + return handleError(c, fmt.Errorf("invalid chunk index: %v", err), http.StatusBadRequest) + } + chunk, err := c.FormFile("chunk") + if err != nil { + return handleError(c, fmt.Errorf("error getting form file: %v", err), http.StatusBadRequest) } - key, err := generateRandomKey() + src, err := chunk.Open() if err != nil { - return handleError(c, fmt.Errorf("error generating encryption key: %v", err), http.StatusInternalServerError) + return handleError(c, fmt.Errorf("error opening chunk: %v", err), http.StatusInternalServerError) } + defer src.Close() - file, handler, err := r.FormFile("file") + chunkData, err := io.ReadAll(src) if err != nil { - return handleError(c, fmt.Errorf("error getting form file: %v", err), http.StatusBadRequest) + return handleError(c, fmt.Errorf("error reading chunk data: %v", err), http.StatusInternalServerError) } - defer file.Close() - id := generateID() + if err := storeChunkInDB(c.Request().Context(), uploadId, chunkIndex, chunkData); err != nil { + return handleError(c, fmt.Errorf("error storing chunk in database: %v", err), http.StatusInternalServerError) + } + + return c.NoContent(http.StatusOK) +} + +func storeChunkInDB(ctx context.Context, uploadId string, chunkIndex int, chunkData []byte) error { + _, err := db.ExecContext(ctx, "INSERT INTO chunks (upload_id, chunk_index, chunk_data, created_at) VALUES ($1, $2, $3, NOW())", uploadId, chunkIndex, chunkData) + return err +} - encryptedData, err := encryptFile(file, key) +func handleUploadComplete(c echo.Context) error { + uploadId := c.FormValue("uploadId") + chunkCount, err := strconv.Atoi(c.FormValue("chunkCount")) if err != nil { - return handleError(c, fmt.Errorf("error encrypting file: %v", err), http.StatusInternalServerError) + return handleError(c, fmt.Errorf("invalid chunk count: %v", err), http.StatusBadRequest) } + fileName := c.FormValue("fileName") - if err := storeFileInDB(r.Context(), id, handler.Filename, encryptedData); err != nil { - return handleError(c, fmt.Errorf("error storing file in database: %v", err), http.StatusInternalServerError) + key, err := generateRandomKey() + if err != nil { + return handleError(c, fmt.Errorf("error generating encryption key: %v", err), http.StatusInternalServerError) } - encodedKey := hex.EncodeToString(key) + id := generateID() + for i := 0; i < chunkCount; i++ { + chunkData, err := getChunkFromDB(c.Request().Context(), uploadId, i) + if err != nil { + return handleError(c, fmt.Errorf("error retrieving chunk data: %v", err), http.StatusInternalServerError) + } - type UploadResponse struct { - ID string `json:"id"` - Key string `json:"key"` + encryptedData, err := encryptFile(bytes.NewReader(chunkData), key) + if err != nil { + return handleError(c, fmt.Errorf("error encrypting chunk: %v", err), http.StatusInternalServerError) + } + + if err := storeChunkInFilesTable(c.Request().Context(), id, fileName, i, encryptedData); err != nil { + return handleError(c, fmt.Errorf("error storing chunk in database: %v", err), http.StatusInternalServerError) + } } - response := UploadResponse{ + encodedKey := hex.EncodeToString(key) + response := struct { + ID string `json:"id"` + Key string `json:"key"` + }{ ID: id, Key: encodedKey, } @@ -142,6 +188,17 @@ func handleUpload(c echo.Context) error { return c.JSON(http.StatusOK, response) } +func storeChunkInFilesTable(ctx context.Context, id, fileName string, chunkIndex int, encryptedData []byte) error { + _, err := db.ExecContext(ctx, "INSERT INTO files (id, name, chunk_index, chunk_data, created_at) VALUES ($1, $2, $3, $4, NOW())", id, fileName, chunkIndex, encryptedData) + return err +} + +func getChunkFromDB(ctx context.Context, uploadId string, chunkIndex int) ([]byte, error) { + var chunkData []byte + err := db.QueryRowContext(ctx, "SELECT chunk_data FROM chunks WHERE upload_id = $1 AND chunk_index = $2", uploadId, chunkIndex).Scan(&chunkData) + return chunkData, err +} + func handleDownload(c echo.Context) error { id := c.Param("id") keyHex := c.QueryParam("key") @@ -151,14 +208,14 @@ func handleDownload(c echo.Context) error { return handleError(c, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest) } - fileName, encryptedData, err := getFileFromDB(c.Request().Context(), id) + fileName, err := getFileNameFromDB(c.Request().Context(), id) if err != nil { - return handleError(c, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError) + return handleError(c, fmt.Errorf("error getting file name from database: %v", err), http.StatusInternalServerError) } c.Response().Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, fileName)) - err = decryptAndStreamFile(c.Response(), encryptedData, key) + err = decryptAndStreamChunks(c.Response(), id, key) if err != nil { return handleError(c, fmt.Errorf("error decrypting and streaming file: %v", err), http.StatusInternalServerError) } @@ -166,6 +223,41 @@ func handleDownload(c echo.Context) error { return nil } +func getFileNameFromDB(ctx context.Context, id string) (fileName string, err error) { + err = db.QueryRowContext(ctx, "SELECT name FROM files WHERE id = $1 LIMIT 1", id).Scan(&fileName) + if err == sql.ErrNoRows { + return "", errors.New("file not found") + } + return fileName, err +} + +func decryptAndStreamChunks(w io.Writer, id string, key []byte) error { + rows, err := db.Query("SELECT chunk_data FROM files WHERE id = $1 ORDER BY chunk_index", id) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var encryptedData []byte + if err := rows.Scan(&encryptedData); err != nil { + return err + } + + plaintext, err := decryptFile(encryptedData, key) + if err != nil { + return err + } + + _, err = w.Write(plaintext) + if err != nil { + return err + } + } + + return rows.Err() +} + func handleGetFileInfo(c echo.Context) error { id := c.Param("id") keyHex := c.QueryParam("key") @@ -175,22 +267,14 @@ func handleGetFileInfo(c echo.Context) error { return handleError(c, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest) } - fileName, encryptedData, err := getFileFromDB(c.Request().Context(), id) + fileName, err := getFileNameFromDB(c.Request().Context(), id) if err != nil { - return handleError(c, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError) + return handleError(c, fmt.Errorf("error getting file name from database: %v", err), http.StatusInternalServerError) } - plaintext, err := decryptFile(encryptedData, key) + fileSize, err := getTotalFileSize(id, key) if err != nil { - return handleError(c, fmt.Errorf("error decrypting file: %v", err), http.StatusInternalServerError) - } - - fileSizeBytes := len(plaintext) - var fileSize string - if fileSizeBytes >= 1024*1024 { - fileSize = fmt.Sprintf("%.2f MB", float64(fileSizeBytes)/(1024*1024)) - } else { - fileSize = fmt.Sprintf("%.2f KB", float64(fileSizeBytes)/1024) + return handleError(c, fmt.Errorf("error getting file size: %v", err), http.StatusInternalServerError) } fileInfo := struct { @@ -204,112 +288,120 @@ func handleGetFileInfo(c echo.Context) error { return c.JSON(http.StatusOK, fileInfo) } -func storeFileInDB(ctx context.Context, id, fileName string, encryptedData []byte) error { - tx, err := db.BeginTx(ctx, nil) +func getTotalFileSize(id string, key []byte) (string, error) { + var totalSize int64 + rows, err := db.Query("SELECT chunk_data FROM files WHERE id = $1 ORDER BY chunk_index", id) if err != nil { - return err + return "", err } - defer tx.Rollback() + defer rows.Close() - _, err = tx.ExecContext(ctx, "INSERT INTO files (id, name, data) VALUES ($1, $2, $3)", id, fileName, encryptedData) - if err != nil { - return err + for rows.Next() { + var encryptedData []byte + if err := rows.Scan(&encryptedData); err != nil { + return "", err + } + + plaintext, err := decryptFile(encryptedData, key) + if err != nil { + return "", err + } + + totalSize += int64(len(plaintext)) } - return tx.Commit() + var fileSize string + if totalSize >= 1024*1024 { + fileSize = fmt.Sprintf("%.2f MB", float64(totalSize)/(1024*1024)) + } else { + fileSize = fmt.Sprintf("%.2f KB", float64(totalSize)/1024) + } + + return fileSize, rows.Err() } -func getFileFromDB(ctx context.Context, id string) (fileName string, encryptedData []byte, err error) { - err = db.QueryRowContext(ctx, "SELECT name, data FROM files WHERE id = $1", id).Scan(&fileName, &encryptedData) - if err == sql.ErrNoRows { - return "", nil, errors.New("file not found") - } - return fileName, encryptedData, err +func handleError(c echo.Context, err error, status int) error { + fmt.Printf("error: %v\n", err) + return c.JSON(status, map[string]string{"error": err.Error()}) +} + +func generateRandomKey() ([]byte, error) { + key := make([]byte, keySize) + _, err := rand.Read(key) + return key, err } -func handleError(c echo.Context, err error, code int) error { - return c.JSON(code, map[string]string{"error": err.Error()}) +func generateID() string { + b := make([]byte, 16) + rand.Read(b) + return hex.EncodeToString(b) } -func encryptFile(in io.Reader, key []byte) ([]byte, error) { +func encryptFile(plaintext io.Reader, key []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { return nil, err } - aesgcm, err := cipher.NewGCM(block) - if err != nil { + nonce := make([]byte, nonceSize) + if _, err := rand.Read(nonce); err != nil { return nil, err } - nonce := make([]byte, nonceSize) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + aesgcm, err := cipher.NewGCM(block) + if err != nil { return nil, err } - plaintext, err := io.ReadAll(in) + plaintextBytes, err := io.ReadAll(plaintext) if err != nil { return nil, err } - ciphertext := aesgcm.Seal(nil, nonce, plaintext, nil) - return append(nonce, ciphertext...), nil + ciphertext := aesgcm.Seal(nonce, nonce, plaintextBytes, nil) + return ciphertext, nil } -func decryptAndStreamFile(w io.Writer, encryptedData []byte, key []byte) error { - if len(encryptedData) < nonceSize { - return errors.New("ciphertext too short") - } - +func decryptFile(ciphertext, key []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { - return err - } - - aesgcm, err := cipher.NewGCM(block) - if err != nil { - return err - } - - nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:] - plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) - if err != nil { - return err + return nil, err } - _, err = w.Write(plaintext) - return err -} - -func decryptFile(encryptedData []byte, key []byte) ([]byte, error) { - if len(encryptedData) < nonceSize { + if len(ciphertext) < nonceSize { return nil, errors.New("ciphertext too short") } - block, err := aes.NewCipher(key) + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + + aesgcm, err := cipher.NewGCM(block) if err != nil { return nil, err } - aesgcm, err := cipher.NewGCM(block) + plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) if err != nil { return nil, err } - nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:] - return aesgcm.Open(nil, nonce, ciphertext, nil) + return plaintext, nil } -func generateRandomKey() ([]byte, error) { - key := make([]byte, keySize) - _, err := rand.Read(key) - return key, err +func startCleanupScheduler() { + ticker := time.NewTicker(24 * time.Hour) + go func() { + for range ticker.C { + cleanupChunks() + } + }() } -func generateID() string { - b := make([]byte, 16) - if _, err := rand.Read(b); err != nil { - panic(err) +func cleanupChunks() { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + defer cancel() + + _, err := db.ExecContext(ctx, "DELETE FROM chunks WHERE created_at < NOW() - INTERVAL '1 day'") + if err != nil { + fmt.Printf("error cleaning up chunks: %v\n", err) } - return hex.EncodeToString(b) } |