diff --git a/server/main.go b/server/main.go index a4520b9..1eb1b46 100644 --- a/server/main.go +++ b/server/main.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "database/sql" "encoding/hex" + "encoding/json" "errors" "fmt" "io" @@ -31,6 +32,7 @@ func main() { router := mux.NewRouter() router.HandleFunc("/upload", handleUpload).Methods("POST") router.HandleFunc("/download/{id}", handleDownload).Methods("GET") + router.HandleFunc("/get/{id}", handleGetFileInfo).Methods("GET") handler := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, @@ -129,6 +131,49 @@ func handleDownload(w http.ResponseWriter, r *http.Request) { } } +func handleGetFileInfo(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + id := vars["id"] + + keyHex := r.URL.Query().Get("key") + key, err := hex.DecodeString(keyHex) + if err != nil { + handleError(w, errors.New("invalid key"), http.StatusBadRequest) + return + } + + fileName, encryptedData, err := getFileFromDB(id) + if err != nil { + handleError(w, err, http.StatusInternalServerError) + return + } + + plaintext, err := decryptFile(encryptedData, key) + if err != nil { + handleError(w, errors.New("error decrypting file"), http.StatusInternalServerError) + return + } + + fileSizeBytes := len(plaintext) + var fileSize string + if fileSizeBytes >= 1024*1024 { + fileSize = fmt.Sprintf("%.2f MB", float64(fileSizeBytes)/(1024*1024)) + } else { + fileSize = fmt.Sprintf("%.2f KB", float64(fileSizeBytes)/1024) + } + + fileInfo := struct { + FileName string `json:"fileName"` + FileSize string `json:"fileSize"` + }{ + FileName: fileName, + FileSize: fileSize, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(fileInfo) +} + func storeFileInDB(id, fileName string, encryptedData []byte) error { tx, err := db.Begin() if err != nil {