summaryrefslogtreecommitdiff
path: root/server/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'server/main.go')
-rw-r--r--server/main.go296
1 files changed, 194 insertions, 102 deletions
diff --git a/server/main.go b/server/main.go
index 8060b41..b7bd5b7 100644
--- a/server/main.go
+++ b/server/main.go
@@ -1,6 +1,7 @@
package main
import (
+ "bytes"
"context"
"crypto/aes"
"crypto/cipher"
@@ -14,6 +15,7 @@ import (
"io"
"net/http"
"os"
+ "strconv"
"time"
"github.com/labstack/echo/v4"
@@ -28,7 +30,6 @@ const (
)
var db *sql.DB
-
var port string
//go:embed all:dist
@@ -45,7 +46,8 @@ func registerHandlers(e *echo.Echo) {
e.Use(middleware.Logger())
e.Use(middleware.Recover())
e.Use(middleware.CORS())
- e.POST("/upload", handleUpload)
+ e.POST("/upload_chunk", handleUploadChunk)
+ e.POST("/upload_complete", handleUploadComplete)
e.GET("/download/:id", handleDownload)
e.GET("/get/:id", handleGetFileInfo)
}
@@ -62,6 +64,9 @@ func main() {
e := echo.New()
registerHandlers(e)
+
+ startCleanupScheduler()
+
e.Logger.Fatal(e.Start(":" + port))
}
@@ -71,8 +76,6 @@ func initDB() (*sql.DB, error) {
dbname := os.Getenv("POSTGRES_DB")
dbURL := fmt.Sprintf("postgres://%s:%s@localhost/%s?sslmode=disable", user, password, dbname)
- // print dbURL for debugging
- fmt.Println(dbURL)
db, err := sql.Open("postgres", dbURL)
if err != nil {
return nil, err
@@ -81,60 +84,103 @@ func initDB() (*sql.DB, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
- if err := createFilesTable(ctx, db); err != nil {
+ if err := createTables(ctx, db); err != nil {
return nil, err
}
return db, nil
}
-func createFilesTable(ctx context.Context, db *sql.DB) error {
+func createTables(ctx context.Context, db *sql.DB) error {
_, err := db.ExecContext(ctx, `
+ CREATE TABLE IF NOT EXISTS chunks (
+ upload_id TEXT,
+ chunk_index INT,
+ chunk_data BYTEA,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ PRIMARY KEY (upload_id, chunk_index)
+ );
CREATE TABLE IF NOT EXISTS files (
- id TEXT PRIMARY KEY,
+ id TEXT,
name TEXT,
- data BYTEA
+ chunk_index INT,
+ chunk_data BYTEA,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ PRIMARY KEY (id, chunk_index)
);
`)
return err
}
-func handleUpload(c echo.Context) error {
- r := c.Request()
- if err := r.ParseMultipartForm(maxUploadSize); err != nil {
- return handleError(c, fmt.Errorf("error parsing multipart form: %v", err), http.StatusBadRequest)
+func handleUploadChunk(c echo.Context) error {
+ uploadId := c.FormValue("uploadId")
+ chunkIndex, err := strconv.Atoi(c.FormValue("chunkIndex"))
+ if err != nil {
+ return handleError(c, fmt.Errorf("invalid chunk index: %v", err), http.StatusBadRequest)
+ }
+ chunk, err := c.FormFile("chunk")
+ if err != nil {
+ return handleError(c, fmt.Errorf("error getting form file: %v", err), http.StatusBadRequest)
}
- key, err := generateRandomKey()
+ src, err := chunk.Open()
if err != nil {
- return handleError(c, fmt.Errorf("error generating encryption key: %v", err), http.StatusInternalServerError)
+ return handleError(c, fmt.Errorf("error opening chunk: %v", err), http.StatusInternalServerError)
}
+ defer src.Close()
- file, handler, err := r.FormFile("file")
+ chunkData, err := io.ReadAll(src)
if err != nil {
- return handleError(c, fmt.Errorf("error getting form file: %v", err), http.StatusBadRequest)
+ return handleError(c, fmt.Errorf("error reading chunk data: %v", err), http.StatusInternalServerError)
}
- defer file.Close()
- id := generateID()
+ if err := storeChunkInDB(c.Request().Context(), uploadId, chunkIndex, chunkData); err != nil {
+ return handleError(c, fmt.Errorf("error storing chunk in database: %v", err), http.StatusInternalServerError)
+ }
+
+ return c.NoContent(http.StatusOK)
+}
+
+func storeChunkInDB(ctx context.Context, uploadId string, chunkIndex int, chunkData []byte) error {
+ _, err := db.ExecContext(ctx, "INSERT INTO chunks (upload_id, chunk_index, chunk_data, created_at) VALUES ($1, $2, $3, NOW())", uploadId, chunkIndex, chunkData)
+ return err
+}
- encryptedData, err := encryptFile(file, key)
+func handleUploadComplete(c echo.Context) error {
+ uploadId := c.FormValue("uploadId")
+ chunkCount, err := strconv.Atoi(c.FormValue("chunkCount"))
if err != nil {
- return handleError(c, fmt.Errorf("error encrypting file: %v", err), http.StatusInternalServerError)
+ return handleError(c, fmt.Errorf("invalid chunk count: %v", err), http.StatusBadRequest)
}
+ fileName := c.FormValue("fileName")
- if err := storeFileInDB(r.Context(), id, handler.Filename, encryptedData); err != nil {
- return handleError(c, fmt.Errorf("error storing file in database: %v", err), http.StatusInternalServerError)
+ key, err := generateRandomKey()
+ if err != nil {
+ return handleError(c, fmt.Errorf("error generating encryption key: %v", err), http.StatusInternalServerError)
}
- encodedKey := hex.EncodeToString(key)
+ id := generateID()
+ for i := 0; i < chunkCount; i++ {
+ chunkData, err := getChunkFromDB(c.Request().Context(), uploadId, i)
+ if err != nil {
+ return handleError(c, fmt.Errorf("error retrieving chunk data: %v", err), http.StatusInternalServerError)
+ }
- type UploadResponse struct {
- ID string `json:"id"`
- Key string `json:"key"`
+ encryptedData, err := encryptFile(bytes.NewReader(chunkData), key)
+ if err != nil {
+ return handleError(c, fmt.Errorf("error encrypting chunk: %v", err), http.StatusInternalServerError)
+ }
+
+ if err := storeChunkInFilesTable(c.Request().Context(), id, fileName, i, encryptedData); err != nil {
+ return handleError(c, fmt.Errorf("error storing chunk in database: %v", err), http.StatusInternalServerError)
+ }
}
- response := UploadResponse{
+ encodedKey := hex.EncodeToString(key)
+ response := struct {
+ ID string `json:"id"`
+ Key string `json:"key"`
+ }{
ID: id,
Key: encodedKey,
}
@@ -142,6 +188,17 @@ func handleUpload(c echo.Context) error {
return c.JSON(http.StatusOK, response)
}
+func storeChunkInFilesTable(ctx context.Context, id, fileName string, chunkIndex int, encryptedData []byte) error {
+ _, err := db.ExecContext(ctx, "INSERT INTO files (id, name, chunk_index, chunk_data, created_at) VALUES ($1, $2, $3, $4, NOW())", id, fileName, chunkIndex, encryptedData)
+ return err
+}
+
+func getChunkFromDB(ctx context.Context, uploadId string, chunkIndex int) ([]byte, error) {
+ var chunkData []byte
+ err := db.QueryRowContext(ctx, "SELECT chunk_data FROM chunks WHERE upload_id = $1 AND chunk_index = $2", uploadId, chunkIndex).Scan(&chunkData)
+ return chunkData, err
+}
+
func handleDownload(c echo.Context) error {
id := c.Param("id")
keyHex := c.QueryParam("key")
@@ -151,14 +208,14 @@ func handleDownload(c echo.Context) error {
return handleError(c, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest)
}
- fileName, encryptedData, err := getFileFromDB(c.Request().Context(), id)
+ fileName, err := getFileNameFromDB(c.Request().Context(), id)
if err != nil {
- return handleError(c, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError)
+ return handleError(c, fmt.Errorf("error getting file name from database: %v", err), http.StatusInternalServerError)
}
c.Response().Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, fileName))
- err = decryptAndStreamFile(c.Response(), encryptedData, key)
+ err = decryptAndStreamChunks(c.Response(), id, key)
if err != nil {
return handleError(c, fmt.Errorf("error decrypting and streaming file: %v", err), http.StatusInternalServerError)
}
@@ -166,6 +223,41 @@ func handleDownload(c echo.Context) error {
return nil
}
+func getFileNameFromDB(ctx context.Context, id string) (fileName string, err error) {
+ err = db.QueryRowContext(ctx, "SELECT name FROM files WHERE id = $1 LIMIT 1", id).Scan(&fileName)
+ if err == sql.ErrNoRows {
+ return "", errors.New("file not found")
+ }
+ return fileName, err
+}
+
+func decryptAndStreamChunks(w io.Writer, id string, key []byte) error {
+ rows, err := db.Query("SELECT chunk_data FROM files WHERE id = $1 ORDER BY chunk_index", id)
+ if err != nil {
+ return err
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var encryptedData []byte
+ if err := rows.Scan(&encryptedData); err != nil {
+ return err
+ }
+
+ plaintext, err := decryptFile(encryptedData, key)
+ if err != nil {
+ return err
+ }
+
+ _, err = w.Write(plaintext)
+ if err != nil {
+ return err
+ }
+ }
+
+ return rows.Err()
+}
+
func handleGetFileInfo(c echo.Context) error {
id := c.Param("id")
keyHex := c.QueryParam("key")
@@ -175,22 +267,14 @@ func handleGetFileInfo(c echo.Context) error {
return handleError(c, fmt.Errorf("invalid key: %v", err), http.StatusBadRequest)
}
- fileName, encryptedData, err := getFileFromDB(c.Request().Context(), id)
+ fileName, err := getFileNameFromDB(c.Request().Context(), id)
if err != nil {
- return handleError(c, fmt.Errorf("error getting file from database: %v", err), http.StatusInternalServerError)
+ return handleError(c, fmt.Errorf("error getting file name from database: %v", err), http.StatusInternalServerError)
}
- plaintext, err := decryptFile(encryptedData, key)
+ fileSize, err := getTotalFileSize(id, key)
if err != nil {
- return handleError(c, fmt.Errorf("error decrypting file: %v", err), http.StatusInternalServerError)
- }
-
- fileSizeBytes := len(plaintext)
- var fileSize string
- if fileSizeBytes >= 1024*1024 {
- fileSize = fmt.Sprintf("%.2f MB", float64(fileSizeBytes)/(1024*1024))
- } else {
- fileSize = fmt.Sprintf("%.2f KB", float64(fileSizeBytes)/1024)
+ return handleError(c, fmt.Errorf("error getting file size: %v", err), http.StatusInternalServerError)
}
fileInfo := struct {
@@ -204,112 +288,120 @@ func handleGetFileInfo(c echo.Context) error {
return c.JSON(http.StatusOK, fileInfo)
}
-func storeFileInDB(ctx context.Context, id, fileName string, encryptedData []byte) error {
- tx, err := db.BeginTx(ctx, nil)
+func getTotalFileSize(id string, key []byte) (string, error) {
+ var totalSize int64
+ rows, err := db.Query("SELECT chunk_data FROM files WHERE id = $1 ORDER BY chunk_index", id)
if err != nil {
- return err
+ return "", err
}
- defer tx.Rollback()
+ defer rows.Close()
- _, err = tx.ExecContext(ctx, "INSERT INTO files (id, name, data) VALUES ($1, $2, $3)", id, fileName, encryptedData)
- if err != nil {
- return err
+ for rows.Next() {
+ var encryptedData []byte
+ if err := rows.Scan(&encryptedData); err != nil {
+ return "", err
+ }
+
+ plaintext, err := decryptFile(encryptedData, key)
+ if err != nil {
+ return "", err
+ }
+
+ totalSize += int64(len(plaintext))
}
- return tx.Commit()
+ var fileSize string
+ if totalSize >= 1024*1024 {
+ fileSize = fmt.Sprintf("%.2f MB", float64(totalSize)/(1024*1024))
+ } else {
+ fileSize = fmt.Sprintf("%.2f KB", float64(totalSize)/1024)
+ }
+
+ return fileSize, rows.Err()
}
-func getFileFromDB(ctx context.Context, id string) (fileName string, encryptedData []byte, err error) {
- err = db.QueryRowContext(ctx, "SELECT name, data FROM files WHERE id = $1", id).Scan(&fileName, &encryptedData)
- if err == sql.ErrNoRows {
- return "", nil, errors.New("file not found")
- }
- return fileName, encryptedData, err
+func handleError(c echo.Context, err error, status int) error {
+ fmt.Printf("error: %v\n", err)
+ return c.JSON(status, map[string]string{"error": err.Error()})
+}
+
+func generateRandomKey() ([]byte, error) {
+ key := make([]byte, keySize)
+ _, err := rand.Read(key)
+ return key, err
}
-func handleError(c echo.Context, err error, code int) error {
- return c.JSON(code, map[string]string{"error": err.Error()})
+func generateID() string {
+ b := make([]byte, 16)
+ rand.Read(b)
+ return hex.EncodeToString(b)
}
-func encryptFile(in io.Reader, key []byte) ([]byte, error) {
+func encryptFile(plaintext io.Reader, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
- aesgcm, err := cipher.NewGCM(block)
- if err != nil {
+ nonce := make([]byte, nonceSize)
+ if _, err := rand.Read(nonce); err != nil {
return nil, err
}
- nonce := make([]byte, nonceSize)
- if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
+ aesgcm, err := cipher.NewGCM(block)
+ if err != nil {
return nil, err
}
- plaintext, err := io.ReadAll(in)
+ plaintextBytes, err := io.ReadAll(plaintext)
if err != nil {
return nil, err
}
- ciphertext := aesgcm.Seal(nil, nonce, plaintext, nil)
- return append(nonce, ciphertext...), nil
+ ciphertext := aesgcm.Seal(nonce, nonce, plaintextBytes, nil)
+ return ciphertext, nil
}
-func decryptAndStreamFile(w io.Writer, encryptedData []byte, key []byte) error {
- if len(encryptedData) < nonceSize {
- return errors.New("ciphertext too short")
- }
-
+func decryptFile(ciphertext, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
- return err
- }
-
- aesgcm, err := cipher.NewGCM(block)
- if err != nil {
- return err
- }
-
- nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:]
- plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil)
- if err != nil {
- return err
+ return nil, err
}
- _, err = w.Write(plaintext)
- return err
-}
-
-func decryptFile(encryptedData []byte, key []byte) ([]byte, error) {
- if len(encryptedData) < nonceSize {
+ if len(ciphertext) < nonceSize {
return nil, errors.New("ciphertext too short")
}
- block, err := aes.NewCipher(key)
+ nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
+
+ aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
- aesgcm, err := cipher.NewGCM(block)
+ plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
- nonce, ciphertext := encryptedData[:nonceSize], encryptedData[nonceSize:]
- return aesgcm.Open(nil, nonce, ciphertext, nil)
+ return plaintext, nil
}
-func generateRandomKey() ([]byte, error) {
- key := make([]byte, keySize)
- _, err := rand.Read(key)
- return key, err
+func startCleanupScheduler() {
+ ticker := time.NewTicker(24 * time.Hour)
+ go func() {
+ for range ticker.C {
+ cleanupChunks()
+ }
+ }()
}
-func generateID() string {
- b := make([]byte, 16)
- if _, err := rand.Read(b); err != nil {
- panic(err)
+func cleanupChunks() {
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
+ defer cancel()
+
+ _, err := db.ExecContext(ctx, "DELETE FROM chunks WHERE created_at < NOW() - INTERVAL '1 day'")
+ if err != nil {
+ fmt.Printf("error cleaning up chunks: %v\n", err)
}
- return hex.EncodeToString(b)
}