From 2e63a6cfc48469015464fa5c41c6c47b0b736903 Mon Sep 17 00:00:00 2001 From: Alexeyzem Date: Wed, 18 Dec 2024 15:02:32 +0300 Subject: [PATCH] detect type with http.Detected --- internal/fileService/controller/controller.go | 35 ++++++++++++++----- internal/fileService/controller/mock.go | 6 ++-- internal/fileService/service/fileService.go | 5 ++- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/internal/fileService/controller/controller.go b/internal/fileService/controller/controller.go index 0e551121..ab5baebc 100644 --- a/internal/fileService/controller/controller.go +++ b/internal/fileService/controller/controller.go @@ -1,9 +1,11 @@ package controller import ( + "bytes" "context" "errors" "fmt" + "io" "mime/multipart" "net/http" "strings" @@ -26,8 +28,8 @@ var fileFormat = map[string]struct{}{ //go:generate mockgen -destination=mock.go -source=$GOFILE -package=${GOPACKAGE} type fileService interface { Upload(ctx context.Context, name string) ([]byte, error) - Download(ctx context.Context, file multipart.File, format string) (string, error) - DownloadNonImage(ctx context.Context, file multipart.File, format, realName string) (string, error) + Download(ctx context.Context, file io.Reader, format string) (string, error) + DownloadNonImage(ctx context.Context, file io.Reader, format, realName string) (string, error) UploadNonImage(ctx context.Context, name string) ([]byte, error) } @@ -51,6 +53,17 @@ func NewFileController(fileService fileService, responder responder) *FileContro } } +func getFormat(buf []byte) string { + formats := http.DetectContentType(buf) + format := strings.Split(formats, "/")[1] + + if strings.HasPrefix(format, "plain") { + format = "txt" + } + + return format +} + func sanitize(input string) string { sanitizer := bluemonday.UGCPolicy() cleaned := sanitizer.Sanitize(input) @@ -138,23 +151,29 @@ func (fc *FileController) Download(w http.ResponseWriter, r *http.Request) { } }(file) - formats := strings.Split(header.Header.Get("Content-Type"), "/") - if len(formats) != 2 { - fc.responder.ErrorBadRequest(w, my_err.ErrWrongFiletype, reqID) + buf := bytes.NewBuffer(make([]byte, 20)) + n, err := file.Read(buf.Bytes()) + if err != nil { + fc.responder.ErrorBadRequest(w, err, reqID) return } + format := getFormat(buf.Bytes()[:n]) + + _, err = io.Copy(buf, file) + if err != nil { + fc.responder.ErrorBadRequest(w, err, reqID) + } var url string - format := formats[1] if _, ok := fileFormat[format]; ok { - url, err = fc.fileService.Download(r.Context(), file, format) + url, err = fc.fileService.Download(r.Context(), buf, format) } else { name := header.Filename if len(name+format) > 55 { fc.responder.ErrorBadRequest(w, errors.New("file name is too big"), reqID) return } - url, err = fc.fileService.DownloadNonImage(r.Context(), file, format, name) + url, err = fc.fileService.DownloadNonImage(r.Context(), buf, format, name) } if err != nil { diff --git a/internal/fileService/controller/mock.go b/internal/fileService/controller/mock.go index 816a55a1..76b8237e 100644 --- a/internal/fileService/controller/mock.go +++ b/internal/fileService/controller/mock.go @@ -11,7 +11,7 @@ package controller import ( context "context" - multipart "mime/multipart" + io "io" http "net/http" reflect "reflect" @@ -43,7 +43,7 @@ func (m *MockfileService) EXPECT() *MockfileServiceMockRecorder { } // Download mocks base method. -func (m *MockfileService) Download(ctx context.Context, file multipart.File, format string) (string, error) { +func (m *MockfileService) Download(ctx context.Context, file io.Reader, format string) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Download", ctx, file, format) ret0, _ := ret[0].(string) @@ -58,7 +58,7 @@ func (mr *MockfileServiceMockRecorder) Download(ctx, file, format any) *gomock.C } // DownloadNonImage mocks base method. -func (m *MockfileService) DownloadNonImage(ctx context.Context, file multipart.File, format, realName string) (string, error) { +func (m *MockfileService) DownloadNonImage(ctx context.Context, file io.Reader, format, realName string) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DownloadNonImage", ctx, file, format, realName) ret0, _ := ret[0].(string) diff --git a/internal/fileService/service/fileService.go b/internal/fileService/service/fileService.go index b2eb0cc3..fab9845c 100644 --- a/internal/fileService/service/fileService.go +++ b/internal/fileService/service/fileService.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "mime/multipart" "os" "github.com/google/uuid" @@ -16,7 +15,7 @@ func NewFileService() *FileService { return &FileService{} } -func (f *FileService) Download(ctx context.Context, file multipart.File, format string) (string, error) { +func (f *FileService) Download(ctx context.Context, file io.Reader, format string) (string, error) { var ( fileName = uuid.New().String() filePath = fmt.Sprintf("/image/%s.%s", fileName, format) @@ -38,7 +37,7 @@ func (f *FileService) Download(ctx context.Context, file multipart.File, format } func (f *FileService) DownloadNonImage( - ctx context.Context, file multipart.File, format, realName string, + ctx context.Context, file io.Reader, format, realName string, ) (string, error) { var ( fileName = uuid.New().String()