diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 00000000..70dbf5c5 --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,73 @@ +# This workflow will build a golang project +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go + +name: Go + +on: + push: + branches: [ "dev" ] + pull_request: + branches: [ "dev" ] + +jobs: + + linter: + runs-on: ubuntu-latest + steps: + - run: + sudo apt-get install libwebp-dev + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: '1.22' + cache: false + - name: Set environment variables + run: | + echo "GOROOT=$(go env GOROOT)" >> $GITHUB_ENV + echo "GOBIN=$(go env GOBIN)" >> $GITHUB_ENV + - name: Install golangci-lint + run: | + go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + - name: golangci-lint + run: golangci-lint run --new-from-rev origin/dev + + tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: '1.22' + - name: Build + run: go build -v ./... + - name: Test + run: go test -v ./... + + build: + needs: [tests, linter] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Checkout code + uses: actions/checkout@v2 + - name: Set up SSH + uses: webfactory/ssh-agent@v0.5.3 + with: + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} + - name: Fetch .env file from server + run: | + ssh -o StrictHostKeyChecking=no ubuntu@185.241.194.197 ' + # Read the contents of the .env file and output it + cat ~/2024_2_BetterCallFirewall/.env + ' > .env + - name: Login to DockerHub Registry + run: echo ${{ secrets.DOCKERHUB_PASSWORD }} | docker login -u ${{ secrets.DOCKERHUB_LOGIN }} --password-stdin + - name: Build Docker images + run: | + for service in auth authgrpc chat community file post postgrpc profile profilegrpc sticker; do + docker build -t slashlight/${service}:${GITHUB_SHA::8} -t slashlight/${service}:latest -f Dockerfile${service} . + docker push slashlight/${service}:${GITHUB_SHA::8} + docker push slashlight/${service}:latest + done + diff --git a/DB/migrations/000001_init_schema.up.sql b/DB/migrations/000001_init_schema.up.sql index 82b19dac..82633c2a 100644 --- a/DB/migrations/000001_init_schema.up.sql +++ b/DB/migrations/000001_init_schema.up.sql @@ -46,8 +46,8 @@ CREATE TABLE IF NOT EXISTS post ( id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, author_id INT REFERENCES profile(id) ON DELETE CASCADE, community_id INT REFERENCES community(id) ON DELETE CASCADE DEFAULT NULL, - content TEXT CONSTRAINT content_post_length CHECK (CHAR_LENGTH(content) <= 500) DEFAULT '', - file_path TEXT CONSTRAINT file_path_length CHECK (CHAR_LENGTH(file_path) <= 100) DEFAULT '', + content TEXT CONSTRAINT content_post_length CHECK (CHAR_LENGTH(content) <= 1000) DEFAULT '', + file_path TEXT CONSTRAINT file_path_length CHECK (CHAR_LENGTH(file_path) <= 1000) DEFAULT '', created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() ); @@ -57,6 +57,8 @@ CREATE TABLE IF NOT EXISTS message ( receiver INT REFERENCES profile(id) ON DELETE CASCADE , sender INT REFERENCES profile(id) ON DELETE CASCADE , content TEXT CONSTRAINT content_message_length CHECK (CHAR_LENGTH(content) <= 500) DEFAULT '', + file_path TEXT CONSTRAINT file_path_message_length CHECK (CHAR_LENGTH(file_path) <= 1000) DEFAULT '', + sticker_path TEXT CONSTRAINT sticker_path_message_length CHECK (CHAR_LENGTH(sticker_path) <= 100) DEFAULT '', is_read BOOLEAN DEFAULT FALSE, created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() @@ -67,6 +69,8 @@ CREATE TABLE IF NOT EXISTS comment ( user_id INT REFERENCES profile(id) ON DELETE CASCADE , post_id INT REFERENCES post(id) ON DELETE CASCADE , content TEXT CONSTRAINT content_comment_length CHECK (CHAR_LENGTH(content) <= 500) DEFAULT '', + file_path TEXT CONSTRAINT file_path_comment CHECK ( CHAR_LENGTH(file_path) <= 1000 ) DEFAULT '', + sticker_path TEXT CONSTRAINT sticker_path_comment_length CHECK (CHAR_LENGTH(sticker_path) <= 200) DEFAULT '', created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() ); @@ -79,6 +83,12 @@ CREATE TABLE IF NOT EXISTS reaction ( updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() ); +CREATE TABLE IF NOT EXISTS sticker ( + id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + file_path TEXT CONSTRAINT file_path_length CHECK (CHAR_LENGTH(file_path) <= 100) DEFAULT '', + profile_id INT REFERENCES profile(id) +); + ALTER TABLE friend ADD FOREIGN KEY ("sender") REFERENCES profile(id) ON DELETE CASCADE, ADD FOREIGN KEY ("receiver") REFERENCES profile(id) ON DELETE CASCADE ; diff --git a/Dockerfilesticker b/Dockerfilesticker new file mode 100644 index 00000000..c6aa08d4 --- /dev/null +++ b/Dockerfilesticker @@ -0,0 +1,25 @@ +FROM golang:alpine AS build + +WORKDIR /stickers + +COPY go.mod . +COPY go.sum . + +RUN go mod download +RUN go mod vendor + +COPY . . + +RUN go build cmd/stickers/main.go + +FROM alpine:latest + +WORKDIR /stickers + +EXPOSE 8088 + +COPY .env . + +COPY --from=build /stickers/main /stickers/main + +CMD ["./main"] \ No newline at end of file diff --git a/Makefile b/Makefile index 9be8968d..41949239 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,8 @@ test: - go test -v ./... -coverprofile=cover.out && go tool cover -html=cover.out -o cover.html + go test ./... -coverprofile=cover.out \ + && go tool cover -func=cover.out | grep -vE "*mock.go|*easyjson.go|*pb.go|*mock_helper.go" \ + && go tool cover -html=cover.out -o cover.html + start: docker compose up --build @@ -16,4 +19,7 @@ gen-proto: proto/*.proto lint: - golangci-lint run \ No newline at end of file + golangci-lint run + +gen-easy-json: + easyjson -all internal/models/*.go \ No newline at end of file diff --git a/cmd/gRPC/auth/main.go b/cmd/gRPC/auth/main.go index 50247e3a..ccc51384 100644 --- a/cmd/gRPC/auth/main.go +++ b/cmd/gRPC/auth/main.go @@ -40,7 +40,9 @@ func main() { go func() { http.Handle("/api/v1/metrics", promhttp.Handler()) - http.ListenAndServe(":6001", nil) + if err = http.ListenAndServe(":6001", nil); err != nil { + panic(err) + } }() log.Printf("Listening on :%s with protocol gRPC", cfg.AUTHGRPC.Port) diff --git a/cmd/gRPC/post/main.go b/cmd/gRPC/post/main.go index 1d233378..f2bd12f9 100644 --- a/cmd/gRPC/post/main.go +++ b/cmd/gRPC/post/main.go @@ -38,7 +38,9 @@ func main() { } go func() { http.Handle("/api/v1/metrics", promhttp.Handler()) - http.ListenAndServe(":6002", nil) + if err = http.ListenAndServe(":6002", nil); err != nil { + panic(err) + } }() log.Printf("Listening on :%s with protocol gRPC", cfg.POSTGRPC.Port) diff --git a/cmd/gRPC/profile/main.go b/cmd/gRPC/profile/main.go index a7d244d2..4a3cd20a 100644 --- a/cmd/gRPC/profile/main.go +++ b/cmd/gRPC/profile/main.go @@ -39,7 +39,9 @@ func main() { go func() { http.Handle("/api/v1/metrics", promhttp.Handler()) - http.ListenAndServe(":6003", nil) + if err = http.ListenAndServe(":6003", nil); err != nil { + panic(err) + } }() log.Printf("Listening on :%s with protocol gRPC", cfg.PROFILEGRPC.Port) diff --git a/cmd/stickers/main.go b/cmd/stickers/main.go new file mode 100644 index 00000000..2596a54a --- /dev/null +++ b/cmd/stickers/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "flag" + "log" + + "github.com/2024_2_BetterCallFirewall/internal/app/stickers" + "github.com/2024_2_BetterCallFirewall/internal/config" +) + +func main() { + confPath := flag.String("c", ".env", "path to config file") + flag.Parse() + + cfg, err := config.GetConfig(*confPath) + if err != nil { + panic(err) + } + + server, err := stickers.GetHTTPServer(cfg) + if err != nil { + panic(err) + } + + log.Printf("Starting server on port %s", cfg.STICKER.Port) + if err := server.ListenAndServe(); err != nil { + panic(err) + } +} diff --git a/docker-compose.yml b/docker-compose.yml index 7eedb08f..f14a1aa4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,9 +2,7 @@ version: '3.9' services: authgrpc: - build: - context: . - dockerfile: Dockerfileauthgrpc + image: slashlight/authgrpc:latest restart: always ports: - "7072:7072" @@ -12,9 +10,7 @@ services: - redis profilegrpc: - build: - context: . - dockerfile: Dockerfileprofilegrpc + image: slashlight/profilegrpc:latest restart: always ports: - "7074:7074" @@ -22,9 +18,7 @@ services: - db - authgrpc postgrpc: - build: - context: . - dockerfile: Dockerfilepostgrpc + image: slashlight/postgrpc:latest restart: always ports: - "7075:7075" @@ -33,9 +27,7 @@ services: - profilegrpc community: - build: - context: . - dockerfile: Dockerfilecommunity + image: slashlight/community:latest restart: always ports: - "8086:8086" @@ -45,9 +37,7 @@ services: - authgrpc auth: - build: - context: . - dockerfile: Dockerfileauth + image: slashlight/auth:latest restart: always ports: - "8082:8082" @@ -57,11 +47,10 @@ services: - community file: - build: - context: . - dockerfile: Dockerfilefile + image: slashlight/file:latest volumes: - ./image:/image + - ./files:/files restart: always ports: - "8083:8083" @@ -70,9 +59,7 @@ services: - auth profile: - build: - context: . - dockerfile: Dockerfileprofile + image: slashlight/profile:latest restart: always ports: - "8084:8084" @@ -83,9 +70,7 @@ services: - file post: - build: - context: . - dockerfile: Dockerfilepost + image: slashlight/post:latest restart: always ports: - "8085:8085" @@ -96,9 +81,7 @@ services: - community - profile chat: - build: - context: . - dockerfile: Dockerfilechat + image: slashlight/chat:latest restart: always ports: - "8087:8087" @@ -107,6 +90,15 @@ services: - authgrpc - post + stickers: + image: slashlight/sticker:latest + restart: always + ports: + - "8088:8088" + depends_on: + - db + - authgrpc + db: image: postgres:latest command: -c config_file=/etc/postgresql/postgresql.conf @@ -179,5 +171,15 @@ services: ports: - "9100:9100" + watchtower: + image: containrrr/watchtower + volumes: + - /var/run/docker.sock:/var/run/docker.sock + environment: + - WATCHTOWER_POLL_INTERVAL=60 + - WATCHTOWER_USERNAME=${DOCKER_USERNAME} + - WATCHTOWER_PASSWORD=${DOCKER_PASSWORD} + + volumes: - postgres_data: \ No newline at end of file + postgres_data: diff --git a/go.mod b/go.mod index f28e9261..95021e12 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,8 @@ require ( github.com/jackc/pgx/v5 v5.7.1 github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 + github.com/mailru/easyjson v0.7.7 + github.com/microcosm-cc/bluemonday v1.0.27 github.com/prometheus/client_golang v1.20.5 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.9.0 @@ -24,15 +26,18 @@ require ( ) require ( + github.com/aymerick/douceur v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cockroachdb/apd v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/gofrs/uuid v4.4.0+incompatible // indirect + github.com/gorilla/css v1.0.1 // indirect github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/kr/text v0.2.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect diff --git a/go.sum b/go.sum index ff66a4f6..ba7f8257 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= +github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -22,6 +24,8 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= +github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= @@ -40,6 +44,8 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= @@ -51,6 +57,10 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk= +github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/internal/api/grpc/post_api/grpc_server.go b/internal/api/grpc/post_api/grpc_server.go index b98cc54c..99f2a6c6 100644 --- a/internal/api/grpc/post_api/grpc_server.go +++ b/internal/api/grpc/post_api/grpc_server.go @@ -40,24 +40,33 @@ func (a *Adapter) GetAuthorsPosts(ctx context.Context, req *Request) (*Response, resp := &Response{ Posts: make([]*Post, 0, len(res)), } + for _, post := range res { - resp.Posts = append(resp.Posts, &Post{ - ID: post.ID, - Head: &Header{ - AuthorID: post.Header.AuthorID, - CommunityID: post.Header.CommunityID, - Author: post.Header.Author, - Avatar: string(post.Header.Avatar), - }, - PostContent: &Content{ - Text: post.PostContent.Text, - File: string(post.PostContent.File), - CreatedAt: post.PostContent.CreatedAt.Unix(), - UpdatedAt: post.PostContent.UpdatedAt.Unix(), + files := make([]string, 0, len(post.PostContent.File)) + for _, f := range post.PostContent.File { + files = append(files, string(f)) + } + + resp.Posts = append( + resp.Posts, &Post{ + ID: post.ID, + Head: &Header{ + AuthorID: post.Header.AuthorID, + CommunityID: post.Header.CommunityID, + Author: post.Header.Author, + Avatar: string(post.Header.Avatar), + }, + PostContent: &Content{ + Text: post.PostContent.Text, + File: files, + CreatedAt: post.PostContent.CreatedAt.Unix(), + UpdatedAt: post.PostContent.UpdatedAt.Unix(), + }, + LikesCount: post.LikesCount, + IsLiked: post.IsLiked, + CommentCount: post.CommentCount, }, - LikesCount: post.LikesCount, - IsLiked: post.IsLiked, - }) + ) } return resp, nil diff --git a/internal/api/grpc/post_api/grpc_server_test.go b/internal/api/grpc/post_api/grpc_server_test.go index 8a3f1649..fcbfcb71 100644 --- a/internal/api/grpc/post_api/grpc_server_test.go +++ b/internal/api/grpc/post_api/grpc_server_test.go @@ -68,9 +68,12 @@ func TestGetAuthorsPosts(t *testing.T) { return &Response{ Posts: []*Post{ { - ID: 1, - PostContent: &Content{Text: "New Post", CreatedAt: createTime.Unix(), UpdatedAt: createTime.Unix()}, - Head: &Header{AuthorID: 1, Author: "Alexey Zemliakov"}, + ID: 1, + PostContent: &Content{ + Text: "New Post", CreatedAt: createTime.Unix(), UpdatedAt: createTime.Unix(), + File: []string{}, + }, + Head: &Header{AuthorID: 1, Author: "Alexey Zemliakov"}, }, }, }, @@ -82,9 +85,11 @@ func TestGetAuthorsPosts(t *testing.T) { Return( []*models.Post{ { - ID: 1, - PostContent: models.Content{Text: "New Post", CreatedAt: createTime, UpdatedAt: createTime}, - Header: models.Header{AuthorID: 1, Author: "Alexey Zemliakov"}, + ID: 1, + PostContent: models.Content{ + Text: "New Post", CreatedAt: createTime, UpdatedAt: createTime, + }, + Header: models.Header{AuthorID: 1, Author: "Alexey Zemliakov"}, }, }, nil, @@ -94,29 +99,31 @@ func TestGetAuthorsPosts(t *testing.T) { } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - adapter, mock := getAdapter(ctrl) - ctx := context.Background() + adapter, mock := getAdapter(ctrl) + ctx := context.Background() - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } - v.SetupMock(input, mock) + v.SetupMock(input, mock) - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } - actual, err := v.Run(ctx, adapter, input) - assert.Equal(t, res, actual) - assert.Equal(t, status.Code(err), v.ExpectedErrCode) - }) + actual, err := v.Run(ctx, adapter, input) + assert.Equal(t, res, actual) + assert.Equal(t, status.Code(err), v.ExpectedErrCode) + }, + ) } } diff --git a/internal/api/grpc/post_api/post.pb.go b/internal/api/grpc/post_api/post.pb.go index b8f5db25..b4241ae2 100644 --- a/internal/api/grpc/post_api/post.pb.go +++ b/internal/api/grpc/post_api/post.pb.go @@ -198,11 +198,12 @@ type Post struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ID uint32 `protobuf:"varint,1,opt,name=ID,proto3" json:"ID,omitempty"` - PostContent *Content `protobuf:"bytes,2,opt,name=PostContent,proto3" json:"PostContent,omitempty"` - Head *Header `protobuf:"bytes,3,opt,name=Head,proto3" json:"Head,omitempty"` - LikesCount uint32 `protobuf:"varint,4,opt,name=LikesCount,proto3" json:"LikesCount,omitempty"` - IsLiked bool `protobuf:"varint,5,opt,name=IsLiked,proto3" json:"IsLiked,omitempty"` + ID uint32 `protobuf:"varint,1,opt,name=ID,proto3" json:"ID,omitempty"` + PostContent *Content `protobuf:"bytes,2,opt,name=PostContent,proto3" json:"PostContent,omitempty"` + Head *Header `protobuf:"bytes,3,opt,name=Head,proto3" json:"Head,omitempty"` + LikesCount uint32 `protobuf:"varint,4,opt,name=LikesCount,proto3" json:"LikesCount,omitempty"` + IsLiked bool `protobuf:"varint,5,opt,name=IsLiked,proto3" json:"IsLiked,omitempty"` + CommentCount uint32 `protobuf:"varint,6,opt,name=CommentCount,proto3" json:"CommentCount,omitempty"` } func (x *Post) Reset() { @@ -272,15 +273,22 @@ func (x *Post) GetIsLiked() bool { return false } +func (x *Post) GetCommentCount() uint32 { + if x != nil { + return x.CommentCount + } + return 0 +} + type Content struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Text string `protobuf:"bytes,1,opt,name=Text,proto3" json:"Text,omitempty"` - File string `protobuf:"bytes,2,opt,name=File,proto3" json:"File,omitempty"` - CreatedAt int64 `protobuf:"varint,3,opt,name=CreatedAt,proto3" json:"CreatedAt,omitempty"` - UpdatedAt int64 `protobuf:"varint,4,opt,name=UpdatedAt,proto3" json:"UpdatedAt,omitempty"` + Text string `protobuf:"bytes,1,opt,name=Text,proto3" json:"Text,omitempty"` + File []string `protobuf:"bytes,2,rep,name=File,proto3" json:"File,omitempty"` + CreatedAt int64 `protobuf:"varint,3,opt,name=CreatedAt,proto3" json:"CreatedAt,omitempty"` + UpdatedAt int64 `protobuf:"varint,4,opt,name=UpdatedAt,proto3" json:"UpdatedAt,omitempty"` } func (x *Content) Reset() { @@ -322,11 +330,11 @@ func (x *Content) GetText() string { return "" } -func (x *Content) GetFile() string { +func (x *Content) GetFile() []string { if x != nil { return x.File } - return "" + return nil } func (x *Content) GetCreatedAt() int64 { @@ -363,7 +371,7 @@ var file_proto_post_proto_rawDesc = []byte{ 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x05, 0x50, 0x6f, 0x73, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x70, 0x6f, 0x73, 0x74, 0x5f, 0x61, 0x70, 0x69, 0x2e, 0x50, 0x6f, 0x73, 0x74, 0x52, 0x05, 0x50, 0x6f, 0x73, 0x74, 0x73, 0x22, - 0xab, 0x01, 0x0a, 0x04, 0x50, 0x6f, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, + 0xcf, 0x01, 0x0a, 0x04, 0x50, 0x6f, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x02, 0x49, 0x44, 0x12, 0x33, 0x0a, 0x0b, 0x50, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x70, 0x6f, 0x73, 0x74, 0x5f, 0x61, 0x70, 0x69, 0x2e, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, @@ -373,24 +381,26 @@ var file_proto_post_proto_rawDesc = []byte{ 0x65, 0x61, 0x64, 0x12, 0x1e, 0x0a, 0x0a, 0x4c, 0x69, 0x6b, 0x65, 0x73, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x4c, 0x69, 0x6b, 0x65, 0x73, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x49, 0x73, 0x4c, 0x69, 0x6b, 0x65, 0x64, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x49, 0x73, 0x4c, 0x69, 0x6b, 0x65, 0x64, 0x22, 0x6d, 0x0a, - 0x07, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x65, 0x78, 0x74, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x54, 0x65, 0x78, 0x74, 0x12, 0x12, 0x0a, 0x04, - 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x46, 0x69, 0x6c, 0x65, - 0x12, 0x1c, 0x0a, 0x09, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x09, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x1c, - 0x0a, 0x09, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x03, 0x52, 0x09, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x32, 0x49, 0x0a, 0x0b, - 0x50, 0x6f, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x3a, 0x0a, 0x0f, 0x47, - 0x65, 0x74, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x73, 0x50, 0x6f, 0x73, 0x74, 0x73, 0x12, 0x11, - 0x2e, 0x70, 0x6f, 0x73, 0x74, 0x5f, 0x61, 0x70, 0x69, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x12, 0x2e, 0x70, 0x6f, 0x73, 0x74, 0x5f, 0x61, 0x70, 0x69, 0x2e, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x41, 0x5a, 0x3f, 0x67, 0x69, 0x74, 0x68, 0x75, - 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x32, 0x30, 0x32, 0x34, 0x5f, 0x32, 0x5f, 0x42, 0x65, 0x74, - 0x74, 0x65, 0x72, 0x43, 0x61, 0x6c, 0x6c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x2f, - 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x67, 0x72, 0x70, - 0x63, 0x2f, 0x70, 0x6f, 0x73, 0x74, 0x5f, 0x61, 0x70, 0x69, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x49, 0x73, 0x4c, 0x69, 0x6b, 0x65, 0x64, 0x12, 0x22, 0x0a, + 0x0c, 0x43, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x06, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x0c, 0x43, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x43, 0x6f, 0x75, 0x6e, + 0x74, 0x22, 0x6d, 0x0a, 0x07, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x12, 0x12, 0x0a, 0x04, + 0x54, 0x65, 0x78, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x54, 0x65, 0x78, 0x74, + 0x12, 0x12, 0x0a, 0x04, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, + 0x46, 0x69, 0x6c, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, + 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, + 0x41, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, + 0x32, 0x49, 0x0a, 0x0b, 0x50, 0x6f, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, + 0x3a, 0x0a, 0x0f, 0x47, 0x65, 0x74, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x73, 0x50, 0x6f, 0x73, + 0x74, 0x73, 0x12, 0x11, 0x2e, 0x70, 0x6f, 0x73, 0x74, 0x5f, 0x61, 0x70, 0x69, 0x2e, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x70, 0x6f, 0x73, 0x74, 0x5f, 0x61, 0x70, 0x69, + 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x41, 0x5a, 0x3f, 0x67, + 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x32, 0x30, 0x32, 0x34, 0x5f, 0x32, + 0x5f, 0x42, 0x65, 0x74, 0x74, 0x65, 0x72, 0x43, 0x61, 0x6c, 0x6c, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x61, 0x70, 0x69, + 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x6f, 0x73, 0x74, 0x5f, 0x61, 0x70, 0x69, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/internal/app/post/app.go b/internal/app/post/app.go index 617a3c3f..fc16d0c8 100644 --- a/internal/app/post/app.go +++ b/internal/app/post/app.go @@ -73,7 +73,8 @@ func GetHTTPServer(cfg *config.Config, postMetric *metrics.HttpMetrics) (*http.S cp := community.New(communityProvider) postService := service.NewPostServiceImpl(repo, pp, cp) - postController := controller.NewPostController(postService, responder) + commentService := service.NewCommentService(repo, pp) + postController := controller.NewPostController(postService, commentService, responder) rout := post.NewRouter(postController, sm, logger, postMetric) server := &http.Server{ diff --git a/internal/app/stickers/app.go b/internal/app/stickers/app.go new file mode 100644 index 00000000..f3e47550 --- /dev/null +++ b/internal/app/stickers/app.go @@ -0,0 +1,64 @@ +package stickers + +import ( + "fmt" + "net/http" + + "github.com/sirupsen/logrus" + + "github.com/2024_2_BetterCallFirewall/internal/config" + "github.com/2024_2_BetterCallFirewall/internal/ext_grpc" + "github.com/2024_2_BetterCallFirewall/internal/ext_grpc/adapter/auth" + "github.com/2024_2_BetterCallFirewall/internal/router" + "github.com/2024_2_BetterCallFirewall/internal/router/stickers" + "github.com/2024_2_BetterCallFirewall/internal/stickers/controller" + "github.com/2024_2_BetterCallFirewall/internal/stickers/repository" + "github.com/2024_2_BetterCallFirewall/internal/stickers/service" + "github.com/2024_2_BetterCallFirewall/pkg/start_postgres" +) + +func GetHTTPServer(cfg *config.Config) (*http.Server, error) { + logger := logrus.New() + logger.Formatter = &logrus.TextFormatter{ + FullTimestamp: true, + DisableColors: false, + TimestampFormat: "2006-01-02 15:04:05", + ForceColors: true, + } + + connStr := fmt.Sprintf( + "host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", + cfg.DB.Host, + cfg.DB.Port, + cfg.DB.User, + cfg.DB.Pass, + cfg.DB.DBName, + cfg.DB.SSLMode, + ) + + postgresDB, err := start_postgres.StartPostgres(connStr, logger) + if err != nil { + return nil, err + } + + responder := router.NewResponder(logger) + provider, err := ext_grpc.GetGRPCProvider(cfg.AUTHGRPC.Host, cfg.AUTHGRPC.Port) + if err != nil { + return nil, err + } + sm := auth.New(provider) + + repo := repository.NewStickerRepo(postgresDB) + stickerService := service.NewStickerUsecase(repo) + stickerController := controller.NewStickerController(stickerService, responder) + + rout := stickers.NewRouter(stickerController, sm, logger) + server := &http.Server{ + Handler: rout, + Addr: fmt.Sprintf(":%s", cfg.STICKER.Port), + ReadTimeout: cfg.STICKER.ReadTimeout, + WriteTimeout: cfg.STICKER.WriteTimeout, + } + + return server, nil +} diff --git a/internal/app/stickers/app_test.go b/internal/app/stickers/app_test.go new file mode 100644 index 00000000..35751bec --- /dev/null +++ b/internal/app/stickers/app_test.go @@ -0,0 +1,26 @@ +package stickers + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/2024_2_BetterCallFirewall/internal/config" +) + +func TestGetServer(t *testing.T) { + server, err := GetHTTPServer( + &config.Config{ + DB: config.DBConnect{ + Port: "test", + Host: "test", + DBName: "test", + User: "test", + Pass: "test", + SSLMode: "test", + }, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, server) +} diff --git a/internal/auth/controller/controller.go b/internal/auth/controller/controller.go index 18cb4777..96aaca2f 100644 --- a/internal/auth/controller/controller.go +++ b/internal/auth/controller/controller.go @@ -2,20 +2,25 @@ package controller import ( "context" - "encoding/json" "errors" "fmt" "net/http" + "regexp" "time" + "github.com/mailru/easyjson" + "github.com/microcosm-cc/bluemonday" "golang.org/x/crypto/bcrypt" "github.com/2024_2_BetterCallFirewall/internal/auth" + "github.com/2024_2_BetterCallFirewall/internal/middleware" "github.com/2024_2_BetterCallFirewall/internal/models" "github.com/2024_2_BetterCallFirewall/pkg/my_err" ) +var emailRegex = regexp.MustCompile(`^[\w-.]+@([\w-]+\.)\w{2,4}$`) + type AuthService interface { Register(user models.User, ctx context.Context) (uint32, error) Auth(user models.User, ctx context.Context) (uint32, error) @@ -35,7 +40,9 @@ type AuthController struct { SessionManager auth.SessionManager } -func NewAuthController(responder Responder, serviceAuth AuthService, sessionManager auth.SessionManager) *AuthController { +func NewAuthController( + responder Responder, serviceAuth AuthService, sessionManager auth.SessionManager, +) *AuthController { return &AuthController{ responder: responder, serviceAuth: serviceAuth, @@ -43,21 +50,38 @@ func NewAuthController(responder Responder, serviceAuth AuthService, sessionMana } } +func sanitize(input string) string { + sanitizer := bluemonday.UGCPolicy() + cleaned := sanitizer.Sanitize(input) + return cleaned +} + func (c *AuthController) Register(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { c.responder.LogError(my_err.ErrInvalidContext, "") } user := models.User{} - err := json.NewDecoder(r.Body).Decode(&user) + err := easyjson.UnmarshalFromReader(r.Body, &user) if err != nil { c.responder.ErrorBadRequest(w, fmt.Errorf("router register: %w", err), reqID) return } + user.FirstName = sanitize(user.FirstName) + user.LastName = sanitize(user.LastName) + user.Email = sanitize(user.Email) + + if !validate(user) { + c.responder.ErrorBadRequest(w, my_err.ErrBadUserInfo, reqID) + return + } + user.ID, err = c.serviceAuth.Register(user, r.Context()) - if errors.Is(err, my_err.ErrUserAlreadyExists) || errors.Is(err, my_err.ErrNonValidEmail) || errors.Is(err, bcrypt.ErrPasswordTooLong) { + if errors.Is(err, my_err.ErrUserAlreadyExists) || errors.Is(err, my_err.ErrNonValidEmail) || errors.Is( + err, bcrypt.ErrPasswordTooLong, + ) { c.responder.ErrorBadRequest(w, err, reqID) return } @@ -78,6 +102,8 @@ func (c *AuthController) Register(w http.ResponseWriter, r *http.Request) { Value: sess.ID, Path: "/", HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, Expires: time.Now().AddDate(0, 0, 1), } @@ -87,17 +113,24 @@ func (c *AuthController) Register(w http.ResponseWriter, r *http.Request) { } func (c *AuthController) Auth(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { c.responder.LogError(my_err.ErrInvalidContext, "") } user := models.User{} - err := json.NewDecoder(r.Body).Decode(&user) + err := easyjson.UnmarshalFromReader(r.Body, &user) if err != nil { c.responder.ErrorBadRequest(w, fmt.Errorf("router auth: %w", err), reqID) return } + user.Email = sanitize(user.Email) + user.Password = sanitize(user.Password) + + if !validateAuth(user) { + c.responder.ErrorBadRequest(w, my_err.ErrBadUserInfo, reqID) + return + } id, err := c.serviceAuth.Auth(user, r.Context()) @@ -121,6 +154,8 @@ func (c *AuthController) Auth(w http.ResponseWriter, r *http.Request) { Value: sess.ID, Path: "/", HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, Expires: time.Now().AddDate(0, 0, 1), } http.SetCookie(w, cookie) @@ -129,7 +164,7 @@ func (c *AuthController) Auth(w http.ResponseWriter, r *http.Request) { } func (c *AuthController) Logout(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { c.responder.LogError(my_err.ErrInvalidContext, "") } @@ -157,9 +192,27 @@ func (c *AuthController) Logout(w http.ResponseWriter, r *http.Request) { Value: sess.ID, Path: "/", HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, Expires: time.Now().AddDate(0, 0, -1), } http.SetCookie(w, cookie) c.responder.OutputJSON(w, "user logout", reqID) } + +func validate(user models.User) bool { + if len([]rune(user.FirstName)) < 3 || len([]rune(user.LastName)) < 3 || len([]rune(user.Password)) < 6 || + len([]rune(user.FirstName)) > 30 || len([]rune(user.LastName)) > 30 { + return false + } + return true +} + +func validateAuth(user models.User) bool { + if len([]rune(user.Password)) < 6 || !emailRegex.MatchString(user.Email) { + return false + } + + return true +} diff --git a/internal/auth/controller/controller_test.go b/internal/auth/controller/controller_test.go index 200c50b6..10310b77 100644 --- a/internal/auth/controller/controller_test.go +++ b/internal/auth/controller/controller_test.go @@ -10,6 +10,8 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/2024_2_BetterCallFirewall/internal/models" "github.com/2024_2_BetterCallFirewall/pkg/my_err" ) @@ -110,12 +112,19 @@ type TestCase struct { func TestRegister(t *testing.T) { controller := NewAuthController(&MockResponder{}, MockAuthService{}, MockSessionManager{}) - jsonUser0, _ := json.Marshal(models.User{ID: 0}) - jsonUser1, _ := json.Marshal(models.User{ID: 1}) - jsonUser2, _ := json.Marshal(models.User{ID: 2}) - jsonUser3, _ := json.Marshal(models.User{ID: 3}) + jsonUser0, _ := json.Marshal(models.User{ID: 0, FirstName: "Alex", LastName: "Zem", Password: "password"}) + jsonUser1, _ := json.Marshal(models.User{ID: 1, FirstName: "Alex", LastName: "Zem", Password: "password"}) + jsonUser2, _ := json.Marshal(models.User{ID: 2, FirstName: "Alex", LastName: "Zem", Password: "password"}) + jsonUser3, _ := json.Marshal(models.User{ID: 3, FirstName: "Alex", LastName: "Zem", Password: "password"}) + jsonUser, _ := json.Marshal(models.User{ID: 3}) testCases := []TestCase{ + { + w: httptest.NewRecorder(), + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(jsonUser)), + wantCode: http.StatusBadRequest, + wantBody: "bad request error", + }, { w: httptest.NewRecorder(), r: httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer([]byte("wrong json"))), @@ -159,14 +168,28 @@ func TestRegister(t *testing.T) { } } +func TestSanitize(t *testing.T) { + test := "" + expected := "" + res := sanitize(test) + assert.Equal(t, expected, res) +} + func TestAuth(t *testing.T) { controller := NewAuthController(&MockResponder{}, MockAuthService{}, MockSessionManager{}) - jsonUser0, _ := json.Marshal(models.User{ID: 0}) - jsonUser1, _ := json.Marshal(models.User{ID: 1}) - jsonUser2, _ := json.Marshal(models.User{ID: 2}) - jsonUser3, _ := json.Marshal(models.User{ID: 3}) + jsonUser0, _ := json.Marshal(models.User{ID: 0, Email: "test@test.ru", Password: "password"}) + jsonUser1, _ := json.Marshal(models.User{ID: 1, Email: "test@test.ru", Password: "password"}) + jsonUser2, _ := json.Marshal(models.User{ID: 2, Email: "test@test.ru", Password: "password"}) + jsonUser3, _ := json.Marshal(models.User{ID: 3, Email: "test@test.ru", Password: "password"}) + jsonUser, _ := json.Marshal(models.User{ID: 3}) testCases := []TestCase{ + { + w: httptest.NewRecorder(), + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(jsonUser)), + wantCode: http.StatusBadRequest, + wantBody: "bad request error", + }, { w: httptest.NewRecorder(), r: httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer([]byte("wrong json"))), diff --git a/internal/auth/service/auth.go b/internal/auth/service/auth.go index abe73157..3c3bf7e7 100644 --- a/internal/auth/service/auth.go +++ b/internal/auth/service/auth.go @@ -13,6 +13,8 @@ import ( "github.com/2024_2_BetterCallFirewall/pkg/my_err" ) +var emailRegex = regexp.MustCompile(`^[\w-.]+@([\w-]+\.)\w{2,4}$`) + type UserRepo interface { Create(ctx context.Context, user *models.User) (uint32, error) GetByEmail(ctx context.Context, email string) (*models.User, error) @@ -70,6 +72,5 @@ func (a *AuthServiceImpl) Auth(user models.User, ctx context.Context) (uint32, e } func (a *AuthServiceImpl) validateEmail(email string) bool { - emailRegex := regexp.MustCompile(`^[\w-.]+@([\w-]+\.)\w{2,4}$`) - return emailRegex.MatchString(email) + return emailRegex.MatchString(email) && len(email) < 50 } diff --git a/internal/chat/controller/client.go b/internal/chat/controller/client.go index 5ae3558b..35d2fd19 100644 --- a/internal/chat/controller/client.go +++ b/internal/chat/controller/client.go @@ -1,32 +1,60 @@ package controller import ( - "encoding/json" + "fmt" "time" "github.com/gorilla/websocket" + "github.com/mailru/easyjson" + "github.com/microcosm-cc/bluemonday" "github.com/2024_2_BetterCallFirewall/internal/models" ) +const wc = "websocket" + type Client struct { Socket *websocket.Conn Receive chan *models.Message chatController *ChatController } +func sanitize(input string) string { + sanitizer := bluemonday.UGCPolicy() + cleaned := sanitizer.Sanitize(input) + return cleaned +} + +func sanitizeFiles(input []string) []string { + var output []string + for _, f := range input { + res := sanitize(f) + if res != "" { + output = append(output, res) + } + } + + return output +} + func (c *Client) Read(userID uint32) { defer c.Socket.Close() for { msg := &models.Message{} _, jsonMessage, err := c.Socket.ReadMessage() if err != nil { + c.chatController.responder.LogError(fmt.Errorf("read message: %w", err), wc) return } - err = json.Unmarshal(jsonMessage, msg) + + err = easyjson.Unmarshal(jsonMessage, msg) if err != nil { + c.chatController.responder.LogError(err, wc) return } + msg.Content.Text = sanitize(msg.Content.Text) + msg.Content.FilePath = sanitizeFiles(msg.Content.FilePath) + msg.Content.StickerPath = sanitize(msg.Content.StickerPath) msg.Sender = userID c.chatController.Messages <- msg } @@ -36,12 +64,17 @@ func (c *Client) Write() { defer c.Socket.Close() for msg := range c.Receive { msg.CreatedAt = time.Now() - jsonForSend, err := json.Marshal(msg) + msg.Content.Text = sanitize(msg.Content.Text) + msg.Content.FilePath = sanitizeFiles(msg.Content.FilePath) + msg.Content.StickerPath = sanitize(msg.Content.StickerPath) + jsonForSend, err := easyjson.Marshal(msg) if err != nil { + c.chatController.responder.LogError(err, wc) return } err = c.Socket.WriteMessage(websocket.TextMessage, jsonForSend) if err != nil { + c.chatController.responder.LogError(fmt.Errorf("write message: %w", err), wc) return } } diff --git a/internal/chat/controller/controller.go b/internal/chat/controller/controller.go index 4cb7220c..577eeaeb 100644 --- a/internal/chat/controller/controller.go +++ b/internal/chat/controller/controller.go @@ -12,6 +12,7 @@ import ( "github.com/gorilla/websocket" "github.com/2024_2_BetterCallFirewall/internal/chat" + "github.com/2024_2_BetterCallFirewall/internal/middleware" "github.com/2024_2_BetterCallFirewall/internal/models" "github.com/2024_2_BetterCallFirewall/pkg/my_err" ) @@ -55,7 +56,7 @@ var ( ) func (cc *ChatController) SetConnection(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { cc.responder.LogError(my_err.ErrInvalidContext, "") return @@ -86,29 +87,49 @@ func (cc *ChatController) SetConnection(w http.ResponseWriter, r *http.Request) }() go client.Write() go client.Read(sess.UserID) - cc.SendChatMsg(ctx, reqID) + cc.SendChatMsg(ctx, reqID, w) } -func (cc *ChatController) SendChatMsg(ctx context.Context, reqID string) { +func validate(content models.MessageContent) bool { + if len(content.FilePath) > 10 || len([]rune(content.StickerPath)) > 100 || len([]rune(content.Text)) > 500 { + return false + } + if content.StickerPath != "" && (len(content.FilePath) > 0 || content.Text != "") { + return false + } + if content.Text == "" && content.StickerPath == "" && len(content.FilePath) == 0 { + return false + } + + return true +} + +func (cc *ChatController) SendChatMsg(ctx context.Context, reqID string, w http.ResponseWriter) { for msg := range cc.Messages { - err := cc.chatService.SendNewMessage(ctx, msg.Receiver, msg.Sender, msg.Content) + if !validate(msg.Content) { + cc.responder.ErrorBadRequest(w, my_err.ErrBadMessageContent, reqID) + return + } + msg := msg.ToDto() + err := cc.chatService.SendNewMessage(ctx, msg.Receiver, msg.Sender, &msg.Content) if err != nil { - cc.responder.LogError(err, reqID) + cc.responder.ErrorInternal(w, err, reqID) return } resConn, ok := mapUserConn[msg.Receiver] if ok { //resConn.Socket.ReadMessage() - resConn.Receive <- msg + m := msg.FromDto() + resConn.Receive <- &m } } } func (cc *ChatController) GetAllChats(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) - lastTimeQuery = r.URL.Query().Get("lastTime") + reqID, ok = r.Context().Value(middleware.RequestKey).(string) + lastTimeQuery = sanitize(r.URL.Query().Get("lastTime")) lastTime time.Time err error ) @@ -149,7 +170,7 @@ func (cc *ChatController) GetAllChats(w http.ResponseWriter, r *http.Request) { func GetIdFromURL(r *http.Request) (uint32, error) { vars := mux.Vars(r) - id := vars["id"] + id := sanitize(vars["id"]) if id == "" { return 0, my_err.ErrEmptyId } @@ -166,8 +187,8 @@ func GetIdFromURL(r *http.Request) (uint32, error) { func (cc *ChatController) GetChat(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) - lastTimeQuery = r.URL.Query().Get("lastTime") + reqID, ok = r.Context().Value(middleware.RequestKey).(string) + lastTimeQuery = sanitize(r.URL.Query().Get("lastTime")) lastTime time.Time err error ) diff --git a/internal/chat/controller/controller_test.go b/internal/chat/controller/controller_test.go index 037173eb..b69cf3ab 100644 --- a/internal/chat/controller/controller_test.go +++ b/internal/chat/controller/controller_test.go @@ -36,6 +36,20 @@ func TestNewController(t *testing.T) { assert.NotNil(t, res) } +func TestSanitize(t *testing.T) { + test := "" + expected := "" + res := sanitize(test) + assert.Equal(t, expected, res) +} + +func TestSanitizeFiles(t *testing.T) { + test := []string{""} + var expected []string + res := sanitizeFiles(test) + assert.Equal(t, expected, res) +} + func TestGetAllChat(t *testing.T) { tests := []TableTest[Response, Request]{ { @@ -60,7 +74,9 @@ func TestGetAllChat(t *testing.T) { m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( func(w, err, req any) { request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } }, ) }, @@ -87,7 +103,9 @@ func TestGetAllChat(t *testing.T) { m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( func(w, err, req any) { request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } }, ) }, @@ -148,7 +166,9 @@ func TestGetAllChat(t *testing.T) { m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( func(w, data, req any) { request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } }, ) }, @@ -179,7 +199,9 @@ func TestGetAllChat(t *testing.T) { m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( func(w, data, req any) { request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } }, ) }, @@ -241,7 +263,9 @@ func TestGetChat(t *testing.T) { m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( func(w, err, req any) { request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } }, ) }, @@ -268,7 +292,9 @@ func TestGetChat(t *testing.T) { m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( func(w, err, req any) { request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } }, ) }, @@ -296,7 +322,9 @@ func TestGetChat(t *testing.T) { m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( func(w, err, req any) { request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } }, ) }, @@ -325,7 +353,9 @@ func TestGetChat(t *testing.T) { m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( func(w, err, req any) { request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } }, ) }, @@ -356,7 +386,9 @@ func TestGetChat(t *testing.T) { m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( func(w, err, req any) { request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } }, ) }, @@ -417,7 +449,9 @@ func TestGetChat(t *testing.T) { m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( func(w, data, req any) { request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } }, ) }, diff --git a/internal/chat/controller/mock.go b/internal/chat/controller/mock.go index 85872a35..348a7797 100644 --- a/internal/chat/controller/mock.go +++ b/internal/chat/controller/mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: controller.go +// +// Generated by this command: +// +// mockgen -destination=mock.go -source=controller.go -package=controller +// // Package controller is a generated GoMock package. package controller @@ -14,10 +19,79 @@ import ( gomock "github.com/golang/mock/gomock" ) +// MockChatService is a mock of ChatService interface. +type MockChatService struct { + ctrl *gomock.Controller + recorder *MockChatServiceMockRecorder + isgomock struct{} +} + +// MockChatServiceMockRecorder is the mock recorder for MockChatService. +type MockChatServiceMockRecorder struct { + mock *MockChatService +} + +// NewMockChatService creates a new mock instance. +func NewMockChatService(ctrl *gomock.Controller) *MockChatService { + mock := &MockChatService{ctrl: ctrl} + mock.recorder = &MockChatServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockChatService) EXPECT() *MockChatServiceMockRecorder { + return m.recorder +} + +// GetAllChats mocks base method. +func (m *MockChatService) GetAllChats(ctx context.Context, userID uint32, lastUpdateTime time.Time) ([]*models.Chat, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllChats", ctx, userID, lastUpdateTime) + ret0, _ := ret[0].([]*models.Chat) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllChats indicates an expected call of GetAllChats. +func (mr *MockChatServiceMockRecorder) GetAllChats(ctx, userID, lastUpdateTime any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllChats", reflect.TypeOf((*MockChatService)(nil).GetAllChats), ctx, userID, lastUpdateTime) +} + +// GetChat mocks base method. +func (m *MockChatService) GetChat(ctx context.Context, userID, chatID uint32, lastSentTime time.Time) ([]*models.Message, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChat", ctx, userID, chatID, lastSentTime) + ret0, _ := ret[0].([]*models.Message) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChat indicates an expected call of GetChat. +func (mr *MockChatServiceMockRecorder) GetChat(ctx, userID, chatID, lastSentTime any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChat", reflect.TypeOf((*MockChatService)(nil).GetChat), ctx, userID, chatID, lastSentTime) +} + +// SendNewMessage mocks base method. +func (m *MockChatService) SendNewMessage(ctx context.Context, receiver, sender uint32, message *models.MessageContentDto) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendNewMessage", ctx, receiver, sender, message) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendNewMessage indicates an expected call of SendNewMessage. +func (mr *MockChatServiceMockRecorder) SendNewMessage(ctx, receiver, sender, message any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendNewMessage", reflect.TypeOf((*MockChatService)(nil).SendNewMessage), ctx, receiver, sender, message) +} + // MockResponder is a mock of Responder interface. type MockResponder struct { ctrl *gomock.Controller recorder *MockResponderMockRecorder + isgomock struct{} } // MockResponderMockRecorder is the mock recorder for MockResponder. @@ -44,7 +118,7 @@ func (m *MockResponder) ErrorBadRequest(w http.ResponseWriter, err error, reques } // ErrorBadRequest indicates an expected call of ErrorBadRequest. -func (mr *MockResponderMockRecorder) ErrorBadRequest(w, err, requestId interface{}) *gomock.Call { +func (mr *MockResponderMockRecorder) ErrorBadRequest(w, err, requestId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ErrorBadRequest", reflect.TypeOf((*MockResponder)(nil).ErrorBadRequest), w, err, requestId) } @@ -56,7 +130,7 @@ func (m *MockResponder) ErrorInternal(w http.ResponseWriter, err error, requestI } // ErrorInternal indicates an expected call of ErrorInternal. -func (mr *MockResponderMockRecorder) ErrorInternal(w, err, requestId interface{}) *gomock.Call { +func (mr *MockResponderMockRecorder) ErrorInternal(w, err, requestId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ErrorInternal", reflect.TypeOf((*MockResponder)(nil).ErrorInternal), w, err, requestId) } @@ -68,7 +142,7 @@ func (m *MockResponder) LogError(err error, requestId string) { } // LogError indicates an expected call of LogError. -func (mr *MockResponderMockRecorder) LogError(err, requestId interface{}) *gomock.Call { +func (mr *MockResponderMockRecorder) LogError(err, requestId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LogError", reflect.TypeOf((*MockResponder)(nil).LogError), err, requestId) } @@ -80,7 +154,7 @@ func (m *MockResponder) OutputJSON(w http.ResponseWriter, data any, requestId st } // OutputJSON indicates an expected call of OutputJSON. -func (mr *MockResponderMockRecorder) OutputJSON(w, data, requestId interface{}) *gomock.Call { +func (mr *MockResponderMockRecorder) OutputJSON(w, data, requestId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutputJSON", reflect.TypeOf((*MockResponder)(nil).OutputJSON), w, data, requestId) } @@ -92,74 +166,7 @@ func (m *MockResponder) OutputNoMoreContentJSON(w http.ResponseWriter, requestId } // OutputNoMoreContentJSON indicates an expected call of OutputNoMoreContentJSON. -func (mr *MockResponderMockRecorder) OutputNoMoreContentJSON(w, requestId interface{}) *gomock.Call { +func (mr *MockResponderMockRecorder) OutputNoMoreContentJSON(w, requestId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutputNoMoreContentJSON", reflect.TypeOf((*MockResponder)(nil).OutputNoMoreContentJSON), w, requestId) } - -// MockChatService is a mock of ChatService interface. -type MockChatService struct { - ctrl *gomock.Controller - recorder *MockChatServiceMockRecorder -} - -// MockChatServiceMockRecorder is the mock recorder for MockChatService. -type MockChatServiceMockRecorder struct { - mock *MockChatService -} - -// NewMockChatService creates a new mock instance. -func NewMockChatService(ctrl *gomock.Controller) *MockChatService { - mock := &MockChatService{ctrl: ctrl} - mock.recorder = &MockChatServiceMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockChatService) EXPECT() *MockChatServiceMockRecorder { - return m.recorder -} - -// GetAllChats mocks base method. -func (m *MockChatService) GetAllChats(ctx context.Context, userID uint32, lastUpdateTime time.Time) ([]*models.Chat, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAllChats", ctx, userID, lastUpdateTime) - ret0, _ := ret[0].([]*models.Chat) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetAllChats indicates an expected call of GetAllChats. -func (mr *MockChatServiceMockRecorder) GetAllChats(ctx, userID, lastUpdateTime interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllChats", reflect.TypeOf((*MockChatService)(nil).GetAllChats), ctx, userID, lastUpdateTime) -} - -// GetChat mocks base method. -func (m *MockChatService) GetChat(ctx context.Context, userID, chatID uint32, lastSentTime time.Time) ([]*models.Message, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChat", ctx, userID, chatID, lastSentTime) - ret0, _ := ret[0].([]*models.Message) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetChat indicates an expected call of GetChat. -func (mr *MockChatServiceMockRecorder) GetChat(ctx, userID, chatID, lastSentTime interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChat", reflect.TypeOf((*MockChatService)(nil).GetChat), ctx, userID, chatID, lastSentTime) -} - -// SendNewMessage mocks base method. -func (m *MockChatService) SendNewMessage(ctx context.Context, receiver, sender uint32, message string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendNewMessage", ctx, receiver, sender, message) - ret0, _ := ret[0].(error) - return ret0 -} - -// SendNewMessage indicates an expected call of SendNewMessage. -func (mr *MockChatServiceMockRecorder) SendNewMessage(ctx, receiver, sender, message interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendNewMessage", reflect.TypeOf((*MockChatService)(nil).SendNewMessage), ctx, receiver, sender, message) -} diff --git a/internal/chat/repository.go b/internal/chat/repository.go index f9816362..f430453a 100644 --- a/internal/chat/repository.go +++ b/internal/chat/repository.go @@ -9,6 +9,6 @@ import ( type ChatRepository interface { GetChats(ctx context.Context, userID uint32, lastUpdateTime time.Time) ([]*models.Chat, error) - GetMessages(ctx context.Context, userID uint32, chatID uint32, lastSentTime time.Time) ([]*models.Message, error) - SendNewMessage(ctx context.Context, receiver uint32, sender uint32, message string) error + GetMessages(ctx context.Context, userID uint32, chatID uint32, lastSentTime time.Time) ([]*models.MessageDto, error) + SendNewMessage(ctx context.Context, receiver uint32, sender uint32, message *models.MessageContentDto) error } diff --git a/internal/chat/repository/postgres/queryConst.go b/internal/chat/repository/postgres/queryConst.go index e37fd3bf..5da9bda9 100644 --- a/internal/chat/repository/postgres/queryConst.go +++ b/internal/chat/repository/postgres/queryConst.go @@ -36,12 +36,12 @@ ORDER BY last_messages.created_at DESC LIMIT 15;` - getLatestMessagesBatch = `SELECT sender, receiver, content, created_at + getLatestMessagesBatch = `SELECT sender, receiver, content, file_path, sticker_path, created_at FROM message WHERE ((sender = $1 AND receiver = $2) OR (sender = $2 AND receiver = $1)) AND created_at < $3 ORDER BY created_at DESC LIMIT 20;` - sendNewMessage = `INSERT INTO message(receiver, sender, content) VALUES ($1, $2, $3)` + sendNewMessage = `INSERT INTO message(receiver, sender, content, file_path, sticker_path) VALUES ($1, $2, $3, $4, $5)` ) diff --git a/internal/chat/repository/postgres/repository.go b/internal/chat/repository/postgres/repository.go index 81c2eacc..41ce2825 100644 --- a/internal/chat/repository/postgres/repository.go +++ b/internal/chat/repository/postgres/repository.go @@ -38,7 +38,9 @@ func (cr *Repo) GetChats(ctx context.Context, userID uint32, lastUpdateTime time for rows.Next() { chat := &models.Chat{} - if err := rows.Scan(&chat.Receiver.AuthorID, &chat.Receiver.Author, &chat.Receiver.Avatar, &chat.LastMessage, &chat.LastDate); err != nil { + if err := rows.Scan( + &chat.Receiver.AuthorID, &chat.Receiver.Author, &chat.Receiver.Avatar, &chat.LastMessage, &chat.LastDate, + ); err != nil { return nil, fmt.Errorf("postgres get chats: %w", err) } chats = append(chats, chat) @@ -51,8 +53,10 @@ func (cr *Repo) GetChats(ctx context.Context, userID uint32, lastUpdateTime time return chats, nil } -func (cr *Repo) GetMessages(ctx context.Context, userID uint32, chatID uint32, lastSentTime time.Time) ([]*models.Message, error) { - var messages []*models.Message +func (cr *Repo) GetMessages( + ctx context.Context, userID uint32, chatID uint32, lastSentTime time.Time, +) ([]*models.MessageDto, error) { + var messages []*models.MessageDto rows, err := cr.db.QueryContext(ctx, getLatestMessagesBatch, userID, chatID, pq.FormatTimestamp(lastSentTime)) @@ -65,8 +69,15 @@ func (cr *Repo) GetMessages(ctx context.Context, userID uint32, chatID uint32, l defer rows.Close() for rows.Next() { - msg := &models.Message{} - if err := rows.Scan(&msg.Sender, &msg.Receiver, &msg.Content, &msg.CreatedAt); err != nil { + msg := &models.MessageDto{} + if err := rows.Scan( + &msg.Sender, + &msg.Receiver, + &msg.Content.Text, + &msg.Content.FilePath, + &msg.Content.StickerPath, + &msg.CreatedAt, + ); err != nil { return nil, fmt.Errorf("postgres get messages: %w", err) } messages = append(messages, msg) @@ -79,8 +90,12 @@ func (cr *Repo) GetMessages(ctx context.Context, userID uint32, chatID uint32, l } -func (cr *Repo) SendNewMessage(ctx context.Context, receiver uint32, sender uint32, message string) error { - _, err := cr.db.ExecContext(ctx, sendNewMessage, receiver, sender, message) +func (cr *Repo) SendNewMessage( + ctx context.Context, receiver uint32, sender uint32, message *models.MessageContentDto, +) error { + _, err := cr.db.ExecContext( + ctx, sendNewMessage, receiver, sender, message.Text, message.FilePath, message.StickerPath, + ) if err != nil { return fmt.Errorf("postgres send new message: %w", err) } diff --git a/internal/chat/service/chat.go b/internal/chat/service/chat.go index 207a5b39..c5203620 100644 --- a/internal/chat/service/chat.go +++ b/internal/chat/service/chat.go @@ -19,7 +19,9 @@ func NewChatService(repo chat.ChatRepository) *ChatService { } } -func (cs *ChatService) GetAllChats(ctx context.Context, userID uint32, lastUpdateTime time.Time) ([]*models.Chat, error) { +func (cs *ChatService) GetAllChats( + ctx context.Context, userID uint32, lastUpdateTime time.Time, +) ([]*models.Chat, error) { chats, err := cs.repo.GetChats(ctx, userID, lastUpdateTime) if err != nil { @@ -29,7 +31,9 @@ func (cs *ChatService) GetAllChats(ctx context.Context, userID uint32, lastUpdat return chats, nil } -func (cs *ChatService) GetChat(ctx context.Context, userID uint32, chatID uint32, lastSent time.Time) ([]*models.Message, error) { +func (cs *ChatService) GetChat( + ctx context.Context, userID uint32, chatID uint32, lastSent time.Time, +) ([]*models.Message, error) { messages, err := cs.repo.GetMessages(ctx, userID, chatID, lastSent) if err != nil { return nil, fmt.Errorf("get all messages: %w", err) @@ -39,10 +43,18 @@ func (cs *ChatService) GetChat(ctx context.Context, userID uint32, chatID uint32 messages[i].CreatedAt = convertTime(m.CreatedAt) } - return messages, nil + res := make([]*models.Message, 0, len(messages)) + for _, m := range messages { + mes := m.FromDto() + res = append(res, &mes) + } + + return res, nil } -func (cs *ChatService) SendNewMessage(ctx context.Context, receiver uint32, sender uint32, message string) error { +func (cs *ChatService) SendNewMessage( + ctx context.Context, receiver uint32, sender uint32, message *models.MessageContentDto, +) error { err := cs.repo.SendNewMessage(ctx, receiver, sender, message) if err != nil { return fmt.Errorf("send new message: %w", err) diff --git a/internal/chat/service/chat_test.go b/internal/chat/service/chat_test.go index 6f9e8c00..68ccf7a2 100644 --- a/internal/chat/service/chat_test.go +++ b/internal/chat/service/chat_test.go @@ -25,17 +25,21 @@ func (m MockRepo) GetChats(ctx context.Context, userID uint32, lastUpdateTime ti return []*models.Chat{}, nil } -func (m MockRepo) GetMessages(ctx context.Context, userID uint32, chatID uint32, lastSentTime time.Time) ([]*models.Message, error) { +func (m MockRepo) GetMessages( + ctx context.Context, userID uint32, chatID uint32, lastSentTime time.Time, +) ([]*models.MessageDto, error) { if userID == 0 || chatID == 0 { return nil, errMock } - return []*models.Message{ + return []*models.MessageDto{ {CreatedAt: createTime}, }, nil } -func (m MockRepo) SendNewMessage(ctx context.Context, receiver uint32, sender uint32, message string) error { - if receiver == 0 || sender == 0 || message == "" { +func (m MockRepo) SendNewMessage( + ctx context.Context, receiver uint32, sender uint32, message *models.MessageContentDto, +) error { + if receiver == 0 || sender == 0 || message.Text == "" { return errMock } return nil @@ -115,7 +119,7 @@ func TestGetChat(t *testing.T) { type TestStructSendNewMessage struct { sender uint32 receiver uint32 - message string + message *models.MessageContentDto wantErr error } @@ -125,25 +129,25 @@ func TestSendNewMessage(t *testing.T) { { sender: 0, receiver: 10, - message: "hello", + message: &models.MessageContentDto{Text: "hello"}, wantErr: errMock, }, { sender: 10, receiver: 0, - message: "hello", + message: &models.MessageContentDto{Text: "hello"}, wantErr: errMock, }, { sender: 1, receiver: 10, - message: "", + message: &models.MessageContentDto{Text: ""}, wantErr: errMock, }, { sender: 1, receiver: 10, - message: "hello", + message: &models.MessageContentDto{Text: "hello"}, wantErr: nil, }, } diff --git a/internal/chat/usecase.go b/internal/chat/usecase.go index c2dc25b0..7449ea1b 100644 --- a/internal/chat/usecase.go +++ b/internal/chat/usecase.go @@ -10,5 +10,5 @@ import ( type ChatService interface { GetAllChats(ctx context.Context, userID uint32, lastUpdateTime time.Time) ([]*models.Chat, error) GetChat(ctx context.Context, userID uint32, chatID uint32, lastSentTime time.Time) ([]*models.Message, error) - SendNewMessage(ctx context.Context, receiver uint32, sender uint32, message string) error + SendNewMessage(ctx context.Context, receiver uint32, sender uint32, message *models.MessageContentDto) error } diff --git a/internal/community/controller/controller.go b/internal/community/controller/controller.go index aaf0cf6f..74cd71b3 100644 --- a/internal/community/controller/controller.go +++ b/internal/community/controller/controller.go @@ -9,7 +9,10 @@ import ( "strconv" "github.com/gorilla/mux" + "github.com/mailru/easyjson" + "github.com/microcosm-cc/bluemonday" + "github.com/2024_2_BetterCallFirewall/internal/middleware" "github.com/2024_2_BetterCallFirewall/internal/models" "github.com/2024_2_BetterCallFirewall/pkg/my_err" ) @@ -48,9 +51,14 @@ func NewCommunityController(responder responder, service communityService) *Cont service: service, } } +func sanitize(input string) string { + sanitizer := bluemonday.UGCPolicy() + cleaned := sanitizer.Sanitize(input) + return cleaned +} func (c *Controller) GetOne(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { c.responder.LogError(my_err.ErrInvalidContext, "") } @@ -69,17 +77,27 @@ func (c *Controller) GetOne(w http.ResponseWriter, r *http.Request) { community, err := c.service.GetOne(r.Context(), id, sess.UserID) if err != nil { + if errors.Is(err, my_err.ErrWrongCommunity) { + c.responder.ErrorBadRequest(w, err, reqID) + return + } c.responder.ErrorInternal(w, err, reqID) return } + if community != nil { + community.Name = sanitize(community.Name) + community.Avatar = models.Picture(sanitize(string(community.Avatar))) + community.About = sanitize(community.About) + } + c.responder.OutputJSON(w, community, reqID) } func (c *Controller) GetAll(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) - lastID = r.URL.Query().Get("id") + reqID, ok = r.Context().Value(middleware.RequestKey).(string) + lastID = sanitize(r.URL.Query().Get("id")) intLastID uint64 err error ) @@ -115,11 +133,17 @@ func (c *Controller) GetAll(w http.ResponseWriter, r *http.Request) { return } + for _, card := range res { + card.Name = sanitize(card.Name) + card.Avatar = models.Picture(sanitize(string(card.Avatar))) + card.About = sanitize(card.About) + } + c.responder.OutputJSON(w, res, reqID) } func (c *Controller) Update(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { c.responder.LogError(my_err.ErrInvalidContext, "") } @@ -149,15 +173,23 @@ func (c *Controller) Update(w http.ResponseWriter, r *http.Request) { err = c.service.Update(r.Context(), id, &newCommunity) if err != nil { + if errors.Is(err, my_err.ErrWrongCommunity) { + c.responder.ErrorBadRequest(w, err, reqID) + return + } c.responder.ErrorInternal(w, err, reqID) return } + newCommunity.Avatar = models.Picture(sanitize(string(newCommunity.Avatar))) + newCommunity.Name = sanitize(newCommunity.Name) + newCommunity.About = sanitize(newCommunity.About) + c.responder.OutputJSON(w, newCommunity, reqID) } func (c *Controller) Delete(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { c.responder.LogError(my_err.ErrInvalidContext, "") } @@ -181,6 +213,10 @@ func (c *Controller) Delete(w http.ResponseWriter, r *http.Request) { err = c.service.Delete(r.Context(), id) if err != nil { + if errors.Is(err, my_err.ErrWrongCommunity) { + c.responder.ErrorBadRequest(w, err, reqID) + return + } c.responder.ErrorInternal(w, err, reqID) return } @@ -189,7 +225,7 @@ func (c *Controller) Delete(w http.ResponseWriter, r *http.Request) { } func (c *Controller) Create(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { c.responder.LogError(my_err.ErrInvalidContext, "") } @@ -216,7 +252,7 @@ func (c *Controller) Create(w http.ResponseWriter, r *http.Request) { } func (c *Controller) JoinToCommunity(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { c.responder.LogError(my_err.ErrInvalidContext, "") } @@ -235,6 +271,10 @@ func (c *Controller) JoinToCommunity(w http.ResponseWriter, r *http.Request) { err = c.service.JoinCommunity(r.Context(), id, sess.UserID) if err != nil { + if errors.Is(err, my_err.ErrWrongCommunity) { + c.responder.ErrorBadRequest(w, err, reqID) + return + } c.responder.ErrorInternal(w, err, reqID) return } @@ -243,7 +283,7 @@ func (c *Controller) JoinToCommunity(w http.ResponseWriter, r *http.Request) { } func (c *Controller) LeaveFromCommunity(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { c.responder.LogError(my_err.ErrInvalidContext, "") } @@ -262,6 +302,10 @@ func (c *Controller) LeaveFromCommunity(w http.ResponseWriter, r *http.Request) err = c.service.LeaveCommunity(r.Context(), id, sess.UserID) if err != nil { + if errors.Is(err, my_err.ErrWrongCommunity) { + c.responder.ErrorBadRequest(w, err, reqID) + return + } c.responder.ErrorInternal(w, err, reqID) return } @@ -270,7 +314,7 @@ func (c *Controller) LeaveFromCommunity(w http.ResponseWriter, r *http.Request) } func (c *Controller) AddAdmin(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { c.responder.LogError(my_err.ErrInvalidContext, "") } @@ -314,9 +358,9 @@ func (c *Controller) AddAdmin(w http.ResponseWriter, r *http.Request) { func (c *Controller) SearchCommunity(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) - subStr = r.URL.Query().Get("q") - lastID = r.URL.Query().Get("id") + reqID, ok = r.Context().Value(middleware.RequestKey).(string) + subStr = sanitize(r.URL.Query().Get("q")) + lastID = sanitize(r.URL.Query().Get("id")) id uint64 err error ) @@ -352,16 +396,29 @@ func (c *Controller) SearchCommunity(w http.ResponseWriter, r *http.Request) { return } + for _, card := range cards { + card.Name = sanitize(card.Name) + card.Avatar = models.Picture(sanitize(string(card.Avatar))) + card.About = sanitize(card.About) + } + c.responder.OutputJSON(w, cards, reqID) } func (c *Controller) getCommunityFromBody(r *http.Request) (models.Community, error) { var res models.Community - err := json.NewDecoder(r.Body).Decode(&res) + err := easyjson.UnmarshalFromReader(r.Body, &res) if err != nil { return models.Community{}, err } + res.Avatar = models.Picture(sanitize(string(res.Avatar))) + res.Name = sanitize(res.Name) + res.About = sanitize(res.About) + + if !validate(res) { + return models.Community{}, my_err.ErrBadCommunity + } return res, nil } @@ -369,7 +426,7 @@ func (c *Controller) getCommunityFromBody(r *http.Request) (models.Community, er func getIDFromQuery(r *http.Request) (uint32, error) { vars := mux.Vars(r) - id := vars["id"] + id := sanitize(vars["id"]) if id == "" { return 0, errors.New("id is empty") } @@ -381,3 +438,11 @@ func getIDFromQuery(r *http.Request) (uint32, error) { return uint32(uid), nil } + +func validate(data models.Community) bool { + if len([]rune(data.Name)) < 3 || len([]rune(data.Name)) >= 50 || len([]rune(data.About)) >= 60 { + return false + } + + return true +} diff --git a/internal/community/controller/controller_test.go b/internal/community/controller/controller_test.go index 383fdd53..80f2edc8 100644 --- a/internal/community/controller/controller_test.go +++ b/internal/community/controller/controller_test.go @@ -58,10 +58,14 @@ func TestGetOne(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -85,11 +89,17 @@ func TestGetOne(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.communityService.EXPECT().GetOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.communityService.EXPECT().GetOne(gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -113,10 +123,14 @@ func TestGetOne(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -141,10 +155,14 @@ func TestGetOne(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().GetOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -167,43 +185,56 @@ func TestGetOne(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } +func TestSanitize(t *testing.T) { + test := "" + expected := "" + res := sanitize(test) + assert.Equal(t, expected, res) +} + func TestGetAll(t *testing.T) { tests := []TableTest[Response, Request]{ { @@ -225,10 +256,14 @@ func TestGetAll(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -251,11 +286,17 @@ func TestGetAll(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.communityService.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.communityService.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -279,9 +320,11 @@ func TestGetAll(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do(func(w, req any) { - request.w.WriteHeader(http.StatusNoContent) - }) + m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do( + func(w, req any) { + request.w.WriteHeader(http.StatusNoContent) + }, + ) }, }, { @@ -306,11 +349,16 @@ func TestGetAll(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return( []*models.CommunityCard{{ID: 1}}, - nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + nil, + ) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -332,40 +380,46 @@ func TestGetAll(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -390,10 +444,14 @@ func TestUpdate(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -416,10 +474,14 @@ func TestUpdate(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -443,16 +505,22 @@ func TestUpdate(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "4", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPut, "/api/v1/community/1", bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPut, "/api/v1/community/1", bytes.NewBuffer([]byte(`{"id":1, "name":"name"}`)), + ) w := httptest.NewRecorder() req = mux.SetURLVars(req, map[string]string{"id": "1"}) req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) @@ -471,16 +539,22 @@ func TestUpdate(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().CheckAccess(gomock.Any(), gomock.Any(), gomock.Any()).Return(false) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "5", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPut, "/api/v1/community/1", bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPut, "/api/v1/community/1", bytes.NewBuffer([]byte(`{"id":1, "name":"name"}`)), + ) w := httptest.NewRecorder() req = mux.SetURLVars(req, map[string]string{"id": "1"}) req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) @@ -500,16 +574,22 @@ func TestUpdate(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().CheckAccess(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) m.communityService.EXPECT().Update(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "6", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPut, "/api/v1/community/1", bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPut, "/api/v1/community/1", bytes.NewBuffer([]byte(`{"id":1, "name":"name"}`)), + ) w := httptest.NewRecorder() req = mux.SetURLVars(req, map[string]string{"id": "1"}) req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) @@ -529,40 +609,79 @@ func TestUpdate(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().CheckAccess(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) m.communityService.EXPECT().Update(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "7", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodPut, "/api/v1/community/1", bytes.NewBuffer([]byte(`{"id":1, "name":"a"}`)), + ) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{"id": "1"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *Controller, request Request) (Response, error) { + implementation.Update(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -587,10 +706,14 @@ func TestDelete(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -613,10 +736,14 @@ func TestDelete(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -641,10 +768,14 @@ func TestDelete(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().CheckAccess(gomock.Any(), gomock.Any(), gomock.Any()).Return(false) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -670,10 +801,14 @@ func TestDelete(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().CheckAccess(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) m.communityService.EXPECT().Delete(gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -699,40 +834,46 @@ func TestDelete(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().CheckAccess(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) m.communityService.EXPECT().Delete(gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -757,10 +898,14 @@ func TestCreate(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -783,16 +928,22 @@ func TestCreate(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "3", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/community/1", bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/community/1", bytes.NewBuffer([]byte(`{"id":1, "name":"name"}`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) res := &Request{r: req, w: w} @@ -810,16 +961,22 @@ func TestCreate(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "4", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/community/1", bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/community/1", bytes.NewBuffer([]byte(`{"id":1, "name":"name"}`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) res := &Request{r: req, w: w} @@ -837,40 +994,46 @@ func TestCreate(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -895,10 +1058,14 @@ func TestJoinToCommunity(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -921,10 +1088,14 @@ func TestJoinToCommunity(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -948,11 +1119,17 @@ func TestJoinToCommunity(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.communityService.EXPECT().JoinCommunity(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.communityService.EXPECT().JoinCommunity( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(errors.New("error")) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -977,40 +1154,46 @@ func TestJoinToCommunity(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().JoinCommunity(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1035,10 +1218,14 @@ func TestLeaveFromCommunity(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1061,10 +1248,14 @@ func TestLeaveFromCommunity(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1088,11 +1279,17 @@ func TestLeaveFromCommunity(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.communityService.EXPECT().LeaveCommunity(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.communityService.EXPECT().LeaveCommunity( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(errors.New("error")) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1117,40 +1314,46 @@ func TestLeaveFromCommunity(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().LeaveCommunity(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1175,10 +1378,14 @@ func TestAddAdmin(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1201,10 +1408,14 @@ func TestAddAdmin(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1229,16 +1440,22 @@ func TestAddAdmin(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().CheckAccess(gomock.Any(), gomock.Any(), gomock.Any()).Return(false) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "4", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/community/1/add_admin", bytes.NewBuffer([]byte(`{kj`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/community/1/add_admin", bytes.NewBuffer([]byte(`{kj`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) req = mux.SetURLVars(req, map[string]string{"id": "1"}) @@ -1258,16 +1475,22 @@ func TestAddAdmin(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().CheckAccess(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "5", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/community/1/add_admin", bytes.NewBuffer([]byte(`1`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/community/1/add_admin", bytes.NewBuffer([]byte(`1`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) req = mux.SetURLVars(req, map[string]string{"id": "1"}) @@ -1286,17 +1509,25 @@ func TestAddAdmin(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().CheckAccess(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - m.communityService.EXPECT().AddAdmin(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.communityService.EXPECT().AddAdmin( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(errors.New("error")) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "6", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/community/1/add_admin", bytes.NewBuffer([]byte(`1`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/community/1/add_admin", bytes.NewBuffer([]byte(`1`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) req = mux.SetURLVars(req, map[string]string{"id": "1"}) @@ -1316,16 +1547,22 @@ func TestAddAdmin(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().CheckAccess(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) m.communityService.EXPECT().AddAdmin(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "7", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/community/1/add_admin", bytes.NewBuffer([]byte(`1`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/community/1/add_admin", bytes.NewBuffer([]byte(`1`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) req = mux.SetURLVars(req, map[string]string{"id": "1"}) @@ -1344,41 +1581,49 @@ func TestAddAdmin(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.communityService.EXPECT().CheckAccess(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - m.communityService.EXPECT().AddAdmin(gomock.Any(), gomock.Any(), gomock.Any()).Return(my_err.ErrWrongCommunity) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.communityService.EXPECT().AddAdmin( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(my_err.ErrWrongCommunity) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1403,10 +1648,14 @@ func TestSearch(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1429,11 +1678,17 @@ func TestSearch(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.communityService.EXPECT().Search(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.communityService.EXPECT().Search(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1456,11 +1711,17 @@ func TestSearch(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.communityService.EXPECT().Search(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.communityService.EXPECT().Search(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, nil, + ) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1483,10 +1744,14 @@ func TestSearch(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1508,39 +1773,45 @@ func TestSearch(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } diff --git a/internal/community/repository/repository.go b/internal/community/repository/repository.go index 751b09ce..397470b2 100644 --- a/internal/community/repository/repository.go +++ b/internal/community/repository/repository.go @@ -45,8 +45,13 @@ func (c CommunityRepository) GetBatch(ctx context.Context, lastID uint32) ([]*mo func (c CommunityRepository) GetOne(ctx context.Context, id uint32) (*models.Community, error) { res := &models.Community{} - err := c.db.QueryRowContext(ctx, GetOne, id).Scan(&res.ID, &res.Name, &res.Avatar, &res.About, &res.CountSubscribers) + err := c.db.QueryRowContext(ctx, GetOne, id).Scan( + &res.ID, &res.Name, &res.Avatar, &res.About, &res.CountSubscribers, + ) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, my_err.ErrWrongCommunity + } return nil, fmt.Errorf("get community db: %w", err) } @@ -59,7 +64,9 @@ func (c CommunityRepository) Create(ctx context.Context, community *models.Commu if community.Avatar == "" { res = c.db.QueryRowContext(ctx, CreateNewCommunity, community.Name, community.About, author) } else { - res = c.db.QueryRowContext(ctx, CreateNewCommunityWithAvatar, community.Name, community.About, community.Avatar, author) + res = c.db.QueryRowContext( + ctx, CreateNewCommunityWithAvatar, community.Name, community.About, community.Avatar, author, + ) } err := res.Err() @@ -80,9 +87,14 @@ func (c CommunityRepository) Update(ctx context.Context, community *models.Commu if community.Avatar == "" { _, err = c.db.ExecContext(ctx, UpdateWithoutAvatar, community.Name, community.About, community.ID) } else { - _, err = c.db.ExecContext(ctx, UpdateWithAvatar, community.Name, community.Avatar, community.About, community.ID) + _, err = c.db.ExecContext( + ctx, UpdateWithAvatar, community.Name, community.Avatar, community.About, community.ID, + ) } if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return my_err.ErrWrongCommunity + } return fmt.Errorf("update community: %w", err) } @@ -92,6 +104,9 @@ func (c CommunityRepository) Update(ctx context.Context, community *models.Commu func (c CommunityRepository) Delete(ctx context.Context, id uint32) error { _, err := c.db.ExecContext(ctx, Delete, id) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return my_err.ErrWrongCommunity + } return fmt.Errorf("delete community: %w", err) } @@ -101,6 +116,9 @@ func (c CommunityRepository) Delete(ctx context.Context, id uint32) error { func (c CommunityRepository) JoinCommunity(ctx context.Context, communityId, author uint32) error { _, err := c.db.ExecContext(ctx, JoinCommunity, communityId, author) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return my_err.ErrWrongCommunity + } return fmt.Errorf("join community: %w", err) } @@ -110,6 +128,9 @@ func (c CommunityRepository) JoinCommunity(ctx context.Context, communityId, aut func (c CommunityRepository) LeaveCommunity(ctx context.Context, communityId, author uint32) error { _, err := c.db.ExecContext(ctx, LeaveCommunity, communityId, author) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return my_err.ErrWrongCommunity + } return fmt.Errorf("leave community: %w", err) } access := c.CheckAccess(ctx, communityId, author) @@ -117,6 +138,9 @@ func (c CommunityRepository) LeaveCommunity(ctx context.Context, communityId, au if access { _, err := c.db.ExecContext(ctx, DeleteAdmin, communityId, author) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return my_err.ErrWrongCommunity + } return fmt.Errorf("delete admin: %w", err) } } @@ -127,6 +151,9 @@ func (c CommunityRepository) LeaveCommunity(ctx context.Context, communityId, au func (c CommunityRepository) NewAdmin(ctx context.Context, communityId uint32, author uint32) error { _, err := c.db.ExecContext(ctx, InsertNewAdmin, communityId, author) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return my_err.ErrWrongCommunity + } return fmt.Errorf("insert new admin: %w", err) } return nil diff --git a/internal/config/config.go b/internal/config/config.go index 92272dad..53f704d4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -48,6 +48,7 @@ type Config struct { POST Server PROFILE Server COMMUNITY Server + STICKER Server AUTHGRPC GRPCServer PROFILEGRPC GRPCServer POSTGRPC GRPCServer @@ -128,6 +129,11 @@ func GetConfig(configFilePath string) (*Config, error) { ReadTimeout: time.Duration(getIntEnv("SERVER_READ_TIMEOUT")) * time.Second, WriteTimeout: time.Duration(getIntEnv("SERVER_WRITE_TIMEOUT")) * time.Second, }, + STICKER: Server{ + Port: os.Getenv("STICKER_HTTP_PORT"), + ReadTimeout: time.Duration(getIntEnv("SERVER_READ_TIMEOUT")) * time.Second, + WriteTimeout: time.Duration(getIntEnv("SERVER_WRITE_TIMEOUT")) * time.Second, + }, PROFILEGRPC: GRPCServer{ Port: os.Getenv("PROFILE_GRPC_PORT"), Host: os.Getenv("PROFILE_GRPC_HOST"), diff --git a/internal/ext_grpc/adapter/post/post_test.go b/internal/ext_grpc/adapter/post/post_test.go index d4f2a8c6..0ba0a3dd 100644 --- a/internal/ext_grpc/adapter/post/post_test.go +++ b/internal/ext_grpc/adapter/post/post_test.go @@ -65,6 +65,7 @@ func TestGetAuthorsPost(t *testing.T) { Text: "new post", CreatedAt: time.Unix(createTime.Unix(), 0), UpdatedAt: time.Unix(createTime.Unix(), 0), + File: []models.Picture{}, }, Header: models.Header{ AuthorID: 1, @@ -75,51 +76,55 @@ func TestGetAuthorsPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request *models.Header, m *mocks) { m.client.EXPECT().GetAuthorsPosts(gomock.Any(), gomock.Any()). - Return(&post_api.Response{ - Posts: []*post_api.Post{ - { - ID: 1, - PostContent: &post_api.Content{ - Text: "new post", - CreatedAt: createTime.Unix(), - UpdatedAt: createTime.Unix(), - }, - Head: &post_api.Header{ - AuthorID: 1, + Return( + &post_api.Response{ + Posts: []*post_api.Post{ + { + ID: 1, + PostContent: &post_api.Content{ + Text: "new post", + CreatedAt: createTime.Unix(), + UpdatedAt: createTime.Unix(), + }, + Head: &post_api.Header{ + AuthorID: 1, + }, }, }, - }, - }, nil) + }, nil, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - adapter, mock := getAdapter(ctrl) - ctx := context.Background() + adapter, mock := getAdapter(ctrl) + ctx := context.Background() - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } - v.SetupMock(input, mock) + v.SetupMock(input, mock) - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } - actual, err := v.Run(ctx, adapter, input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + actual, err := v.Run(ctx, adapter, input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } diff --git a/internal/ext_grpc/port/post/post.go b/internal/ext_grpc/port/post/post.go index 72100e66..a5c5264b 100644 --- a/internal/ext_grpc/port/post/post.go +++ b/internal/ext_grpc/port/post/post.go @@ -22,23 +22,31 @@ func NewRequest(header *models.Header, userID uint32) *post_api.Request { func UnmarshalResponse(response *post_api.Response) []*models.Post { res := make([]*models.Post, 0, len(response.Posts)) for _, post := range response.Posts { - res = append(res, &models.Post{ - ID: post.ID, - Header: models.Header{ - AuthorID: post.Head.AuthorID, - CommunityID: post.Head.CommunityID, - Avatar: models.Picture(post.Head.Avatar), - Author: post.Head.Author, - }, - PostContent: models.Content{ - Text: post.PostContent.Text, - File: models.Picture(post.PostContent.File), - CreatedAt: time.Unix(post.PostContent.CreatedAt, 0), - UpdatedAt: time.Unix(post.PostContent.UpdatedAt, 0), + files := make([]models.Picture, 0, len(post.PostContent.File)) + for _, file := range post.PostContent.File { + files = append(files, models.Picture(file)) + } + + res = append( + res, &models.Post{ + ID: post.ID, + Header: models.Header{ + AuthorID: post.Head.AuthorID, + CommunityID: post.Head.CommunityID, + Avatar: models.Picture(post.Head.Avatar), + Author: post.Head.Author, + }, + PostContent: models.Content{ + Text: post.PostContent.Text, + File: files, + CreatedAt: time.Unix(post.PostContent.CreatedAt, 0), + UpdatedAt: time.Unix(post.PostContent.UpdatedAt, 0), + }, + IsLiked: post.IsLiked, + LikesCount: post.LikesCount, + CommentCount: post.CommentCount, }, - IsLiked: post.IsLiked, - LikesCount: post.LikesCount, - }) + ) } return res diff --git a/internal/fileService/controller/controller.go b/internal/fileService/controller/controller.go index 24f03c0c..51ad745a 100644 --- a/internal/fileService/controller/controller.go +++ b/internal/fileService/controller/controller.go @@ -1,28 +1,43 @@ package controller import ( + "bytes" "context" "errors" "fmt" + "io" "mime/multipart" "net/http" + "strings" "github.com/gorilla/mux" + "github.com/microcosm-cc/bluemonday" + "github.com/2024_2_BetterCallFirewall/internal/middleware" "github.com/2024_2_BetterCallFirewall/pkg/my_err" ) var fileFormat = map[string]struct{}{ - "image/jpeg": {}, - "image/jpg": {}, - "image/png": {}, - "image/webp": {}, + "jpeg": {}, + "jpg": {}, + "png": {}, + "webp": {}, + "gif": {}, } +const ( + charset = "charset=utf" + txt = "txt" + plain = "plain" + maxMemory = 100 * 1024 * 1024 +) + //go:generate mockgen -destination=mock.go -source=$GOFILE -package=${GOPACKAGE} type fileService interface { Upload(ctx context.Context, name string) ([]byte, error) - Download(ctx context.Context, file multipart.File) (string, error) + Download(ctx context.Context, file io.Reader, format string) (string, error) + DownloadNonImage(ctx context.Context, file io.Reader, format, realName string) (string, error) + UploadNonImage(ctx context.Context, name string) ([]byte, error) } type responder interface { @@ -45,11 +60,57 @@ func NewFileController(fileService fileService, responder responder) *FileContro } } -func (fc *FileController) Upload(w http.ResponseWriter, r *http.Request) { +func getFormat(buf []byte) string { + formats := http.DetectContentType(buf) + format := strings.Split(formats, "/")[1] + + if strings.Contains(format, charset) { + format = strings.Split(format, ";")[0] + } + + if format == plain { + format = txt + } + + return format +} + +func sanitize(input string) string { + sanitizer := bluemonday.UGCPolicy() + cleaned := sanitizer.Sanitize(input) + return cleaned +} + +func (fc *FileController) UploadNonImage(w http.ResponseWriter, r *http.Request) { var ( reqID, ok = r.Context().Value("requestID").(string) vars = mux.Vars(r) - name = vars["name"] + name = sanitize(vars["name"]) + ) + + if !ok { + fc.responder.LogError(my_err.ErrInvalidContext, "") + } + + if name == "" { + fc.responder.ErrorBadRequest(w, errors.New("name is empty"), reqID) + return + } + + res, err := fc.fileService.UploadNonImage(r.Context(), name) + if err != nil { + fc.responder.ErrorBadRequest(w, fmt.Errorf("%w: %w", err, my_err.ErrWrongFile), reqID) + return + } + + fc.responder.OutputBytes(w, res, reqID) +} + +func (fc *FileController) Upload(w http.ResponseWriter, r *http.Request) { + var ( + reqID, ok = r.Context().Value(middleware.RequestKey).(string) + vars = mux.Vars(r) + name = sanitize(vars["name"]) ) if !ok { @@ -71,35 +132,65 @@ func (fc *FileController) Upload(w http.ResponseWriter, r *http.Request) { } func (fc *FileController) Download(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { fc.responder.LogError(my_err.ErrInvalidContext, "") } - err := r.ParseMultipartForm(10 << 20) // 10Mbyte - defer r.MultipartForm.RemoveAll() + err := r.ParseMultipartForm(maxMemory) // 100Mbyte if err != nil { fc.responder.ErrorBadRequest(w, my_err.ErrToLargeFile, reqID) return } + defer func() { + err = r.MultipartForm.RemoveAll() + if err != nil { + fc.responder.LogError(err, reqID) + } + }() file, header, err := r.FormFile("file") if err != nil { - file = nil + fc.responder.ErrorBadRequest(w, err, reqID) + return + } + + defer func(file multipart.File) { + err = file.Close() + if err != nil { + fc.responder.LogError(err, reqID) + } + }(file) + + buf := bytes.NewBuffer(make([]byte, 20)) + n, err := file.Read(buf.Bytes()) + if err != nil { + fc.responder.ErrorBadRequest(w, err, reqID) + return + } + format := getFormat(buf.Bytes()[:n]) + + _, err = io.Copy(buf, file) + if err != nil { + fc.responder.ErrorBadRequest(w, err, reqID) + return + } + var url string + + if _, ok := fileFormat[format]; ok { + url, err = fc.fileService.Download(r.Context(), buf, format) } else { - format := header.Header.Get("Content-Type") - if _, ok := fileFormat[format]; !ok { - fc.responder.ErrorBadRequest(w, my_err.ErrWrongFiletype, reqID) + name := header.Filename + if len([]rune(name+format)) > 53 { + fc.responder.ErrorBadRequest(w, errors.New("file name is too big"), reqID) return } + url, err = fc.fileService.DownloadNonImage(r.Context(), buf, format, name) } - defer file.Close() - url, err := fc.fileService.Download(r.Context(), file) if err != nil { fc.responder.ErrorBadRequest(w, err, reqID) return } - fc.responder.OutputJSON(w, url, reqID) } diff --git a/internal/fileService/controller/controller_test.go b/internal/fileService/controller/controller_test.go index f21ecf3f..c5b4546f 100644 --- a/internal/fileService/controller/controller_test.go +++ b/internal/fileService/controller/controller_test.go @@ -3,6 +3,7 @@ package controller import ( "context" "errors" + "mime/multipart" "net/http" "net/http/httptest" "testing" @@ -56,10 +57,14 @@ func TestUpload(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -83,10 +88,14 @@ func TestUpload(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.fileService.EXPECT().Upload(gomock.Any(), gomock.Any()).Return(nil, errMock) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -110,40 +119,292 @@ func TestUpload(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.fileService.EXPECT().Upload(gomock.Any(), gomock.Any()).Return(nil, nil) - m.responder.EXPECT().OutputBytes(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputBytes(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +func TestUploadNonImage(t *testing.T) { + tests := []TableTest[Response, Request]{ + { + name: "1", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/files/default", nil) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *FileController, request Request) (Response, error) { + implementation.UploadNonImage(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "2", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/files/default", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{"name": "default"}) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *FileController, request Request) (Response, error) { + implementation.UploadNonImage(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.fileService.EXPECT().UploadNonImage(gomock.Any(), gomock.Any()).Return(nil, errMock) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "3", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/files/default", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{"name": "default"}) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *FileController, request Request) (Response, error) { + implementation.UploadNonImage(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusOK, Body: "OK"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.fileService.EXPECT().UploadNonImage(gomock.Any(), gomock.Any()).Return(nil, nil) + m.responder.EXPECT().OutputBytes(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +func TestSanitize(t *testing.T) { + test := "" + expected := "" + res := sanitize(test) + assert.Equal(t, expected, res) +} + +func TestDownload(t *testing.T) { + tests := []TableTest[Response, Request]{ + { + name: "1", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPost, "/image", nil) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *FileController, request Request) (Response, error) { + implementation.Download(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "2", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPost, "/image", nil) + w := httptest.NewRecorder() + req.MultipartForm = &multipart.Form{ + File: make(map[string][]*multipart.FileHeader), + } + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *FileController, request Request) (Response, error) { + implementation.Download(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +func TestGetFormat(t *testing.T) { + tests := []struct { + input []byte + output string + }{ + {input: []byte("some text"), output: "txt"}, + {input: []byte(""), output: "html"}, + } + + for i, tt := range tests { + actual := getFormat(tt.input) + if actual != tt.output { + t.Errorf("%d: expected %s, got %s", i, tt.output, actual) + } } } diff --git a/internal/fileService/controller/mock.go b/internal/fileService/controller/mock.go index ad31747d..76b8237e 100644 --- a/internal/fileService/controller/mock.go +++ b/internal/fileService/controller/mock.go @@ -1,12 +1,17 @@ // Code generated by MockGen. DO NOT EDIT. // Source: controller.go +// +// Generated by this command: +// +// mockgen -destination=mock.go -source=controller.go -package=controller +// // Package controller is a generated GoMock package. package controller import ( context "context" - multipart "mime/multipart" + io "io" http "net/http" reflect "reflect" @@ -17,6 +22,7 @@ import ( type MockfileService struct { ctrl *gomock.Controller recorder *MockfileServiceMockRecorder + isgomock struct{} } // MockfileServiceMockRecorder is the mock recorder for MockfileService. @@ -37,18 +43,33 @@ func (m *MockfileService) EXPECT() *MockfileServiceMockRecorder { } // Download mocks base method. -func (m *MockfileService) Download(ctx context.Context, file multipart.File) (string, error) { +func (m *MockfileService) Download(ctx context.Context, file io.Reader, format string) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Download", ctx, file) + ret := m.ctrl.Call(m, "Download", ctx, file, format) ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } // Download indicates an expected call of Download. -func (mr *MockfileServiceMockRecorder) Download(ctx, file interface{}) *gomock.Call { +func (mr *MockfileServiceMockRecorder) Download(ctx, file, format any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Download", reflect.TypeOf((*MockfileService)(nil).Download), ctx, file) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Download", reflect.TypeOf((*MockfileService)(nil).Download), ctx, file, format) +} + +// DownloadNonImage mocks base method. +func (m *MockfileService) DownloadNonImage(ctx context.Context, file io.Reader, format, realName string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DownloadNonImage", ctx, file, format, realName) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DownloadNonImage indicates an expected call of DownloadNonImage. +func (mr *MockfileServiceMockRecorder) DownloadNonImage(ctx, file, format, realName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadNonImage", reflect.TypeOf((*MockfileService)(nil).DownloadNonImage), ctx, file, format, realName) } // Upload mocks base method. @@ -61,15 +82,31 @@ func (m *MockfileService) Upload(ctx context.Context, name string) ([]byte, erro } // Upload indicates an expected call of Upload. -func (mr *MockfileServiceMockRecorder) Upload(ctx, name interface{}) *gomock.Call { +func (mr *MockfileServiceMockRecorder) Upload(ctx, name any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upload", reflect.TypeOf((*MockfileService)(nil).Upload), ctx, name) } +// UploadNonImage mocks base method. +func (m *MockfileService) UploadNonImage(ctx context.Context, name string) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UploadNonImage", ctx, name) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UploadNonImage indicates an expected call of UploadNonImage. +func (mr *MockfileServiceMockRecorder) UploadNonImage(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadNonImage", reflect.TypeOf((*MockfileService)(nil).UploadNonImage), ctx, name) +} + // Mockresponder is a mock of responder interface. type Mockresponder struct { ctrl *gomock.Controller recorder *MockresponderMockRecorder + isgomock struct{} } // MockresponderMockRecorder is the mock recorder for Mockresponder. @@ -96,7 +133,7 @@ func (m *Mockresponder) ErrorBadRequest(w http.ResponseWriter, err error, reques } // ErrorBadRequest indicates an expected call of ErrorBadRequest. -func (mr *MockresponderMockRecorder) ErrorBadRequest(w, err, requestID interface{}) *gomock.Call { +func (mr *MockresponderMockRecorder) ErrorBadRequest(w, err, requestID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ErrorBadRequest", reflect.TypeOf((*Mockresponder)(nil).ErrorBadRequest), w, err, requestID) } @@ -108,7 +145,7 @@ func (m *Mockresponder) LogError(err error, requestID string) { } // LogError indicates an expected call of LogError. -func (mr *MockresponderMockRecorder) LogError(err, requestID interface{}) *gomock.Call { +func (mr *MockresponderMockRecorder) LogError(err, requestID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LogError", reflect.TypeOf((*Mockresponder)(nil).LogError), err, requestID) } @@ -120,7 +157,7 @@ func (m *Mockresponder) OutputBytes(w http.ResponseWriter, data []byte, requestI } // OutputBytes indicates an expected call of OutputBytes. -func (mr *MockresponderMockRecorder) OutputBytes(w, data, requestID interface{}) *gomock.Call { +func (mr *MockresponderMockRecorder) OutputBytes(w, data, requestID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutputBytes", reflect.TypeOf((*Mockresponder)(nil).OutputBytes), w, data, requestID) } @@ -132,7 +169,7 @@ func (m *Mockresponder) OutputJSON(w http.ResponseWriter, data any, requestId st } // OutputJSON indicates an expected call of OutputJSON. -func (mr *MockresponderMockRecorder) OutputJSON(w, data, requestId interface{}) *gomock.Call { +func (mr *MockresponderMockRecorder) OutputJSON(w, data, requestId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutputJSON", reflect.TypeOf((*Mockresponder)(nil).OutputJSON), w, data, requestId) } diff --git a/internal/fileService/service/fileService.go b/internal/fileService/service/fileService.go index 6a0b7d5e..fab9845c 100644 --- a/internal/fileService/service/fileService.go +++ b/internal/fileService/service/fileService.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "mime/multipart" "os" "github.com/google/uuid" @@ -16,14 +15,16 @@ func NewFileService() *FileService { return &FileService{} } -func (f *FileService) Download(ctx context.Context, file multipart.File) (string, error) { +func (f *FileService) Download(ctx context.Context, file io.Reader, format string) (string, error) { var ( fileName = uuid.New().String() - filePath = fmt.Sprintf("/image/%s", fileName) + filePath = fmt.Sprintf("/image/%s.%s", fileName, format) dst, err = os.Create(filePath) ) - defer dst.Close() + defer func(dst *os.File) { + _ = dst.Close() + }(dst) if err != nil { return "", fmt.Errorf("save file: %w", err) } @@ -35,6 +36,27 @@ func (f *FileService) Download(ctx context.Context, file multipart.File) (string return filePath, nil } +func (f *FileService) DownloadNonImage( + ctx context.Context, file io.Reader, format, realName string, +) (string, error) { + var ( + fileName = uuid.New().String() + filePath = fmt.Sprintf("/files/%s|%s.%s", fileName, realName, format) + dst, err = os.Create(filePath) + ) + defer func(dst *os.File) { + _ = dst.Close() + }(dst) + if err != nil { + return "", fmt.Errorf("save file: %w", err) + } + if _, err := io.Copy(dst, file); err != nil { + return "", fmt.Errorf("save file: %w", err) + } + + return filePath, nil +} + func (f *FileService) Upload(ctx context.Context, name string) ([]byte, error) { var ( file, err = os.Open(fmt.Sprintf("/image/%s", name)) @@ -42,7 +64,30 @@ func (f *FileService) Upload(ctx context.Context, name string) ([]byte, error) { sl = make([]byte, 1024) ) - defer file.Close() + defer func(file *os.File) { + _ = file.Close() + }(file) + if err != nil { + return nil, fmt.Errorf("open file: %w", err) + } + + for n, err := file.Read(sl); err != io.EOF; n, err = file.Read(sl) { + res = append(res, sl[:n]...) + } + + return res, nil +} + +func (f *FileService) UploadNonImage(ctx context.Context, name string) ([]byte, error) { + var ( + file, err = os.Open(fmt.Sprintf("/files/%s", name)) + res []byte + sl = make([]byte, 1024) + ) + + defer func(file *os.File) { + _ = file.Close() + }(file) if err != nil { return nil, fmt.Errorf("open file: %w", err) } diff --git a/internal/middleware/accesslog.go b/internal/middleware/accesslog.go index 4b808642..9da97f26 100644 --- a/internal/middleware/accesslog.go +++ b/internal/middleware/accesslog.go @@ -9,12 +9,21 @@ import ( log "github.com/sirupsen/logrus" ) +type requestID string + +var RequestKey requestID = "requestID" + func AccessLog(logger *log.Logger, next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - id := uuid.New().String() - ctx := context.WithValue(r.Context(), "requestID", id) - start := time.Now() - next.ServeHTTP(w, r.WithContext(ctx)) - logger.Infof("New request:%s\n \tMethod: %v\n\tRemote addr: %v\n\tURL: %v\n\tTime: %v", id, r.Method, r.RemoteAddr, r.URL.String(), time.Since(start)) - }) + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + id := uuid.New().String() + ctx := context.WithValue(r.Context(), RequestKey, id) + start := time.Now() + next.ServeHTTP(w, r.WithContext(ctx)) + logger.Infof( + "New request:%s\n \tMethod: %v\n\tRemote addr: %v\n\tURL: %v\n\tTime: %v", id, r.Method, r.RemoteAddr, + r.URL.String(), time.Since(start), + ) + }, + ) } diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 8e07a7a3..c9b3d69a 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -21,52 +21,56 @@ type SessionManager interface { } func Auth(sm SessionManager, next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if _, ok := noAuthUrls[r.URL.Path]; ok { - logout(w, r, sm) - next.ServeHTTP(w, r) - return - } - - sessionCookie, err := r.Cookie("session_id") - if err != nil { - unauthorized(w, r, err) - return - } - - sess, err := sm.Check(sessionCookie.Value) - if err != nil { - unauthorized(w, r, err) - return - } - - if sess.CreatedAt <= time.Now().Add(-time.Hour).Unix() { - if err := sm.Destroy(sess); err != nil { - log.Println(r.Context().Value("requestID"), err) - internalErr(w) + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if _, ok := noAuthUrls[r.URL.Path]; ok { + logout(w, r, sm) + next.ServeHTTP(w, r) return } - sess, err = sm.Create(sess.UserID) + sessionCookie, err := r.Cookie("session_id") if err != nil { - log.Println(r.Context().Value("requestID"), err) - internalErr(w) + unauthorized(w, r, err) return } - cookie := &http.Cookie{ - Name: "session_id", - Value: sess.ID, - Path: "/", - HttpOnly: true, - Expires: time.Now().AddDate(0, 0, 1), + sess, err := sm.Check(sessionCookie.Value) + if err != nil { + unauthorized(w, r, err) + return + } + + if sess.CreatedAt <= time.Now().Add(-time.Hour).Unix() { + if err := sm.Destroy(sess); err != nil { + log.Println(r.Context().Value("requestID"), err) + internalErr(w) + return + } + + sess, err = sm.Create(sess.UserID) + if err != nil { + log.Println(r.Context().Value("requestID"), err) + internalErr(w) + return + } + + cookie := &http.Cookie{ + Name: "session_id", + Value: sess.ID, + Path: "/", + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, + Expires: time.Now().AddDate(0, 0, 1), + } + http.SetCookie(w, cookie) } - http.SetCookie(w, cookie) - } - ctx := models.ContextWithSession(r.Context(), sess) - next.ServeHTTP(w, r.WithContext(ctx)) - }) + ctx := models.ContextWithSession(r.Context(), sess) + next.ServeHTTP(w, r.WithContext(ctx)) + }, + ) } func logout(w http.ResponseWriter, r *http.Request, sm SessionManager) { @@ -90,6 +94,8 @@ func logout(w http.ResponseWriter, r *http.Request, sm SessionManager) { Value: sess.ID, Path: "/", HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, Expires: time.Now().AddDate(0, 0, -1), } @@ -98,7 +104,7 @@ func logout(w http.ResponseWriter, r *http.Request, sm SessionManager) { func unauthorized(w http.ResponseWriter, r *http.Request, err error) { w.Header().Set("Content-Type", "application/json:charset=UTF-8") - w.Header().Set("Access-Control-Allow-Origin", "http://vilka.online") + w.Header().Set("Access-Control-Allow-Origin", "https://vilka.online") w.Header().Set("Access-Control-Allow-Credentials", "true") w.WriteHeader(http.StatusUnauthorized) @@ -109,7 +115,7 @@ func unauthorized(w http.ResponseWriter, r *http.Request, err error) { func internalErr(w http.ResponseWriter) { w.Header().Set("Content-Type", "application/json:charset=UTF-8") - w.Header().Set("Access-Control-Allow-Origin", "http://vilka.online") + w.Header().Set("Access-Control-Allow-Origin", "https://vilka.online") w.Header().Set("Access-Control-Allow-Credentials", "true") w.WriteHeader(http.StatusInternalServerError) diff --git a/internal/middleware/fileMetrics.go b/internal/middleware/fileMetrics.go index 9a55a646..315d47ad 100644 --- a/internal/middleware/fileMetrics.go +++ b/internal/middleware/fileMetrics.go @@ -41,35 +41,41 @@ func (rw *fileResponseWriter) Write(b []byte) (int, error) { } func FileMetricsMiddleware(metr *metrics.FileMetrics, next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - respWithCode := NewFileResponseWriter(w) - next.ServeHTTP(respWithCode, r) - statusCode := respWithCode.statusCode - path := r.URL.Path - method := r.Method - var ( - err error - format string - size int64 - ) - if r.Method == http.MethodPost { - format, size, err = getFormatAndSize(r) - } else if r.Method == http.MethodGet { - file := respWithCode.file - format = http.DetectContentType(file[:512]) - size = int64(len(file)) - } - if err != nil { - format = "error" - size = 0 - } - if statusCode != http.StatusOK && statusCode != http.StatusNoContent { - metr.IncErrors(path, strconv.Itoa(statusCode), method, format, size) - } - metr.IncHits(path, strconv.Itoa(statusCode), method, format, size) - metr.ObserveTiming(path, strconv.Itoa(statusCode), method, format, size, time.Since(start).Seconds()) - }) + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + respWithCode := NewFileResponseWriter(w) + next.ServeHTTP(respWithCode, r) + statusCode := respWithCode.statusCode + path := r.URL.Path + method := r.Method + var ( + err error + format string + size int64 + ) + if r.Method == http.MethodPost { + format, size, err = getFormatAndSize(r) + } else if r.Method == http.MethodGet { + file := respWithCode.file + size = int64(len(file)) + if size <= 512 { + format = http.DetectContentType(file[:size]) + } else { + format = http.DetectContentType(file[:512]) + } + } + if err != nil { + format = "error" + size = 0 + } + if statusCode != http.StatusOK && statusCode != http.StatusNoContent { + metr.IncErrors(path, strconv.Itoa(statusCode), method, format, size) + } + metr.IncHits(path, strconv.Itoa(statusCode), method, format, size) + metr.ObserveTiming(path, strconv.Itoa(statusCode), method, format, size, time.Since(start).Seconds()) + }, + ) } func getFormatAndSize(r *http.Request) (string, int64, error) { diff --git a/internal/middleware/preflite.go b/internal/middleware/preflite.go index 86149bc4..df3af25f 100644 --- a/internal/middleware/preflite.go +++ b/internal/middleware/preflite.go @@ -5,19 +5,21 @@ import ( ) func Preflite(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodOptions { - w.Header().Set("Access-Control-Allow-Origin", "http://vilka.online") - w.Header().Set("Access-Control-Allow-Methods", "POST, GET, PUT, DELETE") - w.Header().Set("Access-Control-Max-Age", "3600") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept") - w.Header().Set("Content-Type", "application/json:charset=UTF-8") - w.Header().Set("Access-Control-Allow-Credentials", "true") + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions { + w.Header().Set("Access-Control-Allow-Origin", "https://vilka.online") + w.Header().Set("Access-Control-Allow-Methods", "POST, GET, PUT, DELETE") + w.Header().Set("Access-Control-Max-Age", "3600") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept") + w.Header().Set("Content-Type", "application/json:charset=UTF-8") + w.Header().Set("Access-Control-Allow-Credentials", "true") - w.WriteHeader(http.StatusOK) - return - } else { - next.ServeHTTP(w, r) - } - }) + w.WriteHeader(http.StatusOK) + return + } else { + next.ServeHTTP(w, r) + } + }, + ) } diff --git a/internal/models/chat.go b/internal/models/chat.go index 56c68911..41a25bd8 100644 --- a/internal/models/chat.go +++ b/internal/models/chat.go @@ -1,18 +1,98 @@ package models import ( + "strings" "time" ) +//easyjson:json type Chat struct { LastMessage string `json:"last_message"` LastDate time.Time `json:"last_date"` Receiver Header `json:"receiver"` } +//easyjson:json type Message struct { - Sender uint32 `json:"sender"` - Receiver uint32 `json:"receiver"` - Content string `json:"content"` - CreatedAt time.Time `json:"created_at"` + Sender uint32 `json:"sender"` + Receiver uint32 `json:"receiver"` + Content MessageContent `json:"content"` + CreatedAt time.Time `json:"created_at"` +} + +//easyjson:skip +type MessageDto struct { + Sender uint32 + Receiver uint32 + Content MessageContentDto + CreatedAt time.Time +} + +func (m *Message) ToDto() MessageDto { + return MessageDto{ + Sender: m.Sender, + Receiver: m.Receiver, + Content: m.Content.ToDto(), + CreatedAt: m.CreatedAt, + } +} + +func (m *MessageDto) FromDto() Message { + return Message{ + Sender: m.Sender, + Receiver: m.Receiver, + Content: m.Content.FromDto(), + CreatedAt: m.CreatedAt, + } +} + +//easyjson:json +type MessageContent struct { + Text string `json:"text"` + FilePath []string `json:"file_path"` + StickerPath string `json:"sticker_path"` +} + +//easyjson:skip +type MessageContentDto struct { + Text string + FilePath string + StickerPath string +} + +func (mc *MessageContent) ToDto() MessageContentDto { + files := make([]string, 0, len(mc.FilePath)) + for _, f := range mc.FilePath { + if f == "" { + continue + } + files = append(files, f) + } + + return MessageContentDto{ + Text: mc.Text, + FilePath: strings.Join(files, "||;||"), + StickerPath: mc.StickerPath, + } +} + +func (mc *MessageContentDto) FromDto() MessageContent { + files := strings.Split(mc.FilePath, "||;||") + contentFiles := make([]string, 0, len(files)) + for _, f := range files { + if f == "" { + continue + } + contentFiles = append(contentFiles, f) + } + + if len(contentFiles) == 0 { + contentFiles = nil + } + + return MessageContent{ + Text: mc.Text, + FilePath: contentFiles, + StickerPath: mc.StickerPath, + } } diff --git a/internal/models/chat_easyjson.go b/internal/models/chat_easyjson.go new file mode 100644 index 00000000..ca2d3f8a --- /dev/null +++ b/internal/models/chat_easyjson.go @@ -0,0 +1,365 @@ +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package models + +import ( + json "encoding/json" + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) + +func easyjson9b8f5552DecodeGithubCom20242BetterCallFirewallInternalModels(in *jlexer.Lexer, out *MessageContent) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "text": + out.Text = string(in.String()) + case "file_path": + if in.IsNull() { + in.Skip() + out.FilePath = nil + } else { + in.Delim('[') + if out.FilePath == nil { + if !in.IsDelim(']') { + out.FilePath = make([]string, 0, 4) + } else { + out.FilePath = []string{} + } + } else { + out.FilePath = (out.FilePath)[:0] + } + for !in.IsDelim(']') { + var v1 string + v1 = string(in.String()) + out.FilePath = append(out.FilePath, v1) + in.WantComma() + } + in.Delim(']') + } + case "sticker_path": + out.StickerPath = string(in.String()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson9b8f5552EncodeGithubCom20242BetterCallFirewallInternalModels(out *jwriter.Writer, in MessageContent) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"text\":" + out.RawString(prefix[1:]) + out.String(string(in.Text)) + } + { + const prefix string = ",\"file_path\":" + out.RawString(prefix) + if in.FilePath == nil && (out.Flags&jwriter.NilSliceAsEmpty) == 0 { + out.RawString("null") + } else { + out.RawByte('[') + for v2, v3 := range in.FilePath { + if v2 > 0 { + out.RawByte(',') + } + out.String(string(v3)) + } + out.RawByte(']') + } + } + { + const prefix string = ",\"sticker_path\":" + out.RawString(prefix) + out.String(string(in.StickerPath)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v MessageContent) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson9b8f5552EncodeGithubCom20242BetterCallFirewallInternalModels(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v MessageContent) MarshalEasyJSON(w *jwriter.Writer) { + easyjson9b8f5552EncodeGithubCom20242BetterCallFirewallInternalModels(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *MessageContent) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson9b8f5552DecodeGithubCom20242BetterCallFirewallInternalModels(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *MessageContent) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson9b8f5552DecodeGithubCom20242BetterCallFirewallInternalModels(l, v) +} +func easyjson9b8f5552DecodeGithubCom20242BetterCallFirewallInternalModels1(in *jlexer.Lexer, out *Message) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "sender": + out.Sender = uint32(in.Uint32()) + case "receiver": + out.Receiver = uint32(in.Uint32()) + case "content": + (out.Content).UnmarshalEasyJSON(in) + case "created_at": + if data := in.Raw(); in.Ok() { + in.AddError((out.CreatedAt).UnmarshalJSON(data)) + } + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson9b8f5552EncodeGithubCom20242BetterCallFirewallInternalModels1(out *jwriter.Writer, in Message) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"sender\":" + out.RawString(prefix[1:]) + out.Uint32(uint32(in.Sender)) + } + { + const prefix string = ",\"receiver\":" + out.RawString(prefix) + out.Uint32(uint32(in.Receiver)) + } + { + const prefix string = ",\"content\":" + out.RawString(prefix) + (in.Content).MarshalEasyJSON(out) + } + { + const prefix string = ",\"created_at\":" + out.RawString(prefix) + out.Raw((in.CreatedAt).MarshalJSON()) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v Message) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson9b8f5552EncodeGithubCom20242BetterCallFirewallInternalModels1(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v Message) MarshalEasyJSON(w *jwriter.Writer) { + easyjson9b8f5552EncodeGithubCom20242BetterCallFirewallInternalModels1(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *Message) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson9b8f5552DecodeGithubCom20242BetterCallFirewallInternalModels1(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *Message) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson9b8f5552DecodeGithubCom20242BetterCallFirewallInternalModels1(l, v) +} +func easyjson9b8f5552DecodeGithubCom20242BetterCallFirewallInternalModels2(in *jlexer.Lexer, out *Chat) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "last_message": + out.LastMessage = string(in.String()) + case "last_date": + if data := in.Raw(); in.Ok() { + in.AddError((out.LastDate).UnmarshalJSON(data)) + } + case "receiver": + easyjson9b8f5552DecodeGithubCom20242BetterCallFirewallInternalModels3(in, &out.Receiver) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson9b8f5552EncodeGithubCom20242BetterCallFirewallInternalModels2(out *jwriter.Writer, in Chat) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"last_message\":" + out.RawString(prefix[1:]) + out.String(string(in.LastMessage)) + } + { + const prefix string = ",\"last_date\":" + out.RawString(prefix) + out.Raw((in.LastDate).MarshalJSON()) + } + { + const prefix string = ",\"receiver\":" + out.RawString(prefix) + easyjson9b8f5552EncodeGithubCom20242BetterCallFirewallInternalModels3(out, in.Receiver) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v Chat) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson9b8f5552EncodeGithubCom20242BetterCallFirewallInternalModels2(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v Chat) MarshalEasyJSON(w *jwriter.Writer) { + easyjson9b8f5552EncodeGithubCom20242BetterCallFirewallInternalModels2(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *Chat) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson9b8f5552DecodeGithubCom20242BetterCallFirewallInternalModels2(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *Chat) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson9b8f5552DecodeGithubCom20242BetterCallFirewallInternalModels2(l, v) +} +func easyjson9b8f5552DecodeGithubCom20242BetterCallFirewallInternalModels3(in *jlexer.Lexer, out *Header) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "author_id": + out.AuthorID = uint32(in.Uint32()) + case "community_id": + out.CommunityID = uint32(in.Uint32()) + case "author": + out.Author = string(in.String()) + case "avatar": + out.Avatar = Picture(in.String()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson9b8f5552EncodeGithubCom20242BetterCallFirewallInternalModels3(out *jwriter.Writer, in Header) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"author_id\":" + out.RawString(prefix[1:]) + out.Uint32(uint32(in.AuthorID)) + } + { + const prefix string = ",\"community_id\":" + out.RawString(prefix) + out.Uint32(uint32(in.CommunityID)) + } + { + const prefix string = ",\"author\":" + out.RawString(prefix) + out.String(string(in.Author)) + } + { + const prefix string = ",\"avatar\":" + out.RawString(prefix) + out.String(string(in.Avatar)) + } + out.RawByte('}') +} diff --git a/internal/models/chat_test.go b/internal/models/chat_test.go new file mode 100644 index 00000000..e4c9d96a --- /dev/null +++ b/internal/models/chat_test.go @@ -0,0 +1,160 @@ +package models + +import ( + "encoding/json" + "testing" + "time" + + "github.com/mailru/easyjson" + "github.com/stretchr/testify/assert" +) + +//easyjson:skip +type TestCaseMessageContent struct { + content MessageContent + contentDto MessageContentDto +} + +func TestFromDtoMessageContent(t *testing.T) { + tests := []TestCaseMessageContent{ + {content: MessageContent{}, contentDto: MessageContentDto{}}, + {content: MessageContent{Text: "text"}, contentDto: MessageContentDto{Text: "text"}}, + {content: MessageContent{FilePath: []string{"image"}}, contentDto: MessageContentDto{FilePath: "image"}}, + { + content: MessageContent{FilePath: []string{"image", "second image"}}, + contentDto: MessageContentDto{FilePath: "image||;||second image"}, + }, + } + + for _, test := range tests { + res := test.contentDto.FromDto() + assert.Equal(t, test.content, res) + } +} + +func TestToDtoMessageContent(t *testing.T) { + tests := []TestCaseMessageContent{ + {content: MessageContent{}, contentDto: MessageContentDto{}}, + {content: MessageContent{Text: "text"}, contentDto: MessageContentDto{Text: "text"}}, + {content: MessageContent{FilePath: []string{"image"}}, contentDto: MessageContentDto{FilePath: "image"}}, + { + content: MessageContent{FilePath: []string{"image", "second image"}}, + contentDto: MessageContentDto{FilePath: "image||;||second image"}, + }, + { + content: MessageContent{FilePath: []string{""}}, + contentDto: MessageContentDto{FilePath: ""}, + }, + } + + for _, test := range tests { + res := test.content.ToDto() + assert.Equal(t, test.contentDto, res) + } +} + +//easyjson:skip +type TestCaseMessage struct { + message Message + messageDto MessageDto +} + +func TestMessageFromDto(t *testing.T) { + tests := []TestCaseMessage{ + {message: Message{}, messageDto: MessageDto{}}, + { + message: Message{Content: MessageContent{FilePath: []string{"image"}}}, + messageDto: MessageDto{Content: MessageContentDto{FilePath: "image"}}, + }, + } + + for _, test := range tests { + res := test.messageDto.FromDto() + assert.Equal(t, test.message, res) + } +} + +func TestMessageToDto(t *testing.T) { + tests := []TestCaseMessage{ + {message: Message{}, messageDto: MessageDto{}}, + { + message: Message{Content: MessageContent{FilePath: []string{"image"}}}, + messageDto: MessageDto{Content: MessageContentDto{FilePath: "image"}}, + }, + } + + for _, test := range tests { + res := test.message.ToDto() + assert.Equal(t, test.messageDto, res) + } +} + +func TestMarshal(t *testing.T) { + createTime := time.Time{} + m := &Message{ + Content: MessageContent{Text: "new message", FilePath: []string{"image"}}, Sender: 1, Receiver: 2, + CreatedAt: createTime, + } + + want := []byte(`{"sender":1,"receiver":2,"content":{"text":"new message","file_path":["image"],"sticker_path":""},"created_at":"0001-01-01T00:00:00Z"}`) + res, err := easyjson.Marshal(m) + assert.NoError(t, err) + assert.Equal(t, string(want), string(res)) + + res, err = json.Marshal(m) + assert.NoError(t, err) + assert.Equal(t, string(want), string(res)) + +} + +func TestUnmarshal(t *testing.T) { + sl := []byte(`{"sender":1,"receiver":2,"content":{"text":"new message","file_path":["image"],"sticker_path":""},"created_at":"0001-01-01T00:00:00Z"}`) + m := &Message{} + err := easyjson.Unmarshal(sl, m) + assert.NoError(t, err) + createTime := time.Time{} + want := &Message{ + Content: MessageContent{Text: "new message", FilePath: []string{"image"}}, Sender: 1, Receiver: 2, + CreatedAt: createTime, + } + + assert.Equal(t, want, m) +} + +func TestMarshalChat(t *testing.T) { + c := &Chat{ + LastMessage: "message", + LastDate: time.Time{}, + Receiver: Header{ + Author: "Andrew Savvateev", + }, + } + want := []byte(`{"last_message":"message","last_date":"0001-01-01T00:00:00Z","receiver":{"author_id":0,"community_id":0,"author":"Andrew Savvateev","avatar":""}}`) + res, err := easyjson.Marshal(c) + assert.NoError(t, err) + assert.Equal(t, string(want), string(res)) + + res, err = json.Marshal(c) + assert.NoError(t, err) + assert.Equal(t, string(want), string(res)) +} + +func TestUnmarshalChat(t *testing.T) { + sl := []byte(`{"last_message":"message","last_date":"0001-01-01T00:00:00Z","receiver":{"author_id":0,"community_id":0,"author":"Andrew Savvateev","avatar":""}}`) + c := &Chat{} + want := &Chat{ + LastMessage: "message", + LastDate: time.Time{}, + Receiver: Header{ + Author: "Andrew Savvateev", + }, + } + + err := easyjson.Unmarshal(sl, c) + assert.NoError(t, err) + assert.Equal(t, want, c) + + err = json.Unmarshal(sl, c) + assert.NoError(t, err) + assert.Equal(t, want, c) +} diff --git a/internal/models/community.go b/internal/models/community.go index f87dbe45..8630c16a 100644 --- a/internal/models/community.go +++ b/internal/models/community.go @@ -1,5 +1,6 @@ package models +//easyjson:json type Community struct { ID uint32 `json:"id"` Name string `json:"name"` @@ -10,6 +11,7 @@ type Community struct { IsFollowed bool `json:"is_followed,omitempty"` } +//easyjson:json type CommunityCard struct { ID uint32 `json:"id"` Name string `json:"name"` diff --git a/internal/models/community_easyjson.go b/internal/models/community_easyjson.go new file mode 100644 index 00000000..4d9699ac --- /dev/null +++ b/internal/models/community_easyjson.go @@ -0,0 +1,221 @@ +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package models + +import ( + json "encoding/json" + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) + +func easyjson798dd0c9DecodeGithubCom20242BetterCallFirewallInternalModels(in *jlexer.Lexer, out *CommunityCard) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "id": + out.ID = uint32(in.Uint32()) + case "name": + out.Name = string(in.String()) + case "avatar": + out.Avatar = Picture(in.String()) + case "about": + out.About = string(in.String()) + case "is_followed": + out.IsFollowed = bool(in.Bool()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson798dd0c9EncodeGithubCom20242BetterCallFirewallInternalModels(out *jwriter.Writer, in CommunityCard) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"id\":" + out.RawString(prefix[1:]) + out.Uint32(uint32(in.ID)) + } + { + const prefix string = ",\"name\":" + out.RawString(prefix) + out.String(string(in.Name)) + } + { + const prefix string = ",\"avatar\":" + out.RawString(prefix) + out.String(string(in.Avatar)) + } + { + const prefix string = ",\"about\":" + out.RawString(prefix) + out.String(string(in.About)) + } + if in.IsFollowed { + const prefix string = ",\"is_followed\":" + out.RawString(prefix) + out.Bool(bool(in.IsFollowed)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v CommunityCard) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson798dd0c9EncodeGithubCom20242BetterCallFirewallInternalModels(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v CommunityCard) MarshalEasyJSON(w *jwriter.Writer) { + easyjson798dd0c9EncodeGithubCom20242BetterCallFirewallInternalModels(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *CommunityCard) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson798dd0c9DecodeGithubCom20242BetterCallFirewallInternalModels(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *CommunityCard) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson798dd0c9DecodeGithubCom20242BetterCallFirewallInternalModels(l, v) +} +func easyjson798dd0c9DecodeGithubCom20242BetterCallFirewallInternalModels1(in *jlexer.Lexer, out *Community) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "id": + out.ID = uint32(in.Uint32()) + case "name": + out.Name = string(in.String()) + case "avatar": + out.Avatar = Picture(in.String()) + case "about": + out.About = string(in.String()) + case "count_subscribers": + out.CountSubscribers = uint32(in.Uint32()) + case "is_admin": + out.IsAdmin = bool(in.Bool()) + case "is_followed": + out.IsFollowed = bool(in.Bool()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson798dd0c9EncodeGithubCom20242BetterCallFirewallInternalModels1(out *jwriter.Writer, in Community) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"id\":" + out.RawString(prefix[1:]) + out.Uint32(uint32(in.ID)) + } + { + const prefix string = ",\"name\":" + out.RawString(prefix) + out.String(string(in.Name)) + } + { + const prefix string = ",\"avatar\":" + out.RawString(prefix) + out.String(string(in.Avatar)) + } + { + const prefix string = ",\"about\":" + out.RawString(prefix) + out.String(string(in.About)) + } + { + const prefix string = ",\"count_subscribers\":" + out.RawString(prefix) + out.Uint32(uint32(in.CountSubscribers)) + } + if in.IsAdmin { + const prefix string = ",\"is_admin\":" + out.RawString(prefix) + out.Bool(bool(in.IsAdmin)) + } + if in.IsFollowed { + const prefix string = ",\"is_followed\":" + out.RawString(prefix) + out.Bool(bool(in.IsFollowed)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v Community) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson798dd0c9EncodeGithubCom20242BetterCallFirewallInternalModels1(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v Community) MarshalEasyJSON(w *jwriter.Writer) { + easyjson798dd0c9EncodeGithubCom20242BetterCallFirewallInternalModels1(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *Community) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson798dd0c9DecodeGithubCom20242BetterCallFirewallInternalModels1(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *Community) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson798dd0c9DecodeGithubCom20242BetterCallFirewallInternalModels1(l, v) +} diff --git a/internal/models/content.go b/internal/models/content.go index 2629c5ce..cef48ac9 100644 --- a/internal/models/content.go +++ b/internal/models/content.go @@ -1,12 +1,61 @@ package models import ( + "strings" "time" ) +//easyjson:json type Content struct { Text string `json:"text"` - File Picture `json:"file,omitempty"` + File []Picture `json:"file,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } + +func (c *Content) ToDto() ContentDto { + files := make([]string, 0, len(c.File)) + for _, f := range c.File { + if f == "" { + continue + } + files = append(files, string(f)) + } + + return ContentDto{ + Text: c.Text, + File: Picture(strings.Join(files, "||;||")), + CreatedAt: c.CreatedAt, + UpdatedAt: c.UpdatedAt, + } +} + +//easyjson:skip +type ContentDto struct { + Text string + File Picture + CreatedAt time.Time + UpdatedAt time.Time +} + +func (c *ContentDto) FromDto() Content { + files := strings.Split(string(c.File), "||;||") + contentFiles := make([]Picture, 0, len(files)) + for _, f := range files { + if f == "" { + continue + } + contentFiles = append(contentFiles, Picture(f)) + } + + if len(contentFiles) == 0 { + contentFiles = nil + } + + return Content{ + Text: c.Text, + File: contentFiles, + CreatedAt: c.CreatedAt, + UpdatedAt: c.UpdatedAt, + } +} diff --git a/internal/models/content_easyjson.go b/internal/models/content_easyjson.go new file mode 100644 index 00000000..fd293128 --- /dev/null +++ b/internal/models/content_easyjson.go @@ -0,0 +1,140 @@ +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package models + +import ( + json "encoding/json" + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) + +func easyjson344736e9DecodeGithubCom20242BetterCallFirewallInternalModels(in *jlexer.Lexer, out *Content) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "text": + out.Text = string(in.String()) + case "file": + if in.IsNull() { + in.Skip() + out.File = nil + } else { + in.Delim('[') + if out.File == nil { + if !in.IsDelim(']') { + out.File = make([]Picture, 0, 4) + } else { + out.File = []Picture{} + } + } else { + out.File = (out.File)[:0] + } + for !in.IsDelim(']') { + var v1 Picture + v1 = Picture(in.String()) + out.File = append(out.File, v1) + in.WantComma() + } + in.Delim(']') + } + case "created_at": + if data := in.Raw(); in.Ok() { + in.AddError((out.CreatedAt).UnmarshalJSON(data)) + } + case "updated_at": + if data := in.Raw(); in.Ok() { + in.AddError((out.UpdatedAt).UnmarshalJSON(data)) + } + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson344736e9EncodeGithubCom20242BetterCallFirewallInternalModels(out *jwriter.Writer, in Content) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"text\":" + out.RawString(prefix[1:]) + out.String(string(in.Text)) + } + if len(in.File) != 0 { + const prefix string = ",\"file\":" + out.RawString(prefix) + { + out.RawByte('[') + for v2, v3 := range in.File { + if v2 > 0 { + out.RawByte(',') + } + out.String(string(v3)) + } + out.RawByte(']') + } + } + { + const prefix string = ",\"created_at\":" + out.RawString(prefix) + out.Raw((in.CreatedAt).MarshalJSON()) + } + { + const prefix string = ",\"updated_at\":" + out.RawString(prefix) + out.Raw((in.UpdatedAt).MarshalJSON()) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v Content) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson344736e9EncodeGithubCom20242BetterCallFirewallInternalModels(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v Content) MarshalEasyJSON(w *jwriter.Writer) { + easyjson344736e9EncodeGithubCom20242BetterCallFirewallInternalModels(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *Content) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson344736e9DecodeGithubCom20242BetterCallFirewallInternalModels(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *Content) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson344736e9DecodeGithubCom20242BetterCallFirewallInternalModels(l, v) +} diff --git a/internal/models/content_test.go b/internal/models/content_test.go new file mode 100644 index 00000000..57eeb97a --- /dev/null +++ b/internal/models/content_test.go @@ -0,0 +1,91 @@ +package models + +import ( + "encoding/json" + "testing" + "time" + + "github.com/mailru/easyjson" + "github.com/stretchr/testify/assert" +) + +//easyjson:skip +type TestCase struct { + content Content + contentDto ContentDto +} + +func TestFromDto(t *testing.T) { + tests := []TestCase{ + {content: Content{}, contentDto: ContentDto{}}, + {content: Content{Text: "text"}, contentDto: ContentDto{Text: "text"}}, + {content: Content{File: []Picture{"image"}}, contentDto: ContentDto{File: "image"}}, + { + content: Content{File: []Picture{"image", "second image"}}, + contentDto: ContentDto{File: "image||;||second image"}, + }, + } + + for _, test := range tests { + res := test.contentDto.FromDto() + assert.Equal(t, test.content, res) + } +} + +func TestToDto(t *testing.T) { + tests := []TestCase{ + {content: Content{}, contentDto: ContentDto{}}, + {content: Content{Text: "text"}, contentDto: ContentDto{Text: "text"}}, + {content: Content{File: []Picture{"image"}}, contentDto: ContentDto{File: "image"}}, + { + content: Content{File: []Picture{"image", "second image"}}, + contentDto: ContentDto{File: "image||;||second image"}, + }, + { + content: Content{File: []Picture{""}}, + contentDto: ContentDto{File: ""}, + }, + } + + for _, test := range tests { + res := test.content.ToDto() + assert.Equal(t, test.contentDto, res) + } +} + +func TestMarshalJson(t *testing.T) { + c := &Content{ + Text: "comment", + File: []Picture{Picture("image")}, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + } + want := []byte(`{"text":"comment","file":["image"],"created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z"}`) + + res, err := easyjson.Marshal(c) + assert.NoError(t, err) + assert.Equal(t, want, res) + + res, err = json.Marshal(c) + assert.NoError(t, err) + assert.Equal(t, want, res) +} + +func TestUnmarshallJson(t *testing.T) { + sl := []byte(`{"text":"comment","file":["image"],"created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z"}`) + c := &Content{} + want := &Content{ + Text: "comment", + File: []Picture{Picture("image")}, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + } + + err := easyjson.Unmarshal(sl, c) + assert.NoError(t, err) + assert.Equal(t, want, c) + + err = json.Unmarshal(sl, want) + assert.NoError(t, err) + assert.Equal(t, want, c) +} diff --git a/internal/models/post.go b/internal/models/post.go index 77b01e23..b2930e77 100644 --- a/internal/models/post.go +++ b/internal/models/post.go @@ -1,16 +1,89 @@ package models +//easyjson:json type Post struct { - ID uint32 `json:"id"` - Header Header `json:"header"` - PostContent Content `json:"post_content"` - LikesCount uint32 `json:"likes_count"` - IsLiked bool `json:"is_liked"` + ID uint32 `json:"id"` + Header Header `json:"header"` + PostContent Content `json:"post_content"` + LikesCount uint32 `json:"likes_count"` + IsLiked bool `json:"is_liked"` + CommentCount uint32 `json:"comment_count"` } +func (p *Post) ToDto() PostDto { + return PostDto{ + ID: p.ID, + Header: p.Header, + PostContent: p.PostContent.ToDto(), + LikesCount: p.LikesCount, + IsLiked: p.IsLiked, + CommentCount: p.CommentCount, + } +} + +//easyjson:skip +type PostDto struct { + ID uint32 + Header Header + PostContent ContentDto + LikesCount uint32 + IsLiked bool + CommentCount uint32 +} + +func (p *PostDto) FromDto() Post { + return Post{ + ID: p.ID, + Header: p.Header, + PostContent: p.PostContent.FromDto(), + LikesCount: p.LikesCount, + IsLiked: p.IsLiked, + CommentCount: p.CommentCount, + } +} + +//easyjson:json type Header struct { AuthorID uint32 `json:"author_id"` CommunityID uint32 `json:"community_id"` Author string `json:"author"` Avatar Picture `json:"avatar"` } + +//easyjson:json +type Comment struct { + ID uint32 `json:"id"` + Header Header `json:"header"` + Content Content `json:"content"` + LikesCount uint32 `json:"likes_count"` + IsLiked bool `json:"is_liked"` +} + +func (c *Comment) ToDto() CommentDto { + return CommentDto{ + ID: c.ID, + Header: c.Header, + Content: c.Content.ToDto(), + LikesCount: c.LikesCount, + IsLiked: c.IsLiked, + } +} + +//easyjson:skip +type CommentDto struct { + ID uint32 + Header Header + Content ContentDto + LikesCount uint32 + IsLiked bool +} + +func (c *CommentDto) FromDto() Comment { + return Comment{ + ID: c.ID, + Header: c.Header, + Content: c.Content.FromDto(), + LikesCount: c.LikesCount, + IsLiked: c.IsLiked, + } +} diff --git a/internal/models/post_easyjson.go b/internal/models/post_easyjson.go new file mode 100644 index 00000000..be42cefa --- /dev/null +++ b/internal/models/post_easyjson.go @@ -0,0 +1,301 @@ +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package models + +import ( + json "encoding/json" + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) + +func easyjson5a72dc82DecodeGithubCom20242BetterCallFirewallInternalModels(in *jlexer.Lexer, out *Post) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "id": + out.ID = uint32(in.Uint32()) + case "header": + (out.Header).UnmarshalEasyJSON(in) + case "post_content": + (out.PostContent).UnmarshalEasyJSON(in) + case "likes_count": + out.LikesCount = uint32(in.Uint32()) + case "is_liked": + out.IsLiked = bool(in.Bool()) + case "comment_count": + out.CommentCount = uint32(in.Uint32()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson5a72dc82EncodeGithubCom20242BetterCallFirewallInternalModels(out *jwriter.Writer, in Post) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"id\":" + out.RawString(prefix[1:]) + out.Uint32(uint32(in.ID)) + } + { + const prefix string = ",\"header\":" + out.RawString(prefix) + (in.Header).MarshalEasyJSON(out) + } + { + const prefix string = ",\"post_content\":" + out.RawString(prefix) + (in.PostContent).MarshalEasyJSON(out) + } + { + const prefix string = ",\"likes_count\":" + out.RawString(prefix) + out.Uint32(uint32(in.LikesCount)) + } + { + const prefix string = ",\"is_liked\":" + out.RawString(prefix) + out.Bool(bool(in.IsLiked)) + } + { + const prefix string = ",\"comment_count\":" + out.RawString(prefix) + out.Uint32(uint32(in.CommentCount)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v Post) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson5a72dc82EncodeGithubCom20242BetterCallFirewallInternalModels(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v Post) MarshalEasyJSON(w *jwriter.Writer) { + easyjson5a72dc82EncodeGithubCom20242BetterCallFirewallInternalModels(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *Post) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson5a72dc82DecodeGithubCom20242BetterCallFirewallInternalModels(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *Post) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson5a72dc82DecodeGithubCom20242BetterCallFirewallInternalModels(l, v) +} +func easyjson5a72dc82DecodeGithubCom20242BetterCallFirewallInternalModels1(in *jlexer.Lexer, out *Header) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "author_id": + out.AuthorID = uint32(in.Uint32()) + case "community_id": + out.CommunityID = uint32(in.Uint32()) + case "author": + out.Author = string(in.String()) + case "avatar": + out.Avatar = Picture(in.String()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson5a72dc82EncodeGithubCom20242BetterCallFirewallInternalModels1(out *jwriter.Writer, in Header) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"author_id\":" + out.RawString(prefix[1:]) + out.Uint32(uint32(in.AuthorID)) + } + { + const prefix string = ",\"community_id\":" + out.RawString(prefix) + out.Uint32(uint32(in.CommunityID)) + } + { + const prefix string = ",\"author\":" + out.RawString(prefix) + out.String(string(in.Author)) + } + { + const prefix string = ",\"avatar\":" + out.RawString(prefix) + out.String(string(in.Avatar)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v Header) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson5a72dc82EncodeGithubCom20242BetterCallFirewallInternalModels1(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v Header) MarshalEasyJSON(w *jwriter.Writer) { + easyjson5a72dc82EncodeGithubCom20242BetterCallFirewallInternalModels1(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *Header) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson5a72dc82DecodeGithubCom20242BetterCallFirewallInternalModels1(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *Header) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson5a72dc82DecodeGithubCom20242BetterCallFirewallInternalModels1(l, v) +} +func easyjson5a72dc82DecodeGithubCom20242BetterCallFirewallInternalModels2(in *jlexer.Lexer, out *Comment) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "id": + out.ID = uint32(in.Uint32()) + case "header": + (out.Header).UnmarshalEasyJSON(in) + case "content": + (out.Content).UnmarshalEasyJSON(in) + case "likes_count": + out.LikesCount = uint32(in.Uint32()) + case "is_liked": + out.IsLiked = bool(in.Bool()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson5a72dc82EncodeGithubCom20242BetterCallFirewallInternalModels2(out *jwriter.Writer, in Comment) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"id\":" + out.RawString(prefix[1:]) + out.Uint32(uint32(in.ID)) + } + { + const prefix string = ",\"header\":" + out.RawString(prefix) + (in.Header).MarshalEasyJSON(out) + } + { + const prefix string = ",\"content\":" + out.RawString(prefix) + (in.Content).MarshalEasyJSON(out) + } + { + const prefix string = ",\"likes_count\":" + out.RawString(prefix) + out.Uint32(uint32(in.LikesCount)) + } + { + const prefix string = ",\"is_liked\":" + out.RawString(prefix) + out.Bool(bool(in.IsLiked)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v Comment) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson5a72dc82EncodeGithubCom20242BetterCallFirewallInternalModels2(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v Comment) MarshalEasyJSON(w *jwriter.Writer) { + easyjson5a72dc82EncodeGithubCom20242BetterCallFirewallInternalModels2(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *Comment) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson5a72dc82DecodeGithubCom20242BetterCallFirewallInternalModels2(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *Comment) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson5a72dc82DecodeGithubCom20242BetterCallFirewallInternalModels2(l, v) +} diff --git a/internal/models/post_test.go b/internal/models/post_test.go new file mode 100644 index 00000000..62f6b9e3 --- /dev/null +++ b/internal/models/post_test.go @@ -0,0 +1,212 @@ +package models + +import ( + "encoding/json" + "testing" + "time" + + "github.com/mailru/easyjson" + "github.com/stretchr/testify/assert" +) + +//easyjson:skip +type TestCasePost struct { + post Post + postDto PostDto +} + +func TestPostFromDto(t *testing.T) { + tests := []TestCasePost{ + {post: Post{}, postDto: PostDto{}}, + {post: Post{PostContent: Content{Text: "text"}}, postDto: PostDto{PostContent: ContentDto{Text: "text"}}}, + { + post: Post{PostContent: Content{File: []Picture{"image"}}}, + postDto: PostDto{PostContent: ContentDto{File: "image"}}, + }, + { + post: Post{PostContent: Content{File: []Picture{"image", "second image"}}}, + postDto: PostDto{PostContent: ContentDto{File: "image||;||second image"}}, + }, + } + + for _, test := range tests { + res := test.postDto.FromDto() + assert.Equal(t, test.post, res) + } +} + +func TestPostToDto(t *testing.T) { + tests := []TestCasePost{ + {post: Post{}, postDto: PostDto{}}, + {post: Post{PostContent: Content{Text: "text"}}, postDto: PostDto{PostContent: ContentDto{Text: "text"}}}, + { + post: Post{PostContent: Content{File: []Picture{"image"}}}, + postDto: PostDto{PostContent: ContentDto{File: "image"}}, + }, + { + post: Post{PostContent: Content{File: []Picture{"image", "second image"}}}, + postDto: PostDto{PostContent: ContentDto{File: "image||;||second image"}}, + }, + } + + for _, test := range tests { + res := test.post.ToDto() + assert.Equal(t, test.postDto, res) + } +} + +//easyjson:skip +type TestCaseComment struct { + comment Comment + commentDto CommentDto +} + +func TestCommentFromDto(t *testing.T) { + tests := []TestCaseComment{ + {comment: Comment{}, commentDto: CommentDto{}}, + { + comment: Comment{Content: Content{File: []Picture{"image"}}}, + commentDto: CommentDto{Content: ContentDto{File: "image"}}, + }, + } + + for _, test := range tests { + res := test.commentDto.FromDto() + assert.Equal(t, test.comment, res) + } +} + +func TestCommentToDto(t *testing.T) { + tests := []TestCaseComment{ + {comment: Comment{}, commentDto: CommentDto{}}, + { + comment: Comment{Content: Content{File: []Picture{"image"}}}, + commentDto: CommentDto{Content: ContentDto{File: "image"}}, + }, + } + + for _, test := range tests { + res := test.comment.ToDto() + assert.Equal(t, test.commentDto, res) + } +} + +func TestMarshalPost(t *testing.T) { + p := &Post{ + ID: 1, + Header: Header{ + AuthorID: 10, + CommunityID: 0, + Author: "Alexey", + Avatar: "/image", + }, + PostContent: Content{ + Text: "text", + File: []Picture{"image"}, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + }, + LikesCount: 1, + IsLiked: true, + CommentCount: 10, + } + want := []byte(`{"id":1,"header":{"author_id":10,"community_id":0,"author":"Alexey","avatar":"/image"},"post_content":{"text":"text","file":["image"],"created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z"},"likes_count":1,"is_liked":true,"comment_count":10}`) + + res, err := easyjson.Marshal(p) + assert.NoError(t, err) + assert.Equal(t, want, res) + + res, err = json.Marshal(p) + assert.NoError(t, err) + assert.Equal(t, want, res) +} + +func TestUnmarshallPost(t *testing.T) { + want := &Post{ + ID: 1, + Header: Header{ + AuthorID: 10, + CommunityID: 0, + Author: "Alexey", + Avatar: "/image", + }, + PostContent: Content{ + Text: "text", + File: []Picture{"image"}, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + }, + LikesCount: 1, + IsLiked: true, + CommentCount: 10, + } + sl := []byte(`{"id":1,"header":{"author_id":10,"community_id":0,"author":"Alexey","avatar":"/image"},"post_content":{"text":"text","file":["image"],"created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z"},"likes_count":1,"is_liked":true,"comment_count":10}`) + p := &Post{} + + err := easyjson.Unmarshal(sl, p) + assert.NoError(t, err) + assert.Equal(t, want, p) + + err = json.Unmarshal(sl, p) + assert.NoError(t, err) + assert.Equal(t, want, p) +} + +func TestMarshallComment(t *testing.T) { + c := &Comment{ + ID: 10, + Header: Header{ + AuthorID: 10, + CommunityID: 0, + Author: "Alexey", + Avatar: "/image", + }, + Content: Content{ + Text: "comment", + File: nil, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + }, + LikesCount: 0, + IsLiked: false, + } + want := []byte(`{"id":10,"header":{"author_id":10,"community_id":0,"author":"Alexey","avatar":"/image"},"content":{"text":"comment","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z"},"likes_count":0,"is_liked":false}`) + + res, err := easyjson.Marshal(c) + assert.NoError(t, err) + assert.Equal(t, res, want) + + res, err = json.Marshal(c) + assert.NoError(t, err) + assert.Equal(t, want, res) +} + +func TestUnmarshallComment(t *testing.T) { + want := &Comment{ + ID: 10, + Header: Header{ + AuthorID: 10, + CommunityID: 0, + Author: "Alexey", + Avatar: "/image", + }, + Content: Content{ + Text: "comment", + File: nil, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + }, + LikesCount: 0, + IsLiked: false, + } + sl := []byte(`{"id":10,"header":{"author_id":10,"community_id":0,"author":"Alexey","avatar":"/image"},"content":{"text":"comment","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z"},"likes_count":0,"is_liked":false}`) + c := &Comment{} + + err := easyjson.Unmarshal(sl, c) + assert.NoError(t, err) + assert.Equal(t, want, c) + + err = json.Unmarshal(sl, c) + assert.NoError(t, err) + assert.Equal(t, want, c) +} diff --git a/internal/models/profile.go b/internal/models/profile.go index 72133fc5..aa50bece 100644 --- a/internal/models/profile.go +++ b/internal/models/profile.go @@ -1,5 +1,6 @@ package models +//easyjson:json type FullProfile struct { ID uint32 `json:"id"` FirstName string `json:"first_name"` @@ -14,6 +15,7 @@ type FullProfile struct { Posts []*Post `json:"posts"` } +//easyjson:json type ShortProfile struct { ID uint32 `json:"id"` FirstName string `json:"first_name"` @@ -24,3 +26,9 @@ type ShortProfile struct { IsSubscription bool `json:"is_subscription"` Avatar Picture `json:"avatar"` } + +//easyjson:json +type ChangePasswordReq struct { + OldPassword string `json:"old_password"` + NewPassword string `json:"new_password"` +} diff --git a/internal/models/profile_easyjson.go b/internal/models/profile_easyjson.go new file mode 100644 index 00000000..b921add6 --- /dev/null +++ b/internal/models/profile_easyjson.go @@ -0,0 +1,419 @@ +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package models + +import ( + json "encoding/json" + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) + +func easyjson521a5691DecodeGithubCom20242BetterCallFirewallInternalModels(in *jlexer.Lexer, out *ShortProfile) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "id": + out.ID = uint32(in.Uint32()) + case "first_name": + out.FirstName = string(in.String()) + case "last_name": + out.LastName = string(in.String()) + case "is_author": + out.IsAuthor = bool(in.Bool()) + case "is_friend": + out.IsFriend = bool(in.Bool()) + case "is_subscriber": + out.IsSubscriber = bool(in.Bool()) + case "is_subscription": + out.IsSubscription = bool(in.Bool()) + case "avatar": + out.Avatar = Picture(in.String()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson521a5691EncodeGithubCom20242BetterCallFirewallInternalModels(out *jwriter.Writer, in ShortProfile) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"id\":" + out.RawString(prefix[1:]) + out.Uint32(uint32(in.ID)) + } + { + const prefix string = ",\"first_name\":" + out.RawString(prefix) + out.String(string(in.FirstName)) + } + { + const prefix string = ",\"last_name\":" + out.RawString(prefix) + out.String(string(in.LastName)) + } + { + const prefix string = ",\"is_author\":" + out.RawString(prefix) + out.Bool(bool(in.IsAuthor)) + } + { + const prefix string = ",\"is_friend\":" + out.RawString(prefix) + out.Bool(bool(in.IsFriend)) + } + { + const prefix string = ",\"is_subscriber\":" + out.RawString(prefix) + out.Bool(bool(in.IsSubscriber)) + } + { + const prefix string = ",\"is_subscription\":" + out.RawString(prefix) + out.Bool(bool(in.IsSubscription)) + } + { + const prefix string = ",\"avatar\":" + out.RawString(prefix) + out.String(string(in.Avatar)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v ShortProfile) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson521a5691EncodeGithubCom20242BetterCallFirewallInternalModels(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v ShortProfile) MarshalEasyJSON(w *jwriter.Writer) { + easyjson521a5691EncodeGithubCom20242BetterCallFirewallInternalModels(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *ShortProfile) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson521a5691DecodeGithubCom20242BetterCallFirewallInternalModels(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *ShortProfile) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson521a5691DecodeGithubCom20242BetterCallFirewallInternalModels(l, v) +} +func easyjson521a5691DecodeGithubCom20242BetterCallFirewallInternalModels1(in *jlexer.Lexer, out *FullProfile) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "id": + out.ID = uint32(in.Uint32()) + case "first_name": + out.FirstName = string(in.String()) + case "last_name": + out.LastName = string(in.String()) + case "bio": + out.Bio = string(in.String()) + case "is_author": + out.IsAuthor = bool(in.Bool()) + case "is_friend": + out.IsFriend = bool(in.Bool()) + case "is_subscriber": + out.IsSubscriber = bool(in.Bool()) + case "is_subscription": + out.IsSubscription = bool(in.Bool()) + case "avatar": + out.Avatar = Picture(in.String()) + case "pics": + if in.IsNull() { + in.Skip() + out.Pics = nil + } else { + in.Delim('[') + if out.Pics == nil { + if !in.IsDelim(']') { + out.Pics = make([]Picture, 0, 4) + } else { + out.Pics = []Picture{} + } + } else { + out.Pics = (out.Pics)[:0] + } + for !in.IsDelim(']') { + var v1 Picture + v1 = Picture(in.String()) + out.Pics = append(out.Pics, v1) + in.WantComma() + } + in.Delim(']') + } + case "posts": + if in.IsNull() { + in.Skip() + out.Posts = nil + } else { + in.Delim('[') + if out.Posts == nil { + if !in.IsDelim(']') { + out.Posts = make([]*Post, 0, 8) + } else { + out.Posts = []*Post{} + } + } else { + out.Posts = (out.Posts)[:0] + } + for !in.IsDelim(']') { + var v2 *Post + if in.IsNull() { + in.Skip() + v2 = nil + } else { + if v2 == nil { + v2 = new(Post) + } + (*v2).UnmarshalEasyJSON(in) + } + out.Posts = append(out.Posts, v2) + in.WantComma() + } + in.Delim(']') + } + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson521a5691EncodeGithubCom20242BetterCallFirewallInternalModels1(out *jwriter.Writer, in FullProfile) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"id\":" + out.RawString(prefix[1:]) + out.Uint32(uint32(in.ID)) + } + { + const prefix string = ",\"first_name\":" + out.RawString(prefix) + out.String(string(in.FirstName)) + } + { + const prefix string = ",\"last_name\":" + out.RawString(prefix) + out.String(string(in.LastName)) + } + { + const prefix string = ",\"bio\":" + out.RawString(prefix) + out.String(string(in.Bio)) + } + { + const prefix string = ",\"is_author\":" + out.RawString(prefix) + out.Bool(bool(in.IsAuthor)) + } + { + const prefix string = ",\"is_friend\":" + out.RawString(prefix) + out.Bool(bool(in.IsFriend)) + } + { + const prefix string = ",\"is_subscriber\":" + out.RawString(prefix) + out.Bool(bool(in.IsSubscriber)) + } + { + const prefix string = ",\"is_subscription\":" + out.RawString(prefix) + out.Bool(bool(in.IsSubscription)) + } + { + const prefix string = ",\"avatar\":" + out.RawString(prefix) + out.String(string(in.Avatar)) + } + { + const prefix string = ",\"pics\":" + out.RawString(prefix) + if in.Pics == nil && (out.Flags&jwriter.NilSliceAsEmpty) == 0 { + out.RawString("null") + } else { + out.RawByte('[') + for v3, v4 := range in.Pics { + if v3 > 0 { + out.RawByte(',') + } + out.String(string(v4)) + } + out.RawByte(']') + } + } + { + const prefix string = ",\"posts\":" + out.RawString(prefix) + if in.Posts == nil && (out.Flags&jwriter.NilSliceAsEmpty) == 0 { + out.RawString("null") + } else { + out.RawByte('[') + for v5, v6 := range in.Posts { + if v5 > 0 { + out.RawByte(',') + } + if v6 == nil { + out.RawString("null") + } else { + (*v6).MarshalEasyJSON(out) + } + } + out.RawByte(']') + } + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v FullProfile) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson521a5691EncodeGithubCom20242BetterCallFirewallInternalModels1(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v FullProfile) MarshalEasyJSON(w *jwriter.Writer) { + easyjson521a5691EncodeGithubCom20242BetterCallFirewallInternalModels1(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *FullProfile) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson521a5691DecodeGithubCom20242BetterCallFirewallInternalModels1(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *FullProfile) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson521a5691DecodeGithubCom20242BetterCallFirewallInternalModels1(l, v) +} +func easyjson521a5691DecodeGithubCom20242BetterCallFirewallInternalModels2(in *jlexer.Lexer, out *ChangePasswordReq) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "old_password": + out.OldPassword = string(in.String()) + case "new_password": + out.NewPassword = string(in.String()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson521a5691EncodeGithubCom20242BetterCallFirewallInternalModels2(out *jwriter.Writer, in ChangePasswordReq) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"old_password\":" + out.RawString(prefix[1:]) + out.String(string(in.OldPassword)) + } + { + const prefix string = ",\"new_password\":" + out.RawString(prefix) + out.String(string(in.NewPassword)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v ChangePasswordReq) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson521a5691EncodeGithubCom20242BetterCallFirewallInternalModels2(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v ChangePasswordReq) MarshalEasyJSON(w *jwriter.Writer) { + easyjson521a5691EncodeGithubCom20242BetterCallFirewallInternalModels2(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *ChangePasswordReq) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson521a5691DecodeGithubCom20242BetterCallFirewallInternalModels2(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *ChangePasswordReq) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson521a5691DecodeGithubCom20242BetterCallFirewallInternalModels2(l, v) +} diff --git a/internal/models/session.go b/internal/models/session.go index 162f6238..334d688e 100644 --- a/internal/models/session.go +++ b/internal/models/session.go @@ -9,6 +9,7 @@ import ( "github.com/2024_2_BetterCallFirewall/pkg/my_err" ) +//easyjson:skip type Session struct { ID string UserID uint32 diff --git a/internal/models/session_easyjson.go b/internal/models/session_easyjson.go new file mode 100644 index 00000000..4d92eaa6 --- /dev/null +++ b/internal/models/session_easyjson.go @@ -0,0 +1,18 @@ +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package models + +import ( + json "encoding/json" + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) diff --git a/internal/models/sticker_request.go b/internal/models/sticker_request.go new file mode 100644 index 00000000..fe23396b --- /dev/null +++ b/internal/models/sticker_request.go @@ -0,0 +1,5 @@ +package models + +type StickerRequest struct { + File string `json:"file"` +} diff --git a/internal/models/user.go b/internal/models/user.go index 8cc6c5ca..ffc9686d 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -2,6 +2,7 @@ package models type Picture string +//easyjson:json type User struct { ID uint32 `json:"id"` Email string `json:"email"` diff --git a/internal/models/user_easyjson.go b/internal/models/user_easyjson.go new file mode 100644 index 00000000..2c9409a9 --- /dev/null +++ b/internal/models/user_easyjson.go @@ -0,0 +1,120 @@ +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package models + +import ( + json "encoding/json" + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) + +func easyjson9e1087fdDecodeGithubCom20242BetterCallFirewallInternalModels(in *jlexer.Lexer, out *User) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "id": + out.ID = uint32(in.Uint32()) + case "email": + out.Email = string(in.String()) + case "password": + out.Password = string(in.String()) + case "first_name": + out.FirstName = string(in.String()) + case "last_name": + out.LastName = string(in.String()) + case "avatar": + out.Avatar = Picture(in.String()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson9e1087fdEncodeGithubCom20242BetterCallFirewallInternalModels(out *jwriter.Writer, in User) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"id\":" + out.RawString(prefix[1:]) + out.Uint32(uint32(in.ID)) + } + { + const prefix string = ",\"email\":" + out.RawString(prefix) + out.String(string(in.Email)) + } + { + const prefix string = ",\"password\":" + out.RawString(prefix) + out.String(string(in.Password)) + } + { + const prefix string = ",\"first_name\":" + out.RawString(prefix) + out.String(string(in.FirstName)) + } + { + const prefix string = ",\"last_name\":" + out.RawString(prefix) + out.String(string(in.LastName)) + } + { + const prefix string = ",\"avatar\":" + out.RawString(prefix) + out.String(string(in.Avatar)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v User) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson9e1087fdEncodeGithubCom20242BetterCallFirewallInternalModels(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v User) MarshalEasyJSON(w *jwriter.Writer) { + easyjson9e1087fdEncodeGithubCom20242BetterCallFirewallInternalModels(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *User) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson9e1087fdDecodeGithubCom20242BetterCallFirewallInternalModels(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *User) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson9e1087fdDecodeGithubCom20242BetterCallFirewallInternalModels(l, v) +} diff --git a/internal/post/controller/controller.go b/internal/post/controller/controller.go index 4b6fd97b..68fe211d 100644 --- a/internal/post/controller/controller.go +++ b/internal/post/controller/controller.go @@ -2,31 +2,42 @@ package controller import ( "context" - "encoding/json" "errors" "fmt" "math" "net/http" "strconv" + "strings" + "time" "github.com/gorilla/mux" + "github.com/mailru/easyjson" + "github.com/microcosm-cc/bluemonday" + "github.com/2024_2_BetterCallFirewall/internal/middleware" "github.com/2024_2_BetterCallFirewall/internal/models" "github.com/2024_2_BetterCallFirewall/pkg/my_err" ) +const ( + postIDkey = "id" + commentIDKey = "comment_id" + imagePrefix = "/image/" + filePrefix = "/files/" +) + //go:generate mockgen -destination=mock.go -source=$GOFILE -package=${GOPACKAGE} type PostService interface { - Create(ctx context.Context, post *models.Post) (uint32, error) - Get(ctx context.Context, postID, userID uint32) (*models.Post, error) - Update(ctx context.Context, post *models.Post) error + Create(ctx context.Context, post *models.PostDto) (uint32, error) + Get(ctx context.Context, postID, userID uint32) (*models.PostDto, error) + Update(ctx context.Context, post *models.PostDto) error Delete(ctx context.Context, postID uint32) error - GetBatch(ctx context.Context, lastID, userID uint32) ([]*models.Post, error) - GetBatchFromFriend(ctx context.Context, userID uint32, lastID uint32) ([]*models.Post, error) + GetBatch(ctx context.Context, lastID, userID uint32) ([]*models.PostDto, error) + GetBatchFromFriend(ctx context.Context, userID uint32, lastID uint32) ([]*models.PostDto, error) GetPostAuthorID(ctx context.Context, postID uint32) (uint32, error) - GetCommunityPost(ctx context.Context, communityID, userID, lastID uint32) ([]*models.Post, error) - CreateCommunityPost(ctx context.Context, post *models.Post) (uint32, error) + GetCommunityPost(ctx context.Context, communityID, userID, lastID uint32) ([]*models.PostDto, error) + CreateCommunityPost(ctx context.Context, post *models.PostDto) (uint32, error) CheckAccessToCommunity(ctx context.Context, userID uint32, communityID uint32) bool SetLikeToPost(ctx context.Context, postID uint32, userID uint32) error @@ -34,6 +45,13 @@ type PostService interface { CheckLikes(ctx context.Context, postID, userID uint32) (bool, error) } +type CommentService interface { + Comment(ctx context.Context, userID, postID uint32, comment *models.ContentDto) (*models.CommentDto, error) + DeleteComment(ctx context.Context, commentID, userID uint32) error + EditComment(ctx context.Context, commentID, userID uint32, comment *models.ContentDto) error + GetComments(ctx context.Context, postID, lastID uint32, newest bool) ([]*models.CommentDto, error) +} + type Responder interface { OutputJSON(w http.ResponseWriter, data any, requestId string) OutputNoMoreContentJSON(w http.ResponseWriter, requestId string) @@ -44,21 +62,38 @@ type Responder interface { } type PostController struct { - postService PostService - responder Responder + postService PostService + commentService CommentService + responder Responder } -func NewPostController(service PostService, responder Responder) *PostController { +func NewPostController(service PostService, commentService CommentService, responder Responder) *PostController { return &PostController{ - postService: service, - responder: responder, + postService: service, + commentService: commentService, + responder: responder, } } +func sanitize(input string) string { + sanitizer := bluemonday.UGCPolicy() + cleaned := sanitizer.Sanitize(input) + return cleaned +} + +func sanitizeFiles(pics []models.Picture) []models.Picture { + var results []models.Picture + for _, pic := range pics { + results = append(results, models.Picture(sanitize(string(pic)))) + } + + return results +} + func (pc *PostController) Create(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) - comunity = r.URL.Query().Get("community") + reqID, ok = r.Context().Value(middleware.RequestKey).(string) + comunity = sanitize(r.URL.Query().Get("community")) id uint32 err error ) @@ -73,10 +108,6 @@ func (pc *PostController) Create(w http.ResponseWriter, r *http.Request) { return } - if len(newPost.PostContent.Text) > 499 { - pc.responder.ErrorBadRequest(w, my_err.ErrPostTooLong, reqID) - return - } if comunity != "" { comID, err := strconv.ParseUint(comunity, 10, 32) if err != nil { @@ -104,16 +135,22 @@ func (pc *PostController) Create(w http.ResponseWriter, r *http.Request) { newPost.ID = id - pc.responder.OutputJSON(w, newPost, reqID) + post := newPost.FromDto() + post.PostContent.Text = sanitize(post.PostContent.Text) + post.PostContent.File = sanitizeFiles(post.PostContent.File) + post.Header.Avatar = models.Picture(sanitize(string(post.Header.Avatar))) + post.Header.Author = sanitize(post.Header.Author) + + pc.responder.OutputJSON(w, post, reqID) } func (pc *PostController) GetOne(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { pc.responder.LogError(my_err.ErrInvalidContext, "") } - postID, err := getIDFromURL(r) + postID, err := getIDFromURL(r, postIDkey) if err != nil { pc.responder.ErrorBadRequest(w, err, reqID) return @@ -136,15 +173,21 @@ func (pc *PostController) GetOne(w http.ResponseWriter, r *http.Request) { return } } + newPost := post.FromDto() - pc.responder.OutputJSON(w, post, reqID) + newPost.PostContent.Text = sanitize(newPost.PostContent.Text) + newPost.PostContent.File = sanitizeFiles(newPost.PostContent.File) + newPost.Header.Avatar = models.Picture(sanitize(string(newPost.Header.Avatar))) + newPost.Header.Author = sanitize(newPost.Header.Author) + + pc.responder.OutputJSON(w, newPost, reqID) } func (pc *PostController) Update(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) - id, err = getIDFromURL(r) - community = r.URL.Query().Get("community") + reqID, ok = r.Context().Value(middleware.RequestKey).(string) + id, err = getIDFromURL(r, postIDkey) + community = sanitize(r.URL.Query().Get("community")) ) if err != nil { @@ -178,10 +221,6 @@ func (pc *PostController) Update(w http.ResponseWriter, r *http.Request) { pc.responder.ErrorBadRequest(w, err, reqID) return } - if len(post.PostContent.Text) > 499 { - pc.responder.ErrorBadRequest(w, my_err.ErrPostTooLong, reqID) - return - } post.ID = id if err := pc.postService.Update(r.Context(), post); err != nil { @@ -192,15 +231,21 @@ func (pc *PostController) Update(w http.ResponseWriter, r *http.Request) { pc.responder.ErrorInternal(w, err, reqID) return } + newPost := post.FromDto() - pc.responder.OutputJSON(w, post, reqID) + newPost.PostContent.Text = sanitize(newPost.PostContent.Text) + newPost.PostContent.File = sanitizeFiles(newPost.PostContent.File) + newPost.Header.Avatar = models.Picture(sanitize(string(newPost.Header.Avatar))) + newPost.Header.Author = sanitize(newPost.Header.Author) + + pc.responder.OutputJSON(w, newPost, reqID) } func (pc *PostController) Delete(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) - postID, err = getIDFromURL(r) - community = r.URL.Query().Get("community") + reqID, ok = r.Context().Value(middleware.RequestKey).(string) + postID, err = getIDFromURL(r, postIDkey) + community = sanitize(r.URL.Query().Get("community")) ) if !ok { @@ -243,10 +288,10 @@ func (pc *PostController) Delete(w http.ResponseWriter, r *http.Request) { func (pc *PostController) GetBatchPosts(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) - section = r.URL.Query().Get("section") - communityID = r.URL.Query().Get("community") - posts []*models.Post + reqID, ok = r.Context().Value(middleware.RequestKey).(string) + section = sanitize(r.URL.Query().Get("section")) + communityID = sanitize(r.URL.Query().Get("community")) + posts []*models.PostDto intLastID uint64 err error id uint64 @@ -302,35 +347,57 @@ func (pc *PostController) GetBatchPosts(w http.ResponseWriter, r *http.Request) } } - pc.responder.OutputJSON(w, posts, reqID) + res := make([]models.Post, 0, len(posts)) + for _, post := range posts { + newPost := post.FromDto() + + newPost.PostContent.Text = sanitize(newPost.PostContent.Text) + newPost.PostContent.File = sanitizeFiles(newPost.PostContent.File) + newPost.Header.Avatar = models.Picture(sanitize(string(newPost.Header.Avatar))) + newPost.Header.Author = sanitize(newPost.Header.Author) + res = append(res, newPost) + } + + pc.responder.OutputJSON(w, res, reqID) } -func (pc *PostController) getPostFromBody(r *http.Request) (*models.Post, error) { +func (pc *PostController) getPostFromBody(r *http.Request) (*models.PostDto, error) { var newPost models.Post - err := json.NewDecoder(r.Body).Decode(&newPost) + err := easyjson.UnmarshalFromReader(r.Body, &newPost) if err != nil { return nil, err } + newPost.PostContent.Text = sanitize(newPost.PostContent.Text) + newPost.PostContent.File = sanitizeFiles(newPost.PostContent.File) + newPost.Header.Avatar = models.Picture(sanitize(string(newPost.Header.Avatar))) + newPost.Header.Author = sanitize(newPost.Header.Author) + + if !validateContent(newPost.PostContent) { + return nil, my_err.ErrBadPostOrComment + } + sess, err := models.SessionFromContext(r.Context()) if err != nil { return nil, err } newPost.Header.AuthorID = sess.UserID + post := newPost.ToDto() - return &newPost, nil + return &post, nil } -func getIDFromURL(r *http.Request) (uint32, error) { +func getIDFromURL(r *http.Request, key string) (uint32, error) { vars := mux.Vars(r) - id := vars["id"] - if id == "" { + id := vars[key] + clearID := sanitize(id) + if clearID == "" { return 0, errors.New("id is empty") } - uid, err := strconv.ParseUint(id, 10, 32) + uid, err := strconv.ParseUint(clearID, 10, 32) if err != nil { return 0, err } @@ -340,12 +407,12 @@ func getIDFromURL(r *http.Request) (uint32, error) { func getLastID(r *http.Request) (uint64, error) { lastID := r.URL.Query().Get("id") - - if lastID == "" { + clearLastID := sanitize(lastID) + if clearLastID == "" { return math.MaxInt32, nil } - intLastID, err := strconv.ParseUint(lastID, 10, 32) + intLastID, err := strconv.ParseUint(clearLastID, 10, 32) if err != nil { return 0, err } @@ -383,12 +450,12 @@ func (pc *PostController) checkAccessToCommunity(r *http.Request, communityID ui } func (pc *PostController) SetLikeOnPost(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { pc.responder.LogError(my_err.ErrInvalidContext, "") } - postID, err := getIDFromURL(r) + postID, err := getIDFromURL(r, postIDkey) if err != nil { pc.responder.ErrorBadRequest(w, err, reqID) return @@ -421,12 +488,12 @@ func (pc *PostController) SetLikeOnPost(w http.ResponseWriter, r *http.Request) } func (pc *PostController) DeleteLikeFromPost(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { pc.responder.LogError(my_err.ErrInvalidContext, "") } - postID, err := getIDFromURL(r) + postID, err := getIDFromURL(r, postIDkey) if err != nil { pc.responder.ErrorBadRequest(w, err, reqID) return @@ -456,3 +523,215 @@ func (pc *PostController) DeleteLikeFromPost(w http.ResponseWriter, r *http.Requ pc.responder.OutputJSON(w, "like is unset from post", reqID) } + +func (pc *PostController) Comment(w http.ResponseWriter, r *http.Request) { + reqID, ok := r.Context().Value(middleware.RequestKey).(string) + if !ok { + pc.responder.LogError(my_err.ErrInvalidContext, "") + } + + postID, err := getIDFromURL(r, postIDkey) + if err != nil { + pc.responder.ErrorBadRequest(w, err, reqID) + return + } + + sess, errSession := models.SessionFromContext(r.Context()) + if errSession != nil { + pc.responder.ErrorBadRequest(w, errSession, reqID) + return + } + + var content models.Content + if err := easyjson.UnmarshalFromReader(r.Body, &content); err != nil { + pc.responder.ErrorBadRequest(w, err, reqID) + return + } + + content.Text = sanitize(content.Text) + content.File = sanitizeFiles(content.File) + + if !validateContent(content) { + pc.responder.ErrorBadRequest(w, my_err.ErrBadPostOrComment, reqID) + return + } + + contentDto := content.ToDto() + newComment, err := pc.commentService.Comment(r.Context(), sess.UserID, postID, &contentDto) + if err != nil { + pc.responder.ErrorInternal(w, err, reqID) + return + } + newComment.Content.CreatedAt = time.Now() + newComment.Content.Text = sanitize(newComment.Content.Text) + newComment.Content.File = models.Picture(sanitize(string(newComment.Content.File))) + newComment.Header.Avatar = models.Picture(sanitize(string(newComment.Header.Avatar))) + newComment.Header.Author = sanitize(newComment.Header.Author) + + pc.responder.OutputJSON(w, newComment.FromDto(), reqID) +} + +func (pc *PostController) DeleteComment(w http.ResponseWriter, r *http.Request) { + reqID, ok := r.Context().Value(middleware.RequestKey).(string) + if !ok { + pc.responder.LogError(my_err.ErrInvalidContext, "") + } + + commentID, err := getIDFromURL(r, commentIDKey) + if err != nil { + pc.responder.ErrorBadRequest(w, err, reqID) + return + } + + sess, errSession := models.SessionFromContext(r.Context()) + if errSession != nil { + pc.responder.ErrorBadRequest(w, errSession, reqID) + return + } + + err = pc.commentService.DeleteComment(r.Context(), commentID, sess.UserID) + if err != nil { + if errors.Is(err, my_err.ErrAccessDenied) { + pc.responder.ErrorBadRequest(w, err, reqID) + return + } + + if errors.Is(err, my_err.ErrWrongComment) { + pc.responder.ErrorBadRequest(w, err, reqID) + return + } + + pc.responder.ErrorInternal(w, err, reqID) + return + } + + pc.responder.OutputJSON(w, "comment is deleted", reqID) +} + +func (pc *PostController) EditComment(w http.ResponseWriter, r *http.Request) { + reqID, ok := r.Context().Value(middleware.RequestKey).(string) + if !ok { + pc.responder.LogError(my_err.ErrInvalidContext, "") + } + + commentID, err := getIDFromURL(r, commentIDKey) + if err != nil { + pc.responder.ErrorBadRequest(w, err, reqID) + return + } + + sess, errSession := models.SessionFromContext(r.Context()) + if errSession != nil { + pc.responder.ErrorBadRequest(w, errSession, reqID) + return + } + + var content models.Content + if err := easyjson.UnmarshalFromReader(r.Body, &content); err != nil { + pc.responder.ErrorBadRequest(w, err, reqID) + return + } + + content.Text = sanitize(content.Text) + content.File = sanitizeFiles(content.File) + + if !validateContent(content) { + pc.responder.ErrorBadRequest(w, my_err.ErrBadPostOrComment, reqID) + return + } + + contentDto := content.ToDto() + if err := pc.commentService.EditComment(r.Context(), commentID, sess.UserID, &contentDto); err != nil { + if errors.Is(err, my_err.ErrAccessDenied) { + pc.responder.ErrorBadRequest(w, err, reqID) + return + } + + if errors.Is(err, my_err.ErrWrongComment) { + pc.responder.ErrorBadRequest(w, err, reqID) + return + } + + pc.responder.ErrorInternal(w, err, reqID) + return + } + + pc.responder.OutputJSON(w, "comment is updated", reqID) +} + +func (pc *PostController) GetComments(w http.ResponseWriter, r *http.Request) { + reqID, ok := r.Context().Value(middleware.RequestKey).(string) + if !ok { + pc.responder.LogError(my_err.ErrInvalidContext, "") + } + + postID, err := getIDFromURL(r, postIDkey) + if err != nil { + pc.responder.ErrorBadRequest(w, err, reqID) + return + } + + lastId, err := getLastID(r) + if err != nil { + pc.responder.ErrorBadRequest(w, err, reqID) + return + } + sorting := sanitize(r.URL.Query().Get("sort")) + newest := true + if sorting == "old" { + newest = false + if lastId == math.MaxInt32 { + lastId = 0 + } + } + + comments, err := pc.commentService.GetComments(r.Context(), postID, uint32(lastId), newest) + if err != nil { + if errors.Is(err, my_err.ErrNoMoreContent) { + pc.responder.OutputNoMoreContentJSON(w, reqID) + return + } + + pc.responder.ErrorInternal(w, err, reqID) + return + } + + res := make([]models.Comment, 0, len(comments)) + for _, comment := range comments { + res = append(res, comment.FromDto()) + } + + for _, newComment := range res { + newComment.Content.Text = sanitize(newComment.Content.Text) + newComment.Content.File = sanitizeFiles(newComment.Content.File) + newComment.Header.Avatar = models.Picture(sanitize(string(newComment.Header.Avatar))) + newComment.Header.Author = sanitize(newComment.Header.Author) + } + + pc.responder.OutputJSON(w, res, reqID) +} + +func validateContent(content models.Content) bool { + if len(content.File) == 0 && len(content.Text) == 0 { + return false + } + + return validateFile(content.File) && len([]rune(content.Text)) < 1000 +} + +func validateFile(filepaths []models.Picture) bool { + if len(filepaths) > 10 { + return false + } + + for _, f := range filepaths { + if len([]rune(f)) > 100 { + return false + } + if !(strings.HasPrefix(string(f), filePrefix) || strings.HasPrefix(string(f), imagePrefix)) { + return false + } + } + + return true +} diff --git a/internal/post/controller/controller_test.go b/internal/post/controller/controller_test.go index 7798c478..d46e7374 100644 --- a/internal/post/controller/controller_test.go +++ b/internal/post/controller/controller_test.go @@ -18,11 +18,12 @@ import ( func getController(ctrl *gomock.Controller) (*PostController, *mocks) { m := &mocks{ - postService: NewMockPostService(ctrl), - responder: NewMockResponder(ctrl), + postService: NewMockPostService(ctrl), + responder: NewMockResponder(ctrl), + commentService: NewMockCommentService(ctrl), } - return NewPostController(m.postService, m.responder), m + return NewPostController(m.postService, m.commentService, m.responder), m } func TestNewPostController(t *testing.T) { @@ -53,17 +54,23 @@ func TestCreate(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "2", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/feed", - bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/feed", + bytes.NewBuffer([]byte(`{"id":1, "post_content":{"text":"text"}}`)), + ) w := httptest.NewRecorder() res := &Request{r: req, w: w} return res, nil @@ -79,17 +86,23 @@ func TestCreate(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "3", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/feed", - bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/feed", + bytes.NewBuffer([]byte(`{"id":1, "post_content":{"text":"text"}}`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) res := &Request{r: req, w: w} @@ -107,17 +120,23 @@ func TestCreate(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.postService.EXPECT().Create(gomock.Any(), gomock.Any()).Return(uint32(0), errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "4", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/feed", - bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/feed", + bytes.NewBuffer([]byte(`{"id":1, "post_content":{"text":"text"}}`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) res := &Request{r: req, w: w} @@ -135,17 +154,23 @@ func TestCreate(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.postService.EXPECT().Create(gomock.Any(), gomock.Any()).Return(uint32(2), nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "5", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/feed?community=ljkhkg", - bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/feed?community=ljkhkg", + bytes.NewBuffer([]byte(`{"id":1}`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) res := &Request{r: req, w: w} @@ -162,17 +187,23 @@ func TestCreate(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "6", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/feed?community=10", - bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/feed?community=10", + bytes.NewBuffer([]byte(`{"id":1, "post_content":{"text":"text"}}`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) res := &Request{r: req, w: w} @@ -190,17 +221,23 @@ func TestCreate(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.postService.EXPECT().CheckAccessToCommunity(gomock.Any(), gomock.Any(), gomock.Any()).Return(false) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "7", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/feed?community=10", - bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/feed?community=10", + bytes.NewBuffer([]byte(`{"id":1, "post_content":{"text":"text"}}`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) res := &Request{r: req, w: w} @@ -218,18 +255,26 @@ func TestCreate(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.postService.EXPECT().CheckAccessToCommunity(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - m.postService.EXPECT().CreateCommunityPost(gomock.Any(), gomock.Any()).Return(uint32(0), errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.postService.EXPECT().CreateCommunityPost(gomock.Any(), gomock.Any()).Return( + uint32(0), errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "8", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/feed?community=10", - bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/feed?community=10", + bytes.NewBuffer([]byte(`{"id":1, "post_content":{"text":"text"}}`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) res := &Request{r: req, w: w} @@ -248,17 +293,25 @@ func TestCreate(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.postService.EXPECT().CheckAccessToCommunity(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) m.postService.EXPECT().CreateCommunityPost(gomock.Any(), gomock.Any()).Return(uint32(10), nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "9", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPost, "/api/v1/feed?community=10", - bytes.NewBuffer([]byte(`{"post_content":{"text":"new post Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed tellus arcu, vulputate rutrum enim vitae, tincidunt imperdiet tellus. Aenean vulputate elit consequat lorem pellentesque bibendum. Donec sed mi posuere dolor semper mollis eu eget dolor. Proin et eleifend magna. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Curabitur tempus ultricies mi, eget malesuada metus. Nam sit amet felis nec dolor vehicula dapibus gravida in nunc. Mauris turpis et. "}}`))) + req := httptest.NewRequest( + http.MethodPost, "/api/v1/feed?community=10", + bytes.NewBuffer( + []byte(`{"post_content":{"text":"new post Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed tellus arcu, vulputate rutrum enim vitae, tincidunt imperdiet tellus. Aenean vulputate elit consequat lorem pellentesque bibendum. Donec sed mi posuere dolor semper mollis eu eget dolor. Proin et eleifend magna. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Curabitur tempus ultricies mi, eget malesuada metus. Nam sit amet felis nec dolor vehicula dapibus gravida in nunc. Mauris turpis et. new post Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed tellus arcu, vulputate rutrum enim vitae, tincidunt imperdiet tellus. Aenean vulputate elit consequat lorem pellentesque bibendum. Donec sed mi posuere dolor semper mollis eu eget dolor. Proin et eleifend magna. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Curabitur tempus ultricies mi, eget malesuada metus. Nam sit amet felis nec dolor vehicula dapibus gravida in nunc. Mauris turpis et. wiukyjctg"}}`), + ), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) res := &Request{r: req, w: w} @@ -275,40 +328,112 @@ func TestCreate(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "10", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodPost, "/api/v1/feed?community=10", + bytes.NewBuffer([]byte(`{"post_content":{"text":""}}`)), + ) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.Create(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "11", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodPost, "/api/v1/feed?community=text", + bytes.NewBuffer([]byte(`{"post_content":{"text":"text"}}`)), + ) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.Create(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -333,10 +458,14 @@ func TestGetOne(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -359,10 +488,14 @@ func TestGetOne(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -387,10 +520,14 @@ func TestGetOne(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.postService.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, my_err.ErrPostNotFound) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -415,10 +552,14 @@ func TestGetOne(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.postService.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -442,11 +583,15 @@ func TestGetOne(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.postService.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.postService.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(&models.PostDto{}, nil) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -469,40 +614,46 @@ func TestGetOne(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -526,10 +677,14 @@ func TestUpdate(t *testing.T) { }, ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -552,10 +707,14 @@ func TestUpdate(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -579,11 +738,17 @@ func TestUpdate(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.postService.EXPECT().GetPostAuthorID(gomock.Any(), gomock.Any()).Return(uint32(0), errors.New("error")) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.postService.EXPECT().GetPostAuthorID(gomock.Any(), gomock.Any()).Return( + uint32(0), errors.New("error"), + ) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -608,17 +773,23 @@ func TestUpdate(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().GetPostAuthorID(gomock.Any(), gomock.Any()).Return(uint32(10), nil) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "5", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/10", - bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPut, "/api/v1/feed/10", + bytes.NewBuffer([]byte(`{"id":1, "post_content":{"text":"text"}}`)), + ) w := httptest.NewRecorder() req = mux.SetURLVars(req, map[string]string{"id": "10"}) req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) @@ -638,17 +809,23 @@ func TestUpdate(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().GetPostAuthorID(gomock.Any(), gomock.Any()).Return(uint32(1), nil) m.postService.EXPECT().Update(gomock.Any(), gomock.Any()).Return(my_err.ErrPostNotFound) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "6", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/10", - bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPut, "/api/v1/feed/10", + bytes.NewBuffer([]byte(`{"id":1, "post_content":{"text":"text"}}`)), + ) w := httptest.NewRecorder() req = mux.SetURLVars(req, map[string]string{"id": "10"}) req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) @@ -668,17 +845,23 @@ func TestUpdate(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().GetPostAuthorID(gomock.Any(), gomock.Any()).Return(uint32(1), nil) m.postService.EXPECT().Update(gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "7", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/10", - bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPut, "/api/v1/feed/10", + bytes.NewBuffer([]byte(`{"id":1, "post_content":{"text":"text"}}`)), + ) w := httptest.NewRecorder() req = mux.SetURLVars(req, map[string]string{"id": "10"}) req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) @@ -698,17 +881,23 @@ func TestUpdate(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().GetPostAuthorID(gomock.Any(), gomock.Any()).Return(uint32(1), nil) m.postService.EXPECT().Update(gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "8", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/10?community=nkljbkvhj", - bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPut, "/api/v1/feed/10?community=nkljbkvhj", + bytes.NewBuffer([]byte(`{"id":1}`)), + ) w := httptest.NewRecorder() req = mux.SetURLVars(req, map[string]string{"id": "10"}) req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) @@ -726,17 +915,23 @@ func TestUpdate(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "9", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/10?community=10", - bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPut, "/api/v1/feed/10?community=10", + bytes.NewBuffer([]byte(`{"id":1}`)), + ) w := httptest.NewRecorder() req = mux.SetURLVars(req, map[string]string{"id": "10"}) req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) @@ -755,17 +950,23 @@ func TestUpdate(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().CheckAccessToCommunity(gomock.Any(), gomock.Any(), gomock.Any()).Return(false) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "10", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/10?community=10", - bytes.NewBuffer([]byte(`{"id":1}`))) + req := httptest.NewRequest( + http.MethodPut, "/api/v1/feed/10?community=10", + bytes.NewBuffer([]byte(`{"id":1, "post_content":{"text":"text"}}`)), + ) w := httptest.NewRecorder() req = mux.SetURLVars(req, map[string]string{"id": "10"}) req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) @@ -785,17 +986,23 @@ func TestUpdate(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().CheckAccessToCommunity(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) m.postService.EXPECT().Update(gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "11", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/10?community=10", - bytes.NewBuffer([]byte(`{"id"`))) + req := httptest.NewRequest( + http.MethodPut, "/api/v1/feed/10?community=10", + bytes.NewBuffer([]byte(`{"id"`)), + ) w := httptest.NewRecorder() req = mux.SetURLVars(req, map[string]string{"id": "10"}) req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) @@ -814,17 +1021,25 @@ func TestUpdate(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().CheckAccessToCommunity(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "12", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/10?community=1", - bytes.NewBuffer([]byte(`{"post_content":{"text":"new post Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed tellus arcu, vulputate rutrum enim vitae, tincidunt imperdiet tellus. Aenean vulputate elit consequat lorem pellentesque bibendum. Donec sed mi posuere dolor semper mollis eu eget dolor. Proin et eleifend magna. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Curabitur tempus ultricies mi, eget malesuada metus. Nam sit amet felis nec dolor vehicula dapibus gravida in nunc. Mauris turpis et. "}}`))) + req := httptest.NewRequest( + http.MethodPut, "/api/v1/feed/10?community=1", + bytes.NewBuffer( + []byte(`{"post_content":{"text":"new post Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed tellus arcu, vulputate rutrum enim vitae, tincidunt imperdiet tellus. Aenean vulputate elit consequat lorem pellentesque bibendum. Donec sed mi posuere dolor semper mollis eu eget dolor. Proin et eleifend magna. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Curabitur tempus ultricies mi, eget malesuada metus. Nam sit amet felis nec dolor vehicula dapibus gravida in nunc. Mauris turpis et. new post Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed tellus arcu, vulputate rutrum enim vitae, tincidunt imperdiet tellus. Aenean vulputate elit consequat lorem pellentesque bibendum. Donec sed mi posuere dolor semper mollis eu eget dolor. Proin et eleifend magna. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Curabitur tempus ultricies mi, eget malesuada metus. Nam sit amet felis nec dolor vehicula dapibus gravida in nunc. Mauris turpis et. new post Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed tellus arcu, vulputate rutrum enim vitae, tincidunt imperdiet tellus. Aenean vulputate elit consequat lorem pellentesque bibendum. Donec sed mi posuere dolor semper mollis eu eget dolor. Proin et eleifend magna. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Curabitur tempus ultricies mi, eget malesuada metus. Nam sit amet felis nec dolor vehicula dapibus gravida in nunc. Mauris turpis et. "}}`), + ), + ) w := httptest.NewRecorder() req = mux.SetURLVars(req, map[string]string{"id": "10"}) req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) @@ -843,40 +1058,46 @@ func TestUpdate(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().CheckAccessToCommunity(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -901,10 +1122,14 @@ func TestDelete(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -927,10 +1152,14 @@ func TestDelete(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -956,10 +1185,14 @@ func TestDelete(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().GetPostAuthorID(gomock.Any(), gomock.Any()).Return(uint32(1), nil) m.postService.EXPECT().Delete(gomock.Any(), gomock.Any()).Return(my_err.ErrPostNotFound) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -985,10 +1218,14 @@ func TestDelete(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().GetPostAuthorID(gomock.Any(), gomock.Any()).Return(uint32(1), nil) m.postService.EXPECT().Delete(gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1014,10 +1251,14 @@ func TestDelete(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().GetPostAuthorID(gomock.Any(), gomock.Any()).Return(uint32(1), nil) m.postService.EXPECT().Delete(gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1041,10 +1282,14 @@ func TestDelete(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, error, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, error, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1067,10 +1312,14 @@ func TestDelete(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, error, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, error, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1096,40 +1345,46 @@ func TestDelete(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().CheckAccessToCommunity(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) m.postService.EXPECT().Delete(gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, error, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, error, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1154,10 +1409,14 @@ func TestGetBatchPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1179,10 +1438,14 @@ func TestGetBatchPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1205,10 +1468,14 @@ func TestGetBatchPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.postService.EXPECT().GetBatch(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, my_err.ErrNoMoreContent) - m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do(func(w, req any) { - request.w.WriteHeader(http.StatusNoContent) - }) + m.postService.EXPECT().GetBatch(gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, my_err.ErrNoMoreContent, + ) + m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do( + func(w, req any) { + request.w.WriteHeader(http.StatusNoContent) + }, + ) }, }, { @@ -1231,11 +1498,17 @@ func TestGetBatchPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.postService.EXPECT().GetBatch(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.postService.EXPECT().GetBatch(gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1259,10 +1532,14 @@ func TestGetBatchPost(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().GetBatch(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1284,10 +1561,14 @@ func TestGetBatchPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1312,10 +1593,14 @@ func TestGetBatchPost(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) m.postService.EXPECT().GetBatchFromFriend(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1338,10 +1623,14 @@ func TestGetBatchPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1364,11 +1653,23 @@ func TestGetBatchPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.postService.EXPECT().GetCommunityPost(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.postService.EXPECT().GetCommunityPost( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return( + []*models.PostDto{ + { + ID: 1, + }, + }, nil, + ) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1391,40 +1692,46 @@ func TestGetBatchPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()).Do(func(err, req any) {}) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1449,10 +1756,14 @@ func TestSetLikeOnPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1475,10 +1786,14 @@ func TestSetLikeOnPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1503,10 +1818,14 @@ func TestSetLikeOnPost(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.postService.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1530,11 +1849,17 @@ func TestSetLikeOnPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.postService.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(false, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.postService.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return( + false, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1559,11 +1884,17 @@ func TestSetLikeOnPost(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.postService.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(false, nil) - m.postService.EXPECT().SetLikeToPost(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.postService.EXPECT().SetLikeToPost( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(errors.New("error")) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1589,40 +1920,46 @@ func TestSetLikeOnPost(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.postService.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(false, nil) m.postService.EXPECT().SetLikeToPost(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1647,10 +1984,14 @@ func TestDeleteLikeFromPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1673,10 +2014,14 @@ func TestDeleteLikeFromPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1701,10 +2046,14 @@ func TestDeleteLikeFromPost(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.postService.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(false, nil) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1728,11 +2077,17 @@ func TestDeleteLikeFromPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.postService.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(false, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.postService.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return( + false, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1757,11 +2112,17 @@ func TestDeleteLikeFromPost(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.postService.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) - m.postService.EXPECT().DeleteLikeFromPost(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.postService.EXPECT().DeleteLikeFromPost( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(errors.New("error")) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1787,46 +2148,994 @@ func TestDeleteLikeFromPost(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.postService.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) m.postService.EXPECT().DeleteLikeFromPost(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +func TestSanitize(t *testing.T) { + test := "" + expected := "" + res := sanitize(test) + assert.Equal(t, expected, res) +} + +func TestSanitizeFiles(t *testing.T) { + test := []models.Picture{"", "filepath"} + expected := []models.Picture{"", "filepath"} + res := sanitizeFiles(test) + assert.Equal(t, expected, res) +} + +func TestComment(t *testing.T) { + tests := []TableTest[Response, Request]{ + { + name: "1", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPost, "/api/v1/feed/2", nil) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.Comment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "2", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPost, "/api/v1/feed/2", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{"id": "2"}) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.Comment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "3", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPost, "/api/v1/feed/2", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{"id": "2"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.Comment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "4", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodPost, "/api/v1/feed/2", bytes.NewBuffer([]byte(`{"id":1, "text":"text"}`)), + ) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{"id": "2"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.Comment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusInternalServerError, Body: "error"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.commentService.EXPECT().Comment(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("error")) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + _, _ = request.w.Write([]byte("error")) + }, + ) + }, + }, + { + name: "5", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodPost, "/api/v1/feed/2", bytes.NewBuffer([]byte(`{"id":1, "text":"text"}`)), + ) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{"id": "2"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.Comment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusOK, Body: "OK"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.commentService.EXPECT().Comment(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return( + &models.CommentDto{ + Content: models.ContentDto{ + Text: "New comment", + }, + }, nil, + ) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + _, _ = request.w.Write([]byte("OK")) + }, + ) + }, + }, + { + name: "6", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodPost, "/api/v1/feed/2", + bytes.NewBuffer([]byte(`{"file":["очень большой текст, написанный в поле файл, как он сюда попал - честно говоря хз, надо проверить валидацию на файл длину"]}`)), + ) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{"id": "2"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.Comment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + } + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +func TestGetComments(t *testing.T) { + tests := []TableTest[Response, Request]{ + { + name: "1", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/feed/2/comment", nil) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.GetComments(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "2", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/feed/2/comment?id=fnf", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{"id": "2"}) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.GetComments(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "3", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/feed/2/comment?id=4", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{"id": "2"}) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.GetComments(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusInternalServerError, Body: "error"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.commentService.EXPECT().GetComments(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("error")) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + _, _ = request.w.Write([]byte("error")) + }, + ) + }, + }, + { + name: "4", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/feed/2/comment?sort=old", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{"id": "2"}) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.GetComments(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusOK, Body: "OK"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.commentService.EXPECT().GetComments(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return( + []*models.CommentDto{ + {ID: 1}, + }, + nil, + ) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + _, _ = request.w.Write([]byte("OK")) + }, + ) + }, + }, + { + name: "5", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/feed/2/comment", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{"id": "2"}) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.GetComments(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusNoContent}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.commentService.EXPECT().GetComments(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, my_err.ErrNoMoreContent) + m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do( + func(w, req any) { + request.w.WriteHeader(http.StatusNoContent) + }, + ) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +func TestEditComment(t *testing.T) { + tests := []TableTest[Response, Request]{ + { + name: "1", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/2/", nil) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.EditComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "2", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/2/1", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{commentIDKey: "1"}) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.EditComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "3", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/2/1", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{commentIDKey: "1"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.EditComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "4", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/2/1", bytes.NewBuffer([]byte(`{"text":"1"}`))) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{commentIDKey: "1"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.EditComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.commentService.EXPECT().EditComment(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(my_err.ErrAccessDenied) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "5", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/2/1", bytes.NewBuffer([]byte(`{"text":"1"}`))) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{commentIDKey: "1"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.EditComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.commentService.EXPECT().EditComment(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(my_err.ErrWrongComment) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "6", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/2/1", bytes.NewBuffer([]byte(`{"text":"1"}`))) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{commentIDKey: "1"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.EditComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusInternalServerError, Body: "error"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.commentService.EXPECT().EditComment(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(errors.New("err")) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + _, _ = request.w.Write([]byte("error")) + }, + ) + }, + }, + { + name: "7", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPut, "/api/v1/feed/2/1", bytes.NewBuffer([]byte(`{"text":"1"}`))) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{commentIDKey: "1"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.EditComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusOK, Body: "OK"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.commentService.EXPECT().EditComment(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + _, _ = request.w.Write([]byte("OK")) + }, + ) + }, + }, + { + name: "8", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodPut, "/api/v1/feed/2/1", + bytes.NewBuffer([]byte(`{"file":["","", "", "", "", "", "", "", "", "", "", "", ""]}`)), + ) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{commentIDKey: "1"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.EditComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "9", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodPut, "/api/v1/feed/2/1", + bytes.NewBuffer([]byte(`{"file":["my wrong file"]}`)), + ) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{commentIDKey: "1"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.EditComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +func TestDeleteComment(t *testing.T) { + tests := []TableTest[Response, Request]{ + { + name: "1", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodDelete, "/api/v1/feed/2/", nil) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.DeleteComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "2", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodDelete, "/api/v1/feed/2/1", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{commentIDKey: "1"}) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.DeleteComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "3", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodDelete, "/api/v1/feed/2/1", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{commentIDKey: "1"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.DeleteComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.commentService.EXPECT().DeleteComment(gomock.Any(), gomock.Any(), gomock.Any()). + Return(my_err.ErrAccessDenied) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "4", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodDelete, "/api/v1/feed/2/1", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{commentIDKey: "1"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.DeleteComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.commentService.EXPECT().DeleteComment(gomock.Any(), gomock.Any(), gomock.Any()). + Return(my_err.ErrWrongComment) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "5", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodDelete, "/api/v1/feed/2/1", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{commentIDKey: "1"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.DeleteComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusInternalServerError, Body: "error"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.commentService.EXPECT().DeleteComment(gomock.Any(), gomock.Any(), gomock.Any()). + Return(errors.New("err")) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + _, _ = request.w.Write([]byte("error")) + }, + ) + }, + }, + { + name: "6", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodDelete, "/api/v1/feed/2/1", nil) + w := httptest.NewRecorder() + req = mux.SetURLVars(req, map[string]string{commentIDKey: "1"}) + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func(ctx context.Context, implementation *PostController, request Request) (Response, error) { + implementation.DeleteComment(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusOK, Body: "OK"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.commentService.EXPECT().DeleteComment(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + _, _ = request.w.Write([]byte("OK")) + }, + ) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } type mocks struct { - postService *MockPostService - responder *MockResponder + postService *MockPostService + responder *MockResponder + commentService *MockCommentService } type Request struct { diff --git a/internal/post/controller/mock.go b/internal/post/controller/mock.go index b7f365fa..2b4ded61 100644 --- a/internal/post/controller/mock.go +++ b/internal/post/controller/mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: controller.go +// +// Generated by this command: +// +// mockgen -destination=mock.go -source=controller.go -package=controller +// // Package controller is a generated GoMock package. package controller @@ -17,6 +22,7 @@ import ( type MockPostService struct { ctrl *gomock.Controller recorder *MockPostServiceMockRecorder + isgomock struct{} } // MockPostServiceMockRecorder is the mock recorder for MockPostService. @@ -45,7 +51,7 @@ func (m *MockPostService) CheckAccessToCommunity(ctx context.Context, userID, co } // CheckAccessToCommunity indicates an expected call of CheckAccessToCommunity. -func (mr *MockPostServiceMockRecorder) CheckAccessToCommunity(ctx, userID, communityID interface{}) *gomock.Call { +func (mr *MockPostServiceMockRecorder) CheckAccessToCommunity(ctx, userID, communityID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckAccessToCommunity", reflect.TypeOf((*MockPostService)(nil).CheckAccessToCommunity), ctx, userID, communityID) } @@ -60,13 +66,13 @@ func (m *MockPostService) CheckLikes(ctx context.Context, postID, userID uint32) } // CheckLikes indicates an expected call of CheckLikes. -func (mr *MockPostServiceMockRecorder) CheckLikes(ctx, postID, userID interface{}) *gomock.Call { +func (mr *MockPostServiceMockRecorder) CheckLikes(ctx, postID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckLikes", reflect.TypeOf((*MockPostService)(nil).CheckLikes), ctx, postID, userID) } // Create mocks base method. -func (m *MockPostService) Create(ctx context.Context, post *models.Post) (uint32, error) { +func (m *MockPostService) Create(ctx context.Context, post *models.PostDto) (uint32, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Create", ctx, post) ret0, _ := ret[0].(uint32) @@ -75,13 +81,13 @@ func (m *MockPostService) Create(ctx context.Context, post *models.Post) (uint32 } // Create indicates an expected call of Create. -func (mr *MockPostServiceMockRecorder) Create(ctx, post interface{}) *gomock.Call { +func (mr *MockPostServiceMockRecorder) Create(ctx, post any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockPostService)(nil).Create), ctx, post) } // CreateCommunityPost mocks base method. -func (m *MockPostService) CreateCommunityPost(ctx context.Context, post *models.Post) (uint32, error) { +func (m *MockPostService) CreateCommunityPost(ctx context.Context, post *models.PostDto) (uint32, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateCommunityPost", ctx, post) ret0, _ := ret[0].(uint32) @@ -90,7 +96,7 @@ func (m *MockPostService) CreateCommunityPost(ctx context.Context, post *models. } // CreateCommunityPost indicates an expected call of CreateCommunityPost. -func (mr *MockPostServiceMockRecorder) CreateCommunityPost(ctx, post interface{}) *gomock.Call { +func (mr *MockPostServiceMockRecorder) CreateCommunityPost(ctx, post any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCommunityPost", reflect.TypeOf((*MockPostService)(nil).CreateCommunityPost), ctx, post) } @@ -104,7 +110,7 @@ func (m *MockPostService) Delete(ctx context.Context, postID uint32) error { } // Delete indicates an expected call of Delete. -func (mr *MockPostServiceMockRecorder) Delete(ctx, postID interface{}) *gomock.Call { +func (mr *MockPostServiceMockRecorder) Delete(ctx, postID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockPostService)(nil).Delete), ctx, postID) } @@ -118,67 +124,67 @@ func (m *MockPostService) DeleteLikeFromPost(ctx context.Context, postID, userID } // DeleteLikeFromPost indicates an expected call of DeleteLikeFromPost. -func (mr *MockPostServiceMockRecorder) DeleteLikeFromPost(ctx, postID, userID interface{}) *gomock.Call { +func (mr *MockPostServiceMockRecorder) DeleteLikeFromPost(ctx, postID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteLikeFromPost", reflect.TypeOf((*MockPostService)(nil).DeleteLikeFromPost), ctx, postID, userID) } // Get mocks base method. -func (m *MockPostService) Get(ctx context.Context, postID, userID uint32) (*models.Post, error) { +func (m *MockPostService) Get(ctx context.Context, postID, userID uint32) (*models.PostDto, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get", ctx, postID, userID) - ret0, _ := ret[0].(*models.Post) + ret0, _ := ret[0].(*models.PostDto) ret1, _ := ret[1].(error) return ret0, ret1 } // Get indicates an expected call of Get. -func (mr *MockPostServiceMockRecorder) Get(ctx, postID, userID interface{}) *gomock.Call { +func (mr *MockPostServiceMockRecorder) Get(ctx, postID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPostService)(nil).Get), ctx, postID, userID) } // GetBatch mocks base method. -func (m *MockPostService) GetBatch(ctx context.Context, lastID, userID uint32) ([]*models.Post, error) { +func (m *MockPostService) GetBatch(ctx context.Context, lastID, userID uint32) ([]*models.PostDto, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetBatch", ctx, lastID, userID) - ret0, _ := ret[0].([]*models.Post) + ret0, _ := ret[0].([]*models.PostDto) ret1, _ := ret[1].(error) return ret0, ret1 } // GetBatch indicates an expected call of GetBatch. -func (mr *MockPostServiceMockRecorder) GetBatch(ctx, lastID, userID interface{}) *gomock.Call { +func (mr *MockPostServiceMockRecorder) GetBatch(ctx, lastID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBatch", reflect.TypeOf((*MockPostService)(nil).GetBatch), ctx, lastID, userID) } // GetBatchFromFriend mocks base method. -func (m *MockPostService) GetBatchFromFriend(ctx context.Context, userID, lastID uint32) ([]*models.Post, error) { +func (m *MockPostService) GetBatchFromFriend(ctx context.Context, userID, lastID uint32) ([]*models.PostDto, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetBatchFromFriend", ctx, userID, lastID) - ret0, _ := ret[0].([]*models.Post) + ret0, _ := ret[0].([]*models.PostDto) ret1, _ := ret[1].(error) return ret0, ret1 } // GetBatchFromFriend indicates an expected call of GetBatchFromFriend. -func (mr *MockPostServiceMockRecorder) GetBatchFromFriend(ctx, userID, lastID interface{}) *gomock.Call { +func (mr *MockPostServiceMockRecorder) GetBatchFromFriend(ctx, userID, lastID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBatchFromFriend", reflect.TypeOf((*MockPostService)(nil).GetBatchFromFriend), ctx, userID, lastID) } // GetCommunityPost mocks base method. -func (m *MockPostService) GetCommunityPost(ctx context.Context, communityID, userID, lastID uint32) ([]*models.Post, error) { +func (m *MockPostService) GetCommunityPost(ctx context.Context, communityID, userID, lastID uint32) ([]*models.PostDto, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetCommunityPost", ctx, communityID, userID, lastID) - ret0, _ := ret[0].([]*models.Post) + ret0, _ := ret[0].([]*models.PostDto) ret1, _ := ret[1].(error) return ret0, ret1 } // GetCommunityPost indicates an expected call of GetCommunityPost. -func (mr *MockPostServiceMockRecorder) GetCommunityPost(ctx, communityID, userID, lastID interface{}) *gomock.Call { +func (mr *MockPostServiceMockRecorder) GetCommunityPost(ctx, communityID, userID, lastID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCommunityPost", reflect.TypeOf((*MockPostService)(nil).GetCommunityPost), ctx, communityID, userID, lastID) } @@ -193,7 +199,7 @@ func (m *MockPostService) GetPostAuthorID(ctx context.Context, postID uint32) (u } // GetPostAuthorID indicates an expected call of GetPostAuthorID. -func (mr *MockPostServiceMockRecorder) GetPostAuthorID(ctx, postID interface{}) *gomock.Call { +func (mr *MockPostServiceMockRecorder) GetPostAuthorID(ctx, postID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPostAuthorID", reflect.TypeOf((*MockPostService)(nil).GetPostAuthorID), ctx, postID) } @@ -207,13 +213,13 @@ func (m *MockPostService) SetLikeToPost(ctx context.Context, postID, userID uint } // SetLikeToPost indicates an expected call of SetLikeToPost. -func (mr *MockPostServiceMockRecorder) SetLikeToPost(ctx, postID, userID interface{}) *gomock.Call { +func (mr *MockPostServiceMockRecorder) SetLikeToPost(ctx, postID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLikeToPost", reflect.TypeOf((*MockPostService)(nil).SetLikeToPost), ctx, postID, userID) } // Update mocks base method. -func (m *MockPostService) Update(ctx context.Context, post *models.Post) error { +func (m *MockPostService) Update(ctx context.Context, post *models.PostDto) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Update", ctx, post) ret0, _ := ret[0].(error) @@ -221,15 +227,98 @@ func (m *MockPostService) Update(ctx context.Context, post *models.Post) error { } // Update indicates an expected call of Update. -func (mr *MockPostServiceMockRecorder) Update(ctx, post interface{}) *gomock.Call { +func (mr *MockPostServiceMockRecorder) Update(ctx, post any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockPostService)(nil).Update), ctx, post) } +// MockCommentService is a mock of CommentService interface. +type MockCommentService struct { + ctrl *gomock.Controller + recorder *MockCommentServiceMockRecorder + isgomock struct{} +} + +// MockCommentServiceMockRecorder is the mock recorder for MockCommentService. +type MockCommentServiceMockRecorder struct { + mock *MockCommentService +} + +// NewMockCommentService creates a new mock instance. +func NewMockCommentService(ctrl *gomock.Controller) *MockCommentService { + mock := &MockCommentService{ctrl: ctrl} + mock.recorder = &MockCommentServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCommentService) EXPECT() *MockCommentServiceMockRecorder { + return m.recorder +} + +// Comment mocks base method. +func (m *MockCommentService) Comment(ctx context.Context, userID, postID uint32, comment *models.ContentDto) (*models.CommentDto, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Comment", ctx, userID, postID, comment) + ret0, _ := ret[0].(*models.CommentDto) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Comment indicates an expected call of Comment. +func (mr *MockCommentServiceMockRecorder) Comment(ctx, userID, postID, comment any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Comment", reflect.TypeOf((*MockCommentService)(nil).Comment), ctx, userID, postID, comment) +} + +// DeleteComment mocks base method. +func (m *MockCommentService) DeleteComment(ctx context.Context, commentID, userID uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteComment", ctx, commentID, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteComment indicates an expected call of DeleteComment. +func (mr *MockCommentServiceMockRecorder) DeleteComment(ctx, commentID, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteComment", reflect.TypeOf((*MockCommentService)(nil).DeleteComment), ctx, commentID, userID) +} + +// EditComment mocks base method. +func (m *MockCommentService) EditComment(ctx context.Context, commentID, userID uint32, comment *models.ContentDto) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EditComment", ctx, commentID, userID, comment) + ret0, _ := ret[0].(error) + return ret0 +} + +// EditComment indicates an expected call of EditComment. +func (mr *MockCommentServiceMockRecorder) EditComment(ctx, commentID, userID, comment any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EditComment", reflect.TypeOf((*MockCommentService)(nil).EditComment), ctx, commentID, userID, comment) +} + +// GetComments mocks base method. +func (m *MockCommentService) GetComments(ctx context.Context, postID, lastID uint32, newest bool) ([]*models.CommentDto, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetComments", ctx, postID, lastID, newest) + ret0, _ := ret[0].([]*models.CommentDto) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetComments indicates an expected call of GetComments. +func (mr *MockCommentServiceMockRecorder) GetComments(ctx, postID, lastID, newest any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetComments", reflect.TypeOf((*MockCommentService)(nil).GetComments), ctx, postID, lastID, newest) +} + // MockResponder is a mock of Responder interface. type MockResponder struct { ctrl *gomock.Controller recorder *MockResponderMockRecorder + isgomock struct{} } // MockResponderMockRecorder is the mock recorder for MockResponder. @@ -256,7 +345,7 @@ func (m *MockResponder) ErrorBadRequest(w http.ResponseWriter, err error, reques } // ErrorBadRequest indicates an expected call of ErrorBadRequest. -func (mr *MockResponderMockRecorder) ErrorBadRequest(w, err, requestId interface{}) *gomock.Call { +func (mr *MockResponderMockRecorder) ErrorBadRequest(w, err, requestId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ErrorBadRequest", reflect.TypeOf((*MockResponder)(nil).ErrorBadRequest), w, err, requestId) } @@ -268,7 +357,7 @@ func (m *MockResponder) ErrorInternal(w http.ResponseWriter, err error, requestI } // ErrorInternal indicates an expected call of ErrorInternal. -func (mr *MockResponderMockRecorder) ErrorInternal(w, err, requestId interface{}) *gomock.Call { +func (mr *MockResponderMockRecorder) ErrorInternal(w, err, requestId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ErrorInternal", reflect.TypeOf((*MockResponder)(nil).ErrorInternal), w, err, requestId) } @@ -280,7 +369,7 @@ func (m *MockResponder) LogError(err error, requestId string) { } // LogError indicates an expected call of LogError. -func (mr *MockResponderMockRecorder) LogError(err, requestId interface{}) *gomock.Call { +func (mr *MockResponderMockRecorder) LogError(err, requestId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LogError", reflect.TypeOf((*MockResponder)(nil).LogError), err, requestId) } @@ -292,7 +381,7 @@ func (m *MockResponder) OutputJSON(w http.ResponseWriter, data any, requestId st } // OutputJSON indicates an expected call of OutputJSON. -func (mr *MockResponderMockRecorder) OutputJSON(w, data, requestId interface{}) *gomock.Call { +func (mr *MockResponderMockRecorder) OutputJSON(w, data, requestId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutputJSON", reflect.TypeOf((*MockResponder)(nil).OutputJSON), w, data, requestId) } @@ -304,7 +393,7 @@ func (m *MockResponder) OutputNoMoreContentJSON(w http.ResponseWriter, requestId } // OutputNoMoreContentJSON indicates an expected call of OutputNoMoreContentJSON. -func (mr *MockResponderMockRecorder) OutputNoMoreContentJSON(w, requestId interface{}) *gomock.Call { +func (mr *MockResponderMockRecorder) OutputNoMoreContentJSON(w, requestId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutputNoMoreContentJSON", reflect.TypeOf((*MockResponder)(nil).OutputNoMoreContentJSON), w, requestId) } diff --git a/internal/post/repository/postgres/postgres.go b/internal/post/repository/postgres/postgres.go index 9336dc35..438ef549 100644 --- a/internal/post/repository/postgres/postgres.go +++ b/internal/post/repository/postgres/postgres.go @@ -28,6 +28,14 @@ const ( DeleteLikeFromPost = `DELETE FROM reaction WHERE post_id = $1 AND user_id = $2;` GetLikesOnPost = `SELECT COUNT(*) FROM reaction WHERE post_id = $1;` CheckLike = `SELECT COUNT(*) FROM reaction WHERE post_id = $1 AND user_id=$2;` + + createComment = `INSERT INTO comment (user_id, post_id, content, file_path) VALUES ($1, $2, $3, $4) RETURNING id;` + updateComment = `UPDATE comment SET content = $1, file_path = $2, updated_at = NOW() WHERE id = $3;` + deleteComment = `DELETE FROM comment WHERE id = $1;` + getCommentsBatch = `SELECT id, user_id, content, file_path, created_at FROM comment WHERE post_id = $1 and id < $2 ORDER BY created_at DESC LIMIT 10;` + getCommentBatchAsc = `SELECT id, user_id, content, file_path, created_at FROM comment WHERE post_id = $1 and id > $2 order by created_at LIMIT 10;` + getCommentAuthor = `SELECT user_id FROM comment WHERE id = $1` + getCommentCount = `SELECT COUNT(*) FROM comment WHERE post_id=$1` ) type Adapter struct { @@ -40,21 +48,26 @@ func NewAdapter(db *sql.DB) *Adapter { } } -func (a *Adapter) Create(ctx context.Context, post *models.Post) (uint32, error) { +func (a *Adapter) Create(ctx context.Context, post *models.PostDto) (uint32, error) { var postID uint32 - if err := a.db.QueryRowContext(ctx, createPost, post.Header.AuthorID, post.PostContent.Text, post.PostContent.File).Scan(&postID); err != nil { + if err := a.db.QueryRowContext( + ctx, createPost, post.Header.AuthorID, post.PostContent.Text, post.PostContent.File, + ).Scan(&postID); err != nil { return 0, fmt.Errorf("postgres create post: %w", err) } return postID, nil } -func (a *Adapter) Get(ctx context.Context, postID uint32) (*models.Post, error) { - var post models.Post +func (a *Adapter) Get(ctx context.Context, postID uint32) (*models.PostDto, error) { + var post models.PostDto if err := a.db.QueryRowContext(ctx, getPost, postID). - Scan(&post.ID, &post.Header.AuthorID, &post.PostContent.Text, &post.PostContent.File, &post.PostContent.CreatedAt); err != nil { + Scan( + &post.ID, &post.Header.AuthorID, &post.PostContent.Text, &post.PostContent.File, + &post.PostContent.CreatedAt, + ); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, my_err.ErrPostNotFound } @@ -84,8 +97,10 @@ func (a *Adapter) Delete(ctx context.Context, postID uint32) error { return nil } -func (a *Adapter) Update(ctx context.Context, post *models.Post) error { - res, err := a.db.ExecContext(ctx, updatePost, post.PostContent.Text, post.PostContent.UpdatedAt, post.PostContent.File, post.ID) +func (a *Adapter) Update(ctx context.Context, post *models.PostDto) error { + res, err := a.db.ExecContext( + ctx, updatePost, post.PostContent.Text, post.PostContent.UpdatedAt, post.PostContent.File, post.ID, + ) if err != nil { return fmt.Errorf("postgres update post: %w", err) @@ -103,7 +118,7 @@ func (a *Adapter) Update(ctx context.Context, post *models.Post) error { return nil } -func (a *Adapter) GetPosts(ctx context.Context, lastID uint32) ([]*models.Post, error) { +func (a *Adapter) GetPosts(ctx context.Context, lastID uint32) ([]*models.PostDto, error) { rows, err := a.db.QueryContext(ctx, getPostBatch, lastID) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -113,12 +128,14 @@ func (a *Adapter) GetPosts(ctx context.Context, lastID uint32) ([]*models.Post, } defer rows.Close() - var posts []*models.Post + var posts []*models.PostDto for rows.Next() { - var post models.Post - if err := rows.Scan(&post.ID, &post.Header.AuthorID, &post.Header.CommunityID, - &post.PostContent.Text, &post.PostContent.File, &post.PostContent.CreatedAt); err != nil { + var post models.PostDto + if err := rows.Scan( + &post.ID, &post.Header.AuthorID, &post.Header.CommunityID, + &post.PostContent.Text, &post.PostContent.File, &post.PostContent.CreatedAt, + ); err != nil { return nil, fmt.Errorf("postgres scan posts: %w", err) } posts = append(posts, &post) @@ -131,7 +148,7 @@ func (a *Adapter) GetPosts(ctx context.Context, lastID uint32) ([]*models.Post, return posts, nil } -func (a *Adapter) GetFriendsPosts(ctx context.Context, friendsID []uint32, lastID uint32) ([]*models.Post, error) { +func (a *Adapter) GetFriendsPosts(ctx context.Context, friendsID []uint32, lastID uint32) ([]*models.PostDto, error) { friends := convertSliceToString(friendsID) rows, err := a.db.QueryContext(ctx, getFriendsPost, lastID, friends) if rows != nil { @@ -148,8 +165,8 @@ func (a *Adapter) GetFriendsPosts(ctx context.Context, friendsID []uint32, lastI return createPostBatchFromRows(rows) } -func (a *Adapter) GetAuthorPosts(ctx context.Context, header *models.Header) ([]*models.Post, error) { - var posts []*models.Post +func (a *Adapter) GetAuthorPosts(ctx context.Context, header *models.Header) ([]*models.PostDto, error) { + var posts []*models.PostDto rows, err := a.db.QueryContext(ctx, getProfilePosts, header.AuthorID) @@ -163,7 +180,7 @@ func (a *Adapter) GetAuthorPosts(ctx context.Context, header *models.Header) ([] defer rows.Close() for rows.Next() { - var post models.Post + var post models.PostDto err = rows.Scan(&post.ID, &post.PostContent.Text, &post.PostContent.File, &post.PostContent.CreatedAt) if err != nil { return nil, fmt.Errorf("postgres get author posts: %w", err) @@ -188,12 +205,15 @@ func (a *Adapter) GetPostAuthor(ctx context.Context, postID uint32) (uint32, err return authorID, nil } -func createPostBatchFromRows(rows *sql.Rows) ([]*models.Post, error) { - var posts []*models.Post +func createPostBatchFromRows(rows *sql.Rows) ([]*models.PostDto, error) { + var posts []*models.PostDto for rows.Next() { - var post models.Post - if err := rows.Scan(&post.ID, &post.Header.AuthorID, &post.PostContent.Text, &post.PostContent.File, &post.PostContent.CreatedAt); err != nil { + var post models.PostDto + if err := rows.Scan( + &post.ID, &post.Header.AuthorID, &post.PostContent.Text, &post.PostContent.File, + &post.PostContent.CreatedAt, + ); err != nil { return nil, fmt.Errorf("postgres scan posts: %w", err) } posts = append(posts, &post) @@ -221,26 +241,34 @@ func convertSliceToString(sl []uint32) string { return res } -func (a *Adapter) CreateCommunityPost(ctx context.Context, post *models.Post, communityID uint32) (uint32, error) { +func (a *Adapter) CreateCommunityPost(ctx context.Context, post *models.PostDto, communityID uint32) (uint32, error) { var ID uint32 - if err := a.db.QueryRowContext(ctx, createCommunityPost, communityID, post.PostContent.Text, post.PostContent.File).Scan(&ID); err != nil { + if err := a.db.QueryRowContext( + ctx, createCommunityPost, communityID, post.PostContent.Text, post.PostContent.File, + ).Scan(&ID); err != nil { return 0, fmt.Errorf("postgres create community post db: %w", err) } return ID, nil - } -func (a *Adapter) GetCommunityPosts(ctx context.Context, communityID, id uint32) ([]*models.Post, error) { - var posts []*models.Post +func (a *Adapter) GetCommunityPosts(ctx context.Context, communityID, id uint32) ([]*models.PostDto, error) { + var posts []*models.PostDto rows, err := a.db.QueryContext(ctx, getCommunityPosts, communityID, id) if err != nil { - return nil, fmt.Errorf("postgres get community posts: %w", err) + if errors.Is(err, sql.ErrNoRows) { + return nil, my_err.ErrNoMoreContent + } + return nil, fmt.Errorf("postgres get posts: %w", err) } + defer rows.Close() for rows.Next() { - post := &models.Post{} - err = rows.Scan(&post.ID, &post.Header.CommunityID, &post.PostContent.Text, &post.PostContent.File, &post.PostContent.CreatedAt) + post := &models.PostDto{} + err = rows.Scan( + &post.ID, &post.Header.CommunityID, &post.PostContent.Text, &post.PostContent.File, + &post.PostContent.CreatedAt, + ) if err != nil { return nil, fmt.Errorf("postgres get community posts: %w", err) } @@ -297,3 +325,122 @@ func (a *Adapter) CheckLikes(ctx context.Context, postID, userID uint32) (bool, return true, nil } + +func (a *Adapter) CreateComment( + ctx context.Context, comment *models.ContentDto, userID, postID uint32, +) (uint32, error) { + var id uint32 + + if err := a.db.QueryRowContext( + ctx, createComment, userID, postID, comment.Text, comment.File, + ).Scan(&id); err != nil { + return 0, fmt.Errorf("postgres create comment: %w", err) + } + + return id, nil +} + +func (a *Adapter) DeleteComment(ctx context.Context, commentID uint32) error { + res, err := a.db.ExecContext(ctx, deleteComment, commentID) + + if err != nil { + return fmt.Errorf("postgres delete comment: %w", err) + } + + affected, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("postgres delete comment: %w", err) + } + + if affected == 0 { + return my_err.ErrWrongComment + } + + return nil +} + +func (a *Adapter) UpdateComment(ctx context.Context, comment *models.ContentDto, commentID uint32) error { + res, err := a.db.ExecContext( + ctx, updateComment, comment.Text, comment.File, commentID, + ) + + if err != nil { + return fmt.Errorf("postgres update comment: %w", err) + } + + affected, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("postgres update comment: %w", err) + } + + if affected == 0 { + return my_err.ErrWrongComment + } + + return nil +} + +func (a *Adapter) GetComments(ctx context.Context, postID, lastID uint32, newest bool) ([]*models.CommentDto, error) { + var ( + rows *sql.Rows + err error + ) + + if newest { + rows, err = a.db.QueryContext(ctx, getCommentsBatch, postID, lastID) + } else { + rows, err = a.db.QueryContext(ctx, getCommentBatchAsc, postID, lastID) + } + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, my_err.ErrNoMoreContent + } + return nil, fmt.Errorf("postgres get posts: %w", err) + } + + var comments []*models.CommentDto + for rows.Next() { + comment := models.CommentDto{} + if err := rows.Scan( + &comment.ID, &comment.Header.AuthorID, &comment.Content.Text, &comment.Content.File, + &comment.Content.CreatedAt, + ); err != nil { + return nil, fmt.Errorf("postgres get comments: %w", err) + } + + comments = append(comments, &comment) + } + + if len(comments) == 0 { + return comments, my_err.ErrNoMoreContent + } + + return comments, nil +} + +func (a *Adapter) GetCommentAuthor(ctx context.Context, commentID uint32) (uint32, error) { + var authorID uint32 + + if err := a.db.QueryRowContext(ctx, getCommentAuthor, commentID).Scan(&authorID); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, my_err.ErrWrongComment + } + return 0, fmt.Errorf("postgres get comment author: %w", err) + } + + return authorID, nil +} + +func (a *Adapter) GetCommentCount(ctx context.Context, postID uint32) (uint32, error) { + var count uint32 + + if err := a.db.QueryRowContext(ctx, getCommentCount, postID).Scan(&count); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, my_err.ErrWrongPost + } + return 0, fmt.Errorf("postgres get comment count: %w", err) + } + + return count, nil +} diff --git a/internal/post/repository/postgres/postgres_test.go b/internal/post/repository/postgres/postgres_test.go index 936b938c..c58f0832 100644 --- a/internal/post/repository/postgres/postgres_test.go +++ b/internal/post/repository/postgres/postgres_test.go @@ -21,7 +21,7 @@ type TestCaseGet struct { ID uint32 wantErr error dbErr error - wantPost *models.Post + wantPost *models.PostDto } func TestGet(t *testing.T) { @@ -33,11 +33,16 @@ func TestGet(t *testing.T) { var ID uint32 = 1 rows := sqlmock.NewRows([]string{"id", "author_id", "content", "file_path", "created_at"}) - expect := []*models.Post{ - {ID: ID, Header: models.Header{AuthorID: 1}, PostContent: models.Content{Text: "content from user 1", File: "http://somefile", CreatedAt: time.Now()}}, + expect := []*models.PostDto{ + { + ID: ID, Header: models.Header{AuthorID: 1}, + PostContent: models.ContentDto{Text: "content from user 1", File: "http://somefile", CreatedAt: time.Now()}, + }, } for _, post := range expect { - rows = rows.AddRow(ID, post.Header.AuthorID, post.PostContent.Text, post.PostContent.File, post.PostContent.CreatedAt) + rows = rows.AddRow( + ID, post.Header.AuthorID, post.PostContent.Text, post.PostContent.File, post.PostContent.CreatedAt, + ) } repo := NewAdapter(db) @@ -63,7 +68,7 @@ func TestGet(t *testing.T) { } type TestCaseCreate struct { - post *models.Post + post *models.PostDto wantID uint32 wantErr error dbErr error @@ -79,9 +84,22 @@ func TestCreate(t *testing.T) { repo := NewAdapter(db) tests := []TestCaseCreate{ - {post: &models.Post{Header: models.Header{AuthorID: 1}, PostContent: models.Content{Text: "content from user 1"}}, wantID: 1, wantErr: nil, dbErr: nil}, - {post: &models.Post{Header: models.Header{AuthorID: 2}, PostContent: models.Content{Text: "content from user 2", File: "http://someFile"}}, wantID: 2, wantErr: nil, dbErr: nil}, - {post: &models.Post{Header: models.Header{AuthorID: 10}, PostContent: models.Content{Text: "wrong query"}}, wantID: 0, wantErr: errMockDB, dbErr: errMockDB}, + { + post: &models.PostDto{ + Header: models.Header{AuthorID: 1}, PostContent: models.ContentDto{Text: "content from user 1"}, + }, wantID: 1, wantErr: nil, dbErr: nil, + }, + { + post: &models.PostDto{ + Header: models.Header{AuthorID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", File: "http://someFile"}, + }, wantID: 2, wantErr: nil, dbErr: nil, + }, + { + post: &models.PostDto{ + Header: models.Header{AuthorID: 10}, PostContent: models.ContentDto{Text: "wrong query"}, + }, wantID: 0, wantErr: errMockDB, dbErr: errMockDB, + }, } for _, test := range tests { @@ -137,7 +155,7 @@ func TestDelete(t *testing.T) { } type TestCaseUpdate struct { - post *models.Post + post *models.PostDto rowsAffected int64 wantErr error dbErr error @@ -153,15 +171,32 @@ func TestUpdate(t *testing.T) { repo := NewAdapter(db) tests := []TestCaseUpdate{ - {post: &models.Post{ID: 1, PostContent: models.Content{Text: "update post", File: "http://someFile", UpdatedAt: time.Now()}}, wantErr: nil, dbErr: nil, rowsAffected: 1}, - {post: &models.Post{ID: 2, PostContent: models.Content{Text: "wrong ID", UpdatedAt: time.Now()}}, wantErr: my_err.ErrPostNotFound, dbErr: nil, rowsAffected: 0}, - {post: &models.Post{ID: 1, PostContent: models.Content{Text: "update post who was update early", UpdatedAt: time.Now()}}, wantErr: nil, dbErr: nil, rowsAffected: 1}, - {post: &models.Post{ID: 5, PostContent: models.Content{Text: "wrong query", UpdatedAt: time.Now()}}, wantErr: errMockDB, dbErr: errMockDB, rowsAffected: 0}, + { + post: &models.PostDto{ + ID: 1, + PostContent: models.ContentDto{Text: "update post", File: "http://someFile", UpdatedAt: time.Now()}, + }, wantErr: nil, dbErr: nil, rowsAffected: 1, + }, + { + post: &models.PostDto{ID: 2, PostContent: models.ContentDto{Text: "wrong ID", UpdatedAt: time.Now()}}, + wantErr: my_err.ErrPostNotFound, dbErr: nil, rowsAffected: 0, + }, + { + post: &models.PostDto{ + ID: 1, PostContent: models.ContentDto{Text: "update post who was update early", UpdatedAt: time.Now()}, + }, wantErr: nil, dbErr: nil, rowsAffected: 1, + }, + { + post: &models.PostDto{ID: 5, PostContent: models.ContentDto{Text: "wrong query", UpdatedAt: time.Now()}}, + wantErr: errMockDB, dbErr: errMockDB, rowsAffected: 0, + }, } for _, test := range tests { mock.ExpectExec(regexp.QuoteMeta(updatePost)). - WithArgs(test.post.PostContent.Text, test.post.PostContent.UpdatedAt, test.post.PostContent.File, test.post.ID). + WithArgs( + test.post.PostContent.Text, test.post.PostContent.UpdatedAt, test.post.PostContent.File, test.post.ID, + ). WillReturnResult(sqlmock.NewResult(0, test.rowsAffected)). WillReturnError(test.dbErr) @@ -212,7 +247,7 @@ func TestGetPostAuthor(t *testing.T) { type TestCaseGetAuthorPosts struct { Author *models.Header - wantPosts []*models.Post + wantPosts []*models.PostDto dbErr error wantErr error } @@ -230,16 +265,20 @@ func TestGetAuthorsPosts(t *testing.T) { tests := []TestCaseGetAuthorPosts{ { Author: &models.Header{AuthorID: 1}, - wantPosts: []*models.Post{ + wantPosts: []*models.PostDto{ { - ID: 1, - Header: models.Header{AuthorID: 1}, - PostContent: models.Content{Text: "content from user 1", File: "http://somefile", CreatedAt: createTime}, + ID: 1, + Header: models.Header{AuthorID: 1}, + PostContent: models.ContentDto{ + Text: "content from user 1", File: "http://somefile", CreatedAt: createTime, + }, }, { - ID: 2, - Header: models.Header{AuthorID: 1}, - PostContent: models.Content{Text: "another content from user 1", File: "http://somefile2", CreatedAt: createTime}, + ID: 2, + Header: models.Header{AuthorID: 1}, + PostContent: models.ContentDto{ + Text: "another content from user 1", File: "http://somefile2", CreatedAt: createTime, + }, }, }, wantErr: nil, @@ -247,8 +286,11 @@ func TestGetAuthorsPosts(t *testing.T) { }, { Author: &models.Header{AuthorID: 2}, - wantPosts: []*models.Post{ - {ID: 3, Header: models.Header{AuthorID: 2}, PostContent: models.Content{Text: "content from user 2", CreatedAt: createTime}}, + wantPosts: []*models.PostDto{ + { + ID: 3, Header: models.Header{AuthorID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", CreatedAt: createTime}, + }, }, wantErr: nil, dbErr: nil, @@ -287,7 +329,7 @@ func TestGetAuthorsPosts(t *testing.T) { type TestCaseGetPosts struct { lastID uint32 - wantPost []*models.Post + wantPost []*models.PostDto dbErr error wantErr error } @@ -300,18 +342,51 @@ func TestGetPosts(t *testing.T) { defer db.Close() createTime := time.Now() - expect := []*models.Post{ - {ID: 1, Header: models.Header{AuthorID: 1}, PostContent: models.Content{Text: "content from user 1", CreatedAt: createTime}}, - {ID: 2, Header: models.Header{AuthorID: 1}, PostContent: models.Content{Text: "content from user 1", CreatedAt: createTime}}, - {ID: 3, Header: models.Header{AuthorID: 2}, PostContent: models.Content{Text: "content from user 2", CreatedAt: createTime}}, - {ID: 4, Header: models.Header{AuthorID: 2}, PostContent: models.Content{Text: "content from user 2", CreatedAt: createTime}}, - {ID: 5, Header: models.Header{AuthorID: 3}, PostContent: models.Content{Text: "content from user 3", CreatedAt: createTime}}, - {ID: 6, Header: models.Header{AuthorID: 3}, PostContent: models.Content{Text: "content from user 3", CreatedAt: createTime}}, - {ID: 7, Header: models.Header{AuthorID: 6}, PostContent: models.Content{Text: "content from user 6", CreatedAt: createTime}}, - {ID: 8, Header: models.Header{AuthorID: 4}, PostContent: models.Content{Text: "content from user 4", CreatedAt: createTime}}, - {ID: 9, Header: models.Header{AuthorID: 2}, PostContent: models.Content{Text: "content from user 2", CreatedAt: createTime}}, - {ID: 10, Header: models.Header{AuthorID: 1}, PostContent: models.Content{Text: "content from user 1", CreatedAt: createTime}}, - {ID: 11, Header: models.Header{AuthorID: 2}, PostContent: models.Content{Text: "content from user 2", CreatedAt: createTime}}, + expect := []*models.PostDto{ + { + ID: 1, Header: models.Header{AuthorID: 1}, + PostContent: models.ContentDto{Text: "content from user 1", CreatedAt: createTime}, + }, + { + ID: 2, Header: models.Header{AuthorID: 1}, + PostContent: models.ContentDto{Text: "content from user 1", CreatedAt: createTime}, + }, + { + ID: 3, Header: models.Header{AuthorID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", CreatedAt: createTime}, + }, + { + ID: 4, Header: models.Header{AuthorID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", CreatedAt: createTime}, + }, + { + ID: 5, Header: models.Header{AuthorID: 3}, + PostContent: models.ContentDto{Text: "content from user 3", CreatedAt: createTime}, + }, + { + ID: 6, Header: models.Header{AuthorID: 3}, + PostContent: models.ContentDto{Text: "content from user 3", CreatedAt: createTime}, + }, + { + ID: 7, Header: models.Header{AuthorID: 6}, + PostContent: models.ContentDto{Text: "content from user 6", CreatedAt: createTime}, + }, + { + ID: 8, Header: models.Header{AuthorID: 4}, + PostContent: models.ContentDto{Text: "content from user 4", CreatedAt: createTime}, + }, + { + ID: 9, Header: models.Header{AuthorID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", CreatedAt: createTime}, + }, + { + ID: 10, Header: models.Header{AuthorID: 1}, + PostContent: models.ContentDto{Text: "content from user 1", CreatedAt: createTime}, + }, + { + ID: 11, Header: models.Header{AuthorID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", CreatedAt: createTime}, + }, } repo := NewAdapter(db) @@ -336,7 +411,10 @@ func TestGetPosts(t *testing.T) { for _, test := range tests { rows := sqlmock.NewRows([]string{"id", "author_id", "community_id", "content", "file_path", "created_at"}) for _, post := range test.wantPost { - rows.AddRow(post.ID, post.Header.AuthorID, post.Header.CommunityID, post.PostContent.Text, post.PostContent.File, post.PostContent.CreatedAt) + rows.AddRow( + post.ID, post.Header.AuthorID, post.Header.CommunityID, post.PostContent.Text, post.PostContent.File, + post.PostContent.CreatedAt, + ) } mock.ExpectQuery(regexp.QuoteMeta(getPostBatch)). WithArgs(test.lastID). @@ -374,7 +452,7 @@ func TestConvertSliceToString(t *testing.T) { type GetFriendsPosts struct { lastID uint32 friendsID []uint32 - wantPost []*models.Post + wantPost []*models.PostDto wantErr error dbErr error } @@ -388,18 +466,51 @@ func TestGetFriendsPosts(t *testing.T) { createTime := time.Now() - expect := []*models.Post{ - {ID: 1, Header: models.Header{AuthorID: 1}, PostContent: models.Content{Text: "content from user 1", CreatedAt: createTime}}, - {ID: 2, Header: models.Header{AuthorID: 1}, PostContent: models.Content{Text: "content from user 1", CreatedAt: createTime}}, - {ID: 3, Header: models.Header{AuthorID: 2}, PostContent: models.Content{Text: "content from user 2", CreatedAt: createTime}}, - {ID: 4, Header: models.Header{AuthorID: 2}, PostContent: models.Content{Text: "content from user 2", CreatedAt: createTime}}, - {ID: 5, Header: models.Header{AuthorID: 3}, PostContent: models.Content{Text: "content from user 3", CreatedAt: createTime}}, - {ID: 6, Header: models.Header{AuthorID: 3}, PostContent: models.Content{Text: "content from user 3", CreatedAt: createTime}}, - {ID: 7, Header: models.Header{AuthorID: 6}, PostContent: models.Content{Text: "content from user 6", CreatedAt: createTime}}, - {ID: 8, Header: models.Header{AuthorID: 4}, PostContent: models.Content{Text: "content from user 4", CreatedAt: createTime}}, - {ID: 9, Header: models.Header{AuthorID: 2}, PostContent: models.Content{Text: "content from user 2", CreatedAt: createTime}}, - {ID: 10, Header: models.Header{AuthorID: 1}, PostContent: models.Content{Text: "content from user 1", CreatedAt: createTime}}, - {ID: 11, Header: models.Header{AuthorID: 2}, PostContent: models.Content{Text: "content from user 2", CreatedAt: createTime}}, + expect := []*models.PostDto{ + { + ID: 1, Header: models.Header{AuthorID: 1}, + PostContent: models.ContentDto{Text: "content from user 1", CreatedAt: createTime}, + }, + { + ID: 2, Header: models.Header{AuthorID: 1}, + PostContent: models.ContentDto{Text: "content from user 1", CreatedAt: createTime}, + }, + { + ID: 3, Header: models.Header{AuthorID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", CreatedAt: createTime}, + }, + { + ID: 4, Header: models.Header{AuthorID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", CreatedAt: createTime}, + }, + { + ID: 5, Header: models.Header{AuthorID: 3}, + PostContent: models.ContentDto{Text: "content from user 3", CreatedAt: createTime}, + }, + { + ID: 6, Header: models.Header{AuthorID: 3}, + PostContent: models.ContentDto{Text: "content from user 3", CreatedAt: createTime}, + }, + { + ID: 7, Header: models.Header{AuthorID: 6}, + PostContent: models.ContentDto{Text: "content from user 6", CreatedAt: createTime}, + }, + { + ID: 8, Header: models.Header{AuthorID: 4}, + PostContent: models.ContentDto{Text: "content from user 4", CreatedAt: createTime}, + }, + { + ID: 9, Header: models.Header{AuthorID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", CreatedAt: createTime}, + }, + { + ID: 10, Header: models.Header{AuthorID: 1}, + PostContent: models.ContentDto{Text: "content from user 1", CreatedAt: createTime}, + }, + { + ID: 11, Header: models.Header{AuthorID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", CreatedAt: createTime}, + }, } repo := NewAdapter(db) @@ -416,7 +527,9 @@ func TestGetFriendsPosts(t *testing.T) { for _, test := range tests { rows := sqlmock.NewRows([]string{"id", "author_id", "content", "file_path", "created_at"}) for _, post := range test.wantPost { - rows.AddRow(post.ID, post.Header.AuthorID, post.PostContent.Text, post.PostContent.File, post.PostContent.CreatedAt) + rows.AddRow( + post.ID, post.Header.AuthorID, post.PostContent.Text, post.PostContent.File, post.PostContent.CreatedAt, + ) } mock.ExpectQuery(regexp.QuoteMeta(getFriendsPost)). WithArgs(test.lastID, convertSliceToString(test.friendsID)). @@ -430,3 +543,163 @@ func TestGetFriendsPosts(t *testing.T) { assert.Equalf(t, posts, test.wantPost, "result dont match\nwant: %v\ngot:%v", test.wantPost, posts) } } + +type TestCaseCreateCommunity struct { + post *models.PostDto + communityID uint32 + wantID uint32 + wantErr error + dbErr error +} + +func TestCreateCommunityPost(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + repo := NewAdapter(db) + + tests := []TestCaseCreateCommunity{ + { + communityID: 1, + post: &models.PostDto{ + Header: models.Header{AuthorID: 1}, PostContent: models.ContentDto{Text: "content from user 1"}, + }, wantID: 1, wantErr: nil, dbErr: nil, + }, + { + communityID: 2, + post: &models.PostDto{ + Header: models.Header{AuthorID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", File: "http://someFile"}, + }, wantID: 2, wantErr: nil, dbErr: nil, + }, + { + communityID: 3, + post: &models.PostDto{ + Header: models.Header{AuthorID: 10}, PostContent: models.ContentDto{Text: "wrong query"}, + }, wantID: 0, wantErr: errMockDB, dbErr: errMockDB, + }, + } + + for _, test := range tests { + mock.ExpectQuery(regexp.QuoteMeta(createCommunityPost)). + WithArgs(test.communityID, test.post.PostContent.Text, test.post.PostContent.File). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(test.wantID)). + WillReturnError(test.dbErr) + + id, err := repo.CreateCommunityPost(context.Background(), test.post, test.communityID) + if id != test.wantID { + t.Errorf("results not match,\n want %v\n have %v", test.wantID, id) + } + if !errors.Is(err, test.wantErr) { + t.Errorf("unexpected err:\n want:%v\n got:%v", test.wantErr, err) + } + } +} + +type TestCaseGetCommunityPost struct { + communityID uint32 + lastID uint32 + wantPost []*models.PostDto + dbErr error + wantErr error +} + +func TestCommunityPost(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + createTime := time.Now() + expect := []*models.PostDto{ + { + ID: 1, Header: models.Header{CommunityID: 1}, + PostContent: models.ContentDto{Text: "content from user 1", CreatedAt: createTime}, + }, + { + ID: 2, Header: models.Header{CommunityID: 1}, + PostContent: models.ContentDto{Text: "content from user 1", CreatedAt: createTime}, + }, + { + ID: 3, Header: models.Header{CommunityID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", CreatedAt: createTime}, + }, + { + ID: 4, Header: models.Header{CommunityID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", CreatedAt: createTime}, + }, + { + ID: 5, Header: models.Header{CommunityID: 3}, + PostContent: models.ContentDto{Text: "content from user 3", CreatedAt: createTime}, + }, + { + ID: 6, Header: models.Header{CommunityID: 3}, + PostContent: models.ContentDto{Text: "content from user 3", CreatedAt: createTime}, + }, + { + ID: 7, Header: models.Header{CommunityID: 6}, + PostContent: models.ContentDto{Text: "content from user 6", CreatedAt: createTime}, + }, + { + ID: 8, Header: models.Header{CommunityID: 4}, + PostContent: models.ContentDto{Text: "content from user 4", CreatedAt: createTime}, + }, + { + ID: 9, Header: models.Header{CommunityID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", CreatedAt: createTime}, + }, + { + ID: 10, Header: models.Header{CommunityID: 1}, + PostContent: models.ContentDto{Text: "content from user 1", CreatedAt: createTime}, + }, + { + ID: 11, Header: models.Header{CommunityID: 2}, + PostContent: models.ContentDto{Text: "content from user 2", CreatedAt: createTime}, + }, + } + + repo := NewAdapter(db) + + tests := []TestCaseGetCommunityPost{ + {lastID: 0, wantPost: nil, wantErr: my_err.ErrNoMoreContent, dbErr: sql.ErrNoRows}, + {lastID: 1, wantPost: nil, wantErr: errMockDB, dbErr: errMockDB}, + { + lastID: 3, + communityID: 1, + wantPost: expect[:3], + wantErr: nil, + dbErr: nil, + }, + { + lastID: 11, + wantPost: expect[1:11], + communityID: 10, + wantErr: nil, + dbErr: nil, + }, + } + + for _, test := range tests { + rows := sqlmock.NewRows([]string{"id", "community_id", "content", "file_path", "created_at"}) + for _, post := range test.wantPost { + rows.AddRow( + post.ID, post.Header.CommunityID, post.PostContent.Text, post.PostContent.File, + post.PostContent.CreatedAt, + ) + } + mock.ExpectQuery(regexp.QuoteMeta(getCommunityPosts)). + WithArgs(test.communityID, test.lastID). + WillReturnRows(rows). + WillReturnError(test.dbErr) + + posts, err := repo.GetCommunityPosts(context.Background(), test.communityID, test.lastID) + if !errors.Is(err, test.wantErr) { + t.Errorf("unexpected error: got:%v\nwant:%v\n", err, test.wantErr) + } + assert.Equalf(t, posts, test.wantPost, "result dont match\nwant: %v\ngot:%v", test.wantPost, posts) + } +} diff --git a/internal/post/service/comment.go b/internal/post/service/comment.go new file mode 100644 index 00000000..5dd50a80 --- /dev/null +++ b/internal/post/service/comment.go @@ -0,0 +1,112 @@ +package service + +import ( + "context" + "fmt" + + "github.com/2024_2_BetterCallFirewall/internal/models" + "github.com/2024_2_BetterCallFirewall/pkg/my_err" +) + +//go:generate mockgen -destination=comment_mock.go -source=$GOFILE -package=${GOPACKAGE} +type dbI interface { + CreateComment(ctx context.Context, comment *models.ContentDto, userID, postID uint32) (uint32, error) + DeleteComment(ctx context.Context, commentID uint32) error + UpdateComment(ctx context.Context, comment *models.ContentDto, commentID uint32) error + GetComments(ctx context.Context, postID, lastID uint32, newest bool) ([]*models.CommentDto, error) + GetCommentAuthor(ctx context.Context, commentID uint32) (uint32, error) +} + +type profileRepoI interface { + GetHeader(ctx context.Context, userID uint32) (*models.Header, error) +} + +type CommentService struct { + db dbI + profileRepo profileRepoI +} + +func NewCommentService(db dbI, profileRepo profileRepoI) *CommentService { + return &CommentService{ + db: db, + profileRepo: profileRepo, + } +} + +func (s *CommentService) Comment( + ctx context.Context, userID, postID uint32, comment *models.ContentDto, +) (*models.CommentDto, error) { + id, err := s.db.CreateComment(ctx, comment, userID, postID) + if err != nil { + return nil, fmt.Errorf("create comment: %w", err) + } + + header, err := s.profileRepo.GetHeader(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get header: %w", err) + } + + newComment := &models.CommentDto{ + ID: id, + Content: *comment, + Header: *header, + } + + return newComment, nil +} + +func (s *CommentService) DeleteComment(ctx context.Context, commentID, userID uint32) error { + authorID, err := s.db.GetCommentAuthor(ctx, commentID) + if err != nil { + return fmt.Errorf("get comment author: %w", err) + } + + if authorID != userID { + return my_err.ErrAccessDenied + } + + err = s.db.DeleteComment(ctx, commentID) + if err != nil { + return fmt.Errorf("delete comment: %w", err) + } + + return nil +} + +func (s *CommentService) EditComment(ctx context.Context, commentID, userID uint32, comment *models.ContentDto) error { + authorID, err := s.db.GetCommentAuthor(ctx, commentID) + if err != nil { + return fmt.Errorf("get comment author: %w", err) + } + + if authorID != userID { + return my_err.ErrAccessDenied + } + + err = s.db.UpdateComment(ctx, comment, commentID) + if err != nil { + return fmt.Errorf("delete comment: %w", err) + } + + return nil +} + +func (s *CommentService) GetComments( + ctx context.Context, postID, lastID uint32, newest bool, +) ([]*models.CommentDto, error) { + comments, err := s.db.GetComments(ctx, postID, lastID, newest) + if err != nil { + return nil, fmt.Errorf("get comments: %w", err) + } + + for i, c := range comments { + header, err := s.profileRepo.GetHeader(ctx, c.Header.AuthorID) + if err != nil { + return nil, fmt.Errorf("get header %d: %w", i, err) + } + + comments[i].Header = *header + } + + return comments, nil +} diff --git a/internal/post/service/comment_mock.go b/internal/post/service/comment_mock.go new file mode 100644 index 00000000..5168ea2d --- /dev/null +++ b/internal/post/service/comment_mock.go @@ -0,0 +1,154 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: comment.go +// +// Generated by this command: +// +// mockgen -destination=comment_mock.go -source=comment.go -package=service +// + +// Package service is a generated GoMock package. +package service + +import ( + context "context" + reflect "reflect" + + models "github.com/2024_2_BetterCallFirewall/internal/models" + gomock "github.com/golang/mock/gomock" +) + +// MockdbI is a mock of dbI interface. +type MockdbI struct { + ctrl *gomock.Controller + recorder *MockdbIMockRecorder + isgomock struct{} +} + +// MockdbIMockRecorder is the mock recorder for MockdbI. +type MockdbIMockRecorder struct { + mock *MockdbI +} + +// NewMockdbI creates a new mock instance. +func NewMockdbI(ctrl *gomock.Controller) *MockdbI { + mock := &MockdbI{ctrl: ctrl} + mock.recorder = &MockdbIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockdbI) EXPECT() *MockdbIMockRecorder { + return m.recorder +} + +// CreateComment mocks base method. +func (m *MockdbI) CreateComment(ctx context.Context, comment *models.ContentDto, userID, postID uint32) (uint32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateComment", ctx, comment, userID, postID) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateComment indicates an expected call of CreateComment. +func (mr *MockdbIMockRecorder) CreateComment(ctx, comment, userID, postID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateComment", reflect.TypeOf((*MockdbI)(nil).CreateComment), ctx, comment, userID, postID) +} + +// DeleteComment mocks base method. +func (m *MockdbI) DeleteComment(ctx context.Context, commentID uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteComment", ctx, commentID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteComment indicates an expected call of DeleteComment. +func (mr *MockdbIMockRecorder) DeleteComment(ctx, commentID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteComment", reflect.TypeOf((*MockdbI)(nil).DeleteComment), ctx, commentID) +} + +// GetCommentAuthor mocks base method. +func (m *MockdbI) GetCommentAuthor(ctx context.Context, commentID uint32) (uint32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCommentAuthor", ctx, commentID) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCommentAuthor indicates an expected call of GetCommentAuthor. +func (mr *MockdbIMockRecorder) GetCommentAuthor(ctx, commentID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCommentAuthor", reflect.TypeOf((*MockdbI)(nil).GetCommentAuthor), ctx, commentID) +} + +// GetComments mocks base method. +func (m *MockdbI) GetComments(ctx context.Context, postID, lastID uint32, newest bool) ([]*models.CommentDto, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetComments", ctx, postID, lastID, newest) + ret0, _ := ret[0].([]*models.CommentDto) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetComments indicates an expected call of GetComments. +func (mr *MockdbIMockRecorder) GetComments(ctx, postID, lastID, newest any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetComments", reflect.TypeOf((*MockdbI)(nil).GetComments), ctx, postID, lastID, newest) +} + +// UpdateComment mocks base method. +func (m *MockdbI) UpdateComment(ctx context.Context, comment *models.ContentDto, commentID uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateComment", ctx, comment, commentID) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateComment indicates an expected call of UpdateComment. +func (mr *MockdbIMockRecorder) UpdateComment(ctx, comment, commentID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateComment", reflect.TypeOf((*MockdbI)(nil).UpdateComment), ctx, comment, commentID) +} + +// MockprofileRepoI is a mock of profileRepoI interface. +type MockprofileRepoI struct { + ctrl *gomock.Controller + recorder *MockprofileRepoIMockRecorder + isgomock struct{} +} + +// MockprofileRepoIMockRecorder is the mock recorder for MockprofileRepoI. +type MockprofileRepoIMockRecorder struct { + mock *MockprofileRepoI +} + +// NewMockprofileRepoI creates a new mock instance. +func NewMockprofileRepoI(ctrl *gomock.Controller) *MockprofileRepoI { + mock := &MockprofileRepoI{ctrl: ctrl} + mock.recorder = &MockprofileRepoIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockprofileRepoI) EXPECT() *MockprofileRepoIMockRecorder { + return m.recorder +} + +// GetHeader mocks base method. +func (m *MockprofileRepoI) GetHeader(ctx context.Context, userID uint32) (*models.Header, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHeader", ctx, userID) + ret0, _ := ret[0].(*models.Header) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetHeader indicates an expected call of GetHeader. +func (mr *MockprofileRepoIMockRecorder) GetHeader(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeader", reflect.TypeOf((*MockprofileRepoI)(nil).GetHeader), ctx, userID) +} diff --git a/internal/post/service/comment_test.go b/internal/post/service/comment_test.go new file mode 100644 index 00000000..56a6f0f6 --- /dev/null +++ b/internal/post/service/comment_test.go @@ -0,0 +1,534 @@ +package service + +import ( + "context" + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/2024_2_BetterCallFirewall/internal/models" + "github.com/2024_2_BetterCallFirewall/pkg/my_err" +) + +type commentMocks struct { + repo *MockdbI + profileRepo *MockprofileRepoI +} + +func getCommentService(ctrl *gomock.Controller) (*CommentService, *commentMocks) { + m := &commentMocks{ + repo: NewMockdbI(ctrl), + profileRepo: NewMockprofileRepoI(ctrl), + } + + return NewCommentService(m.repo, m.profileRepo), m +} + +type inputCreate struct { + userID uint32 + postID uint32 + comment *models.ContentDto +} + +func TestNewCommentService(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + service, _ := getCommentService(ctrl) + assert.NotNil(t, service) +} + +func TestCommentService_Comment(t *testing.T) { + tests := []TableTest3[*models.CommentDto, inputCreate]{ + { + name: "1", + SetupInput: func() (*inputCreate, error) { + return &inputCreate{}, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputCreate, + ) (*models.CommentDto, error) { + return implementation.Comment(ctx, request.userID, request.postID, request.comment) + }, + ExpectedResult: func() (*models.CommentDto, error) { + return nil, nil + }, + ExpectedErr: errMock, + SetupMock: func(request inputCreate, m *commentMocks) { + m.repo.EXPECT().CreateComment(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(uint32(0), errMock) + }, + }, + { + name: "2", + SetupInput: func() (*inputCreate, error) { + return &inputCreate{}, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputCreate, + ) (*models.CommentDto, error) { + return implementation.Comment(ctx, request.userID, request.postID, request.comment) + }, + ExpectedResult: func() (*models.CommentDto, error) { + return nil, nil + }, + ExpectedErr: errMock, + SetupMock: func(request inputCreate, m *commentMocks) { + m.repo.EXPECT().CreateComment(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(uint32(1), nil) + m.profileRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()). + Return(nil, errMock) + }, + }, + { + name: "3", + SetupInput: func() (*inputCreate, error) { + return &inputCreate{ + userID: 1, + postID: 1, + comment: &models.ContentDto{ + Text: "new comment", + }, + }, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputCreate, + ) (*models.CommentDto, error) { + return implementation.Comment(ctx, request.userID, request.postID, request.comment) + }, + ExpectedResult: func() (*models.CommentDto, error) { + return &models.CommentDto{ + ID: 1, + Content: models.ContentDto{ + Text: "new comment", + }, + Header: models.Header{ + AuthorID: 1, + Author: "Alexey Zemliakov", + }, + }, nil + }, + ExpectedErr: nil, + SetupMock: func(request inputCreate, m *commentMocks) { + m.repo.EXPECT().CreateComment(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(uint32(1), nil) + m.profileRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()). + Return( + &models.Header{ + AuthorID: 1, + Author: "Alexey Zemliakov", + }, nil, + ) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getCommentService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +type inputDelete struct { + userID uint32 + commentID uint32 +} + +func TestCommentService_Delete(t *testing.T) { + tests := []TableTest3[struct{}, inputDelete]{ + { + name: "1", + SetupInput: func() (*inputDelete, error) { + return &inputDelete{}, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputDelete, + ) (struct{}, error) { + return struct{}{}, implementation.DeleteComment(ctx, request.userID, request.commentID) + }, + ExpectedResult: func() (struct{}, error) { + return struct{}{}, nil + }, + ExpectedErr: errMock, + SetupMock: func(request inputDelete, m *commentMocks) { + m.repo.EXPECT().GetCommentAuthor(gomock.Any(), gomock.Any()).Return(uint32(0), errMock) + }, + }, + { + name: "2", + SetupInput: func() (*inputDelete, error) { + return &inputDelete{ + userID: 1, + }, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputDelete, + ) (struct{}, error) { + return struct{}{}, implementation.DeleteComment(ctx, request.commentID, request.userID) + }, + ExpectedResult: func() (struct{}, error) { + return struct{}{}, nil + }, + ExpectedErr: my_err.ErrAccessDenied, + SetupMock: func(request inputDelete, m *commentMocks) { + m.repo.EXPECT().GetCommentAuthor(gomock.Any(), gomock.Any()).Return(uint32(0), nil) + }, + }, + { + name: "3", + SetupInput: func() (*inputDelete, error) { + return &inputDelete{ + userID: 1, + }, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputDelete, + ) (struct{}, error) { + return struct{}{}, implementation.DeleteComment(ctx, request.commentID, request.userID) + }, + ExpectedResult: func() (struct{}, error) { + return struct{}{}, nil + }, + ExpectedErr: errMock, + SetupMock: func(request inputDelete, m *commentMocks) { + m.repo.EXPECT().GetCommentAuthor(gomock.Any(), gomock.Any()).Return(uint32(1), nil) + m.repo.EXPECT().DeleteComment(gomock.Any(), gomock.Any()).Return(errMock) + }, + }, + { + name: "4", + SetupInput: func() (*inputDelete, error) { + return &inputDelete{ + userID: 1, + }, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputDelete, + ) (struct{}, error) { + return struct{}{}, implementation.DeleteComment(ctx, request.commentID, request.userID) + }, + ExpectedResult: func() (struct{}, error) { + return struct{}{}, nil + }, + ExpectedErr: nil, + SetupMock: func(request inputDelete, m *commentMocks) { + m.repo.EXPECT().GetCommentAuthor(gomock.Any(), gomock.Any()).Return(uint32(1), nil) + m.repo.EXPECT().DeleteComment(gomock.Any(), gomock.Any()).Return(nil) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getCommentService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +type inputEdit struct { + userID uint32 + commentID uint32 + comment *models.ContentDto +} + +func TestCommentService_Edit(t *testing.T) { + tests := []TableTest3[struct{}, inputEdit]{ + { + name: "1", + SetupInput: func() (*inputEdit, error) { + return &inputEdit{}, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputEdit, + ) (struct{}, error) { + return struct{}{}, implementation.EditComment(ctx, request.commentID, request.userID, request.comment) + }, + ExpectedResult: func() (struct{}, error) { + return struct{}{}, nil + }, + ExpectedErr: errMock, + SetupMock: func(request inputEdit, m *commentMocks) { + m.repo.EXPECT().GetCommentAuthor(gomock.Any(), gomock.Any()).Return(uint32(0), errMock) + }, + }, + { + name: "2", + SetupInput: func() (*inputEdit, error) { + return &inputEdit{ + userID: 10, + }, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputEdit, + ) (struct{}, error) { + return struct{}{}, implementation.EditComment(ctx, request.commentID, request.userID, request.comment) + }, + ExpectedResult: func() (struct{}, error) { + return struct{}{}, nil + }, + ExpectedErr: my_err.ErrAccessDenied, + SetupMock: func(request inputEdit, m *commentMocks) { + m.repo.EXPECT().GetCommentAuthor(gomock.Any(), gomock.Any()).Return(uint32(1), nil) + }, + }, + { + name: "3", + SetupInput: func() (*inputEdit, error) { + return &inputEdit{ + userID: 10, + }, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputEdit, + ) (struct{}, error) { + return struct{}{}, implementation.EditComment(ctx, request.commentID, request.userID, request.comment) + }, + ExpectedResult: func() (struct{}, error) { + return struct{}{}, nil + }, + ExpectedErr: errMock, + SetupMock: func(request inputEdit, m *commentMocks) { + m.repo.EXPECT().GetCommentAuthor(gomock.Any(), gomock.Any()).Return(uint32(10), nil) + m.repo.EXPECT().UpdateComment(gomock.Any(), gomock.Any(), gomock.Any()).Return(errMock) + }, + }, + { + name: "4", + SetupInput: func() (*inputEdit, error) { + return &inputEdit{ + userID: 10, + }, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputEdit, + ) (struct{}, error) { + return struct{}{}, implementation.EditComment(ctx, request.commentID, request.userID, request.comment) + }, + ExpectedResult: func() (struct{}, error) { + return struct{}{}, nil + }, + ExpectedErr: nil, + SetupMock: func(request inputEdit, m *commentMocks) { + m.repo.EXPECT().GetCommentAuthor(gomock.Any(), gomock.Any()).Return(uint32(10), nil) + m.repo.EXPECT().UpdateComment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getCommentService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +type inputGet struct { + postID uint32 + lastID uint32 + newest bool +} + +func TestCommentService_GetComments(t *testing.T) { + tests := []TableTest3[[]*models.CommentDto, inputGet]{ + { + name: "1", + SetupInput: func() (*inputGet, error) { + return &inputGet{}, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputGet, + ) ([]*models.CommentDto, error) { + return implementation.GetComments(ctx, request.postID, request.lastID, request.newest) + }, + ExpectedResult: func() ([]*models.CommentDto, error) { + return nil, nil + }, + ExpectedErr: errMock, + SetupMock: func(request inputGet, m *commentMocks) { + m.repo.EXPECT().GetComments(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errMock) + }, + }, + { + name: "2", + SetupInput: func() (*inputGet, error) { + return &inputGet{}, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputGet, + ) ([]*models.CommentDto, error) { + return implementation.GetComments(ctx, request.postID, request.lastID, request.newest) + }, + ExpectedResult: func() ([]*models.CommentDto, error) { + return []*models.CommentDto{ + { + ID: 1, + Header: models.Header{AuthorID: 1, Author: "Alexey Zemliakov"}, + }, + { + ID: 2, + Header: models.Header{AuthorID: 1, Author: "Alexey Zemliakov"}, + }, + }, nil + }, + ExpectedErr: nil, + SetupMock: func(request inputGet, m *commentMocks) { + m.repo.EXPECT().GetComments(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return( + []*models.CommentDto{ + {ID: 1}, + {ID: 2}, + }, + nil, + ) + m.profileRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()). + Return( + &models.Header{ + AuthorID: 1, + Author: "Alexey Zemliakov", + }, + nil, + ).AnyTimes() + }, + }, + { + name: "3", + SetupInput: func() (*inputGet, error) { + return &inputGet{}, nil + }, + Run: func( + ctx context.Context, implementation *CommentService, request inputGet, + ) ([]*models.CommentDto, error) { + return implementation.GetComments(ctx, request.postID, request.lastID, request.newest) + }, + ExpectedResult: func() ([]*models.CommentDto, error) { + return nil, nil + }, + ExpectedErr: errMock, + SetupMock: func(request inputGet, m *commentMocks) { + m.repo.EXPECT().GetComments(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return( + []*models.CommentDto{ + {ID: 1}, + {ID: 2}, + }, + nil, + ) + m.profileRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()). + Return(nil, errMock).AnyTimes() + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getCommentService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +type TableTest3[T, In any] struct { + name string + SetupInput func() (*In, error) + Run func(context.Context, *CommentService, In) (T, error) + ExpectedResult func() (T, error) + ExpectedErr error + SetupMock func(In, *commentMocks) +} diff --git a/internal/post/service/mock.go b/internal/post/service/mock.go index 5f540151..c219d6da 100644 --- a/internal/post/service/mock.go +++ b/internal/post/service/mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: post.go +// +// Generated by this command: +// +// mockgen -destination=mock.go -source=post.go -package=service +// // Package service is a generated GoMock package. package service @@ -16,6 +21,7 @@ import ( type MockDB struct { ctrl *gomock.Controller recorder *MockDBMockRecorder + isgomock struct{} } // MockDBMockRecorder is the mock recorder for MockDB. @@ -45,13 +51,13 @@ func (m *MockDB) CheckLikes(ctx context.Context, postID, userID uint32) (bool, e } // CheckLikes indicates an expected call of CheckLikes. -func (mr *MockDBMockRecorder) CheckLikes(ctx, postID, userID interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) CheckLikes(ctx, postID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckLikes", reflect.TypeOf((*MockDB)(nil).CheckLikes), ctx, postID, userID) } // Create mocks base method. -func (m *MockDB) Create(ctx context.Context, post *models.Post) (uint32, error) { +func (m *MockDB) Create(ctx context.Context, post *models.PostDto) (uint32, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Create", ctx, post) ret0, _ := ret[0].(uint32) @@ -60,13 +66,13 @@ func (m *MockDB) Create(ctx context.Context, post *models.Post) (uint32, error) } // Create indicates an expected call of Create. -func (mr *MockDBMockRecorder) Create(ctx, post interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) Create(ctx, post any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockDB)(nil).Create), ctx, post) } // CreateCommunityPost mocks base method. -func (m *MockDB) CreateCommunityPost(ctx context.Context, post *models.Post, communityID uint32) (uint32, error) { +func (m *MockDB) CreateCommunityPost(ctx context.Context, post *models.PostDto, communityID uint32) (uint32, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateCommunityPost", ctx, post, communityID) ret0, _ := ret[0].(uint32) @@ -75,7 +81,7 @@ func (m *MockDB) CreateCommunityPost(ctx context.Context, post *models.Post, com } // CreateCommunityPost indicates an expected call of CreateCommunityPost. -func (mr *MockDBMockRecorder) CreateCommunityPost(ctx, post, communityID interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) CreateCommunityPost(ctx, post, communityID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCommunityPost", reflect.TypeOf((*MockDB)(nil).CreateCommunityPost), ctx, post, communityID) } @@ -89,7 +95,7 @@ func (m *MockDB) Delete(ctx context.Context, postID uint32) error { } // Delete indicates an expected call of Delete. -func (mr *MockDBMockRecorder) Delete(ctx, postID interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) Delete(ctx, postID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockDB)(nil).Delete), ctx, postID) } @@ -103,52 +109,67 @@ func (m *MockDB) DeleteLikeFromPost(ctx context.Context, postID, userID uint32) } // DeleteLikeFromPost indicates an expected call of DeleteLikeFromPost. -func (mr *MockDBMockRecorder) DeleteLikeFromPost(ctx, postID, userID interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) DeleteLikeFromPost(ctx, postID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteLikeFromPost", reflect.TypeOf((*MockDB)(nil).DeleteLikeFromPost), ctx, postID, userID) } // Get mocks base method. -func (m *MockDB) Get(ctx context.Context, postID uint32) (*models.Post, error) { +func (m *MockDB) Get(ctx context.Context, postID uint32) (*models.PostDto, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get", ctx, postID) - ret0, _ := ret[0].(*models.Post) + ret0, _ := ret[0].(*models.PostDto) ret1, _ := ret[1].(error) return ret0, ret1 } // Get indicates an expected call of Get. -func (mr *MockDBMockRecorder) Get(ctx, postID interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) Get(ctx, postID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockDB)(nil).Get), ctx, postID) } +// GetCommentCount mocks base method. +func (m *MockDB) GetCommentCount(ctx context.Context, postID uint32) (uint32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCommentCount", ctx, postID) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCommentCount indicates an expected call of GetCommentCount. +func (mr *MockDBMockRecorder) GetCommentCount(ctx, postID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCommentCount", reflect.TypeOf((*MockDB)(nil).GetCommentCount), ctx, postID) +} + // GetCommunityPosts mocks base method. -func (m *MockDB) GetCommunityPosts(ctx context.Context, communityID, lastID uint32) ([]*models.Post, error) { +func (m *MockDB) GetCommunityPosts(ctx context.Context, communityID, lastID uint32) ([]*models.PostDto, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetCommunityPosts", ctx, communityID, lastID) - ret0, _ := ret[0].([]*models.Post) + ret0, _ := ret[0].([]*models.PostDto) ret1, _ := ret[1].(error) return ret0, ret1 } // GetCommunityPosts indicates an expected call of GetCommunityPosts. -func (mr *MockDBMockRecorder) GetCommunityPosts(ctx, communityID, lastID interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) GetCommunityPosts(ctx, communityID, lastID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCommunityPosts", reflect.TypeOf((*MockDB)(nil).GetCommunityPosts), ctx, communityID, lastID) } // GetFriendsPosts mocks base method. -func (m *MockDB) GetFriendsPosts(ctx context.Context, friendsID []uint32, lastID uint32) ([]*models.Post, error) { +func (m *MockDB) GetFriendsPosts(ctx context.Context, friendsID []uint32, lastID uint32) ([]*models.PostDto, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetFriendsPosts", ctx, friendsID, lastID) - ret0, _ := ret[0].([]*models.Post) + ret0, _ := ret[0].([]*models.PostDto) ret1, _ := ret[1].(error) return ret0, ret1 } // GetFriendsPosts indicates an expected call of GetFriendsPosts. -func (mr *MockDBMockRecorder) GetFriendsPosts(ctx, friendsID, lastID interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) GetFriendsPosts(ctx, friendsID, lastID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFriendsPosts", reflect.TypeOf((*MockDB)(nil).GetFriendsPosts), ctx, friendsID, lastID) } @@ -163,7 +184,7 @@ func (m *MockDB) GetLikesOnPost(ctx context.Context, postID uint32) (uint32, err } // GetLikesOnPost indicates an expected call of GetLikesOnPost. -func (mr *MockDBMockRecorder) GetLikesOnPost(ctx, postID interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) GetLikesOnPost(ctx, postID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLikesOnPost", reflect.TypeOf((*MockDB)(nil).GetLikesOnPost), ctx, postID) } @@ -178,22 +199,22 @@ func (m *MockDB) GetPostAuthor(ctx context.Context, postID uint32) (uint32, erro } // GetPostAuthor indicates an expected call of GetPostAuthor. -func (mr *MockDBMockRecorder) GetPostAuthor(ctx, postID interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) GetPostAuthor(ctx, postID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPostAuthor", reflect.TypeOf((*MockDB)(nil).GetPostAuthor), ctx, postID) } // GetPosts mocks base method. -func (m *MockDB) GetPosts(ctx context.Context, lastID uint32) ([]*models.Post, error) { +func (m *MockDB) GetPosts(ctx context.Context, lastID uint32) ([]*models.PostDto, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetPosts", ctx, lastID) - ret0, _ := ret[0].([]*models.Post) + ret0, _ := ret[0].([]*models.PostDto) ret1, _ := ret[1].(error) return ret0, ret1 } // GetPosts indicates an expected call of GetPosts. -func (mr *MockDBMockRecorder) GetPosts(ctx, lastID interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) GetPosts(ctx, lastID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPosts", reflect.TypeOf((*MockDB)(nil).GetPosts), ctx, lastID) } @@ -207,13 +228,13 @@ func (m *MockDB) SetLikeToPost(ctx context.Context, postID, userID uint32) error } // SetLikeToPost indicates an expected call of SetLikeToPost. -func (mr *MockDBMockRecorder) SetLikeToPost(ctx, postID, userID interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) SetLikeToPost(ctx, postID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLikeToPost", reflect.TypeOf((*MockDB)(nil).SetLikeToPost), ctx, postID, userID) } // Update mocks base method. -func (m *MockDB) Update(ctx context.Context, post *models.Post) error { +func (m *MockDB) Update(ctx context.Context, post *models.PostDto) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Update", ctx, post) ret0, _ := ret[0].(error) @@ -221,7 +242,7 @@ func (m *MockDB) Update(ctx context.Context, post *models.Post) error { } // Update indicates an expected call of Update. -func (mr *MockDBMockRecorder) Update(ctx, post interface{}) *gomock.Call { +func (mr *MockDBMockRecorder) Update(ctx, post any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDB)(nil).Update), ctx, post) } @@ -230,6 +251,7 @@ func (mr *MockDBMockRecorder) Update(ctx, post interface{}) *gomock.Call { type MockProfileRepo struct { ctrl *gomock.Controller recorder *MockProfileRepoMockRecorder + isgomock struct{} } // MockProfileRepoMockRecorder is the mock recorder for MockProfileRepo. @@ -259,7 +281,7 @@ func (m *MockProfileRepo) GetFriendsID(ctx context.Context, userID uint32) ([]ui } // GetFriendsID indicates an expected call of GetFriendsID. -func (mr *MockProfileRepoMockRecorder) GetFriendsID(ctx, userID interface{}) *gomock.Call { +func (mr *MockProfileRepoMockRecorder) GetFriendsID(ctx, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFriendsID", reflect.TypeOf((*MockProfileRepo)(nil).GetFriendsID), ctx, userID) } @@ -274,7 +296,7 @@ func (m *MockProfileRepo) GetHeader(ctx context.Context, userID uint32) (*models } // GetHeader indicates an expected call of GetHeader. -func (mr *MockProfileRepoMockRecorder) GetHeader(ctx, userID interface{}) *gomock.Call { +func (mr *MockProfileRepoMockRecorder) GetHeader(ctx, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeader", reflect.TypeOf((*MockProfileRepo)(nil).GetHeader), ctx, userID) } @@ -283,6 +305,7 @@ func (mr *MockProfileRepoMockRecorder) GetHeader(ctx, userID interface{}) *gomoc type MockCommunityRepo struct { ctrl *gomock.Controller recorder *MockCommunityRepoMockRecorder + isgomock struct{} } // MockCommunityRepoMockRecorder is the mock recorder for MockCommunityRepo. @@ -311,7 +334,7 @@ func (m *MockCommunityRepo) CheckAccess(ctx context.Context, communityID, userID } // CheckAccess indicates an expected call of CheckAccess. -func (mr *MockCommunityRepoMockRecorder) CheckAccess(ctx, communityID, userID interface{}) *gomock.Call { +func (mr *MockCommunityRepoMockRecorder) CheckAccess(ctx, communityID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckAccess", reflect.TypeOf((*MockCommunityRepo)(nil).CheckAccess), ctx, communityID, userID) } @@ -326,7 +349,7 @@ func (m *MockCommunityRepo) GetHeader(ctx context.Context, communityID uint32) ( } // GetHeader indicates an expected call of GetHeader. -func (mr *MockCommunityRepoMockRecorder) GetHeader(ctx, communityID interface{}) *gomock.Call { +func (mr *MockCommunityRepoMockRecorder) GetHeader(ctx, communityID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeader", reflect.TypeOf((*MockCommunityRepo)(nil).GetHeader), ctx, communityID) } diff --git a/internal/post/service/mock_helper.go b/internal/post/service/mock_helper.go index b4d722bd..7c016606 100644 --- a/internal/post/service/mock_helper.go +++ b/internal/post/service/mock_helper.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: post_profile.go +// +// Generated by this command: +// +// mockgen -destination=mock_helper.go -source=post_profile.go -package=service +// // Package service is a generated GoMock package. package service @@ -16,6 +21,7 @@ import ( type MockPostProfileDB struct { ctrl *gomock.Controller recorder *MockPostProfileDBMockRecorder + isgomock struct{} } // MockPostProfileDBMockRecorder is the mock recorder for MockPostProfileDB. @@ -45,26 +51,41 @@ func (m *MockPostProfileDB) CheckLikes(ctx context.Context, postID, userID uint3 } // CheckLikes indicates an expected call of CheckLikes. -func (mr *MockPostProfileDBMockRecorder) CheckLikes(ctx, postID, userID interface{}) *gomock.Call { +func (mr *MockPostProfileDBMockRecorder) CheckLikes(ctx, postID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckLikes", reflect.TypeOf((*MockPostProfileDB)(nil).CheckLikes), ctx, postID, userID) } // GetAuthorPosts mocks base method. -func (m *MockPostProfileDB) GetAuthorPosts(ctx context.Context, header *models.Header) ([]*models.Post, error) { +func (m *MockPostProfileDB) GetAuthorPosts(ctx context.Context, header *models.Header) ([]*models.PostDto, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAuthorPosts", ctx, header) - ret0, _ := ret[0].([]*models.Post) + ret0, _ := ret[0].([]*models.PostDto) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAuthorPosts indicates an expected call of GetAuthorPosts. -func (mr *MockPostProfileDBMockRecorder) GetAuthorPosts(ctx, header interface{}) *gomock.Call { +func (mr *MockPostProfileDBMockRecorder) GetAuthorPosts(ctx, header any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorPosts", reflect.TypeOf((*MockPostProfileDB)(nil).GetAuthorPosts), ctx, header) } +// GetCommentCount mocks base method. +func (m *MockPostProfileDB) GetCommentCount(ctx context.Context, postID uint32) (uint32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCommentCount", ctx, postID) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCommentCount indicates an expected call of GetCommentCount. +func (mr *MockPostProfileDBMockRecorder) GetCommentCount(ctx, postID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCommentCount", reflect.TypeOf((*MockPostProfileDB)(nil).GetCommentCount), ctx, postID) +} + // GetLikesOnPost mocks base method. func (m *MockPostProfileDB) GetLikesOnPost(ctx context.Context, postID uint32) (uint32, error) { m.ctrl.T.Helper() @@ -75,7 +96,7 @@ func (m *MockPostProfileDB) GetLikesOnPost(ctx context.Context, postID uint32) ( } // GetLikesOnPost indicates an expected call of GetLikesOnPost. -func (mr *MockPostProfileDBMockRecorder) GetLikesOnPost(ctx, postID interface{}) *gomock.Call { +func (mr *MockPostProfileDBMockRecorder) GetLikesOnPost(ctx, postID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLikesOnPost", reflect.TypeOf((*MockPostProfileDB)(nil).GetLikesOnPost), ctx, postID) } diff --git a/internal/post/service/post.go b/internal/post/service/post.go index 30ac4bad..d2d2b8a2 100644 --- a/internal/post/service/post.go +++ b/internal/post/service/post.go @@ -11,21 +11,23 @@ import ( //go:generate mockgen -destination=mock.go -source=$GOFILE -package=${GOPACKAGE} type DB interface { - Create(ctx context.Context, post *models.Post) (uint32, error) - Get(ctx context.Context, postID uint32) (*models.Post, error) - Update(ctx context.Context, post *models.Post) error + Create(ctx context.Context, post *models.PostDto) (uint32, error) + Get(ctx context.Context, postID uint32) (*models.PostDto, error) + Update(ctx context.Context, post *models.PostDto) error Delete(ctx context.Context, postID uint32) error - GetPosts(ctx context.Context, lastID uint32) ([]*models.Post, error) - GetFriendsPosts(ctx context.Context, friendsID []uint32, lastID uint32) ([]*models.Post, error) + GetPosts(ctx context.Context, lastID uint32) ([]*models.PostDto, error) + GetFriendsPosts(ctx context.Context, friendsID []uint32, lastID uint32) ([]*models.PostDto, error) GetPostAuthor(ctx context.Context, postID uint32) (uint32, error) - CreateCommunityPost(ctx context.Context, post *models.Post, communityID uint32) (uint32, error) - GetCommunityPosts(ctx context.Context, communityID uint32, lastID uint32) ([]*models.Post, error) + CreateCommunityPost(ctx context.Context, post *models.PostDto, communityID uint32) (uint32, error) + GetCommunityPosts(ctx context.Context, communityID uint32, lastID uint32) ([]*models.PostDto, error) SetLikeToPost(ctx context.Context, postID uint32, userID uint32) error DeleteLikeFromPost(ctx context.Context, postID uint32, userID uint32) error GetLikesOnPost(ctx context.Context, postID uint32) (uint32, error) CheckLikes(ctx context.Context, postID, userID uint32) (bool, error) + + GetCommentCount(ctx context.Context, postID uint32) (uint32, error) } type ProfileRepo interface { @@ -52,7 +54,7 @@ func NewPostServiceImpl(db DB, profileRepo ProfileRepo, repo CommunityRepo) *Pos } } -func (s *PostServiceImpl) Create(ctx context.Context, post *models.Post) (uint32, error) { +func (s *PostServiceImpl) Create(ctx context.Context, post *models.PostDto) (uint32, error) { id, err := s.db.Create(ctx, post) if err != nil { return 0, fmt.Errorf("create post: %w", err) @@ -61,7 +63,7 @@ func (s *PostServiceImpl) Create(ctx context.Context, post *models.Post) (uint32 return id, nil } -func (s *PostServiceImpl) Get(ctx context.Context, postID, userID uint32) (*models.Post, error) { +func (s *PostServiceImpl) Get(ctx context.Context, postID, userID uint32) (*models.PostDto, error) { post, err := s.db.Get(ctx, postID) if err != nil { return nil, fmt.Errorf("get post: %w", err) @@ -83,7 +85,7 @@ func (s *PostServiceImpl) Delete(ctx context.Context, postID uint32) error { return nil } -func (s *PostServiceImpl) Update(ctx context.Context, post *models.Post) error { +func (s *PostServiceImpl) Update(ctx context.Context, post *models.PostDto) error { post.PostContent.UpdatedAt = time.Now() err := s.db.Update(ctx, post) @@ -94,7 +96,7 @@ func (s *PostServiceImpl) Update(ctx context.Context, post *models.Post) error { return nil } -func (s *PostServiceImpl) GetBatch(ctx context.Context, lastID, userID uint32) ([]*models.Post, error) { +func (s *PostServiceImpl) GetBatch(ctx context.Context, lastID, userID uint32) ([]*models.PostDto, error) { posts, err := s.db.GetPosts(ctx, lastID) if err != nil { return nil, fmt.Errorf("get posts: %w", err) @@ -109,7 +111,9 @@ func (s *PostServiceImpl) GetBatch(ctx context.Context, lastID, userID uint32) ( return posts, nil } -func (s *PostServiceImpl) GetBatchFromFriend(ctx context.Context, userID uint32, lastID uint32) ([]*models.Post, error) { +func (s *PostServiceImpl) GetBatchFromFriend( + ctx context.Context, userID uint32, lastID uint32, +) ([]*models.PostDto, error) { friends, err := s.profileRepo.GetFriendsID(ctx, userID) if err != nil { return nil, fmt.Errorf("get friends: %w", err) @@ -142,7 +146,7 @@ func (s *PostServiceImpl) GetPostAuthorID(ctx context.Context, postID uint32) (u return id, nil } -func (s *PostServiceImpl) CreateCommunityPost(ctx context.Context, post *models.Post) (uint32, error) { +func (s *PostServiceImpl) CreateCommunityPost(ctx context.Context, post *models.PostDto) (uint32, error) { id, err := s.db.CreateCommunityPost(ctx, post, post.Header.CommunityID) if err != nil { return 0, fmt.Errorf("create post: %w", err) @@ -151,7 +155,9 @@ func (s *PostServiceImpl) CreateCommunityPost(ctx context.Context, post *models. return id, nil } -func (s *PostServiceImpl) GetCommunityPost(ctx context.Context, communityID, userID, lastID uint32) ([]*models.Post, error) { +func (s *PostServiceImpl) GetCommunityPost( + ctx context.Context, communityID, userID, lastID uint32, +) ([]*models.PostDto, error) { posts, err := s.db.GetCommunityPosts(ctx, communityID, lastID) if err != nil { return nil, fmt.Errorf("get posts: %w", err) @@ -199,7 +205,7 @@ func (s *PostServiceImpl) CheckLikes(ctx context.Context, postID, userID uint32) return res, nil } -func (s *PostServiceImpl) setPostFields(ctx context.Context, post *models.Post, userID uint32) error { +func (s *PostServiceImpl) setPostFields(ctx context.Context, post *models.PostDto, userID uint32) error { var ( header *models.Header err error @@ -231,6 +237,10 @@ func (s *PostServiceImpl) setPostFields(ctx context.Context, post *models.Post, post.IsLiked = liked post.PostContent.CreatedAt = convertTime(post.PostContent.CreatedAt) + post.CommentCount, err = s.db.GetCommentCount(ctx, post.ID) + if err != nil { + return fmt.Errorf("get comment count: %w", err) + } return nil } diff --git a/internal/post/service/post_profile.go b/internal/post/service/post_profile.go index 06ed4676..9d6d03e4 100644 --- a/internal/post/service/post_profile.go +++ b/internal/post/service/post_profile.go @@ -9,9 +9,10 @@ import ( //go:generate mockgen -destination=mock_helper.go -source=$GOFILE -package=${GOPACKAGE} type PostProfileDB interface { - GetAuthorPosts(ctx context.Context, header *models.Header) ([]*models.Post, error) + GetAuthorPosts(ctx context.Context, header *models.Header) ([]*models.PostDto, error) GetLikesOnPost(ctx context.Context, postID uint32) (uint32, error) CheckLikes(ctx context.Context, postID, userID uint32) (bool, error) + GetCommentCount(ctx context.Context, postID uint32) (uint32, error) } type PostProfileImpl struct { @@ -24,7 +25,9 @@ func NewPostProfileImpl(db PostProfileDB) *PostProfileImpl { } } -func (p *PostProfileImpl) GetAuthorsPosts(ctx context.Context, header *models.Header, userID uint32) ([]*models.Post, error) { +func (p *PostProfileImpl) GetAuthorsPosts( + ctx context.Context, header *models.Header, userID uint32, +) ([]*models.Post, error) { posts, err := p.db.GetAuthorPosts(ctx, header) if err != nil { return nil, err @@ -43,7 +46,17 @@ func (p *PostProfileImpl) GetAuthorsPosts(ctx context.Context, header *models.He } posts[i].IsLiked = liked posts[i].PostContent.CreatedAt = convertTime(post.PostContent.CreatedAt) + commentCount, err := p.db.GetCommentCount(ctx, post.ID) + if err != nil { + return nil, fmt.Errorf("get comment count: %w", err) + } + posts[i].CommentCount = commentCount + } + res := make([]*models.Post, 0, len(posts)) + for _, post := range posts { + postFrom := post.FromDto() + res = append(res, &postFrom) } - return posts, nil + return res, nil } diff --git a/internal/post/service/post_profile_test.go b/internal/post/service/post_profile_test.go index 6a27105b..cfb3f9b8 100644 --- a/internal/post/service/post_profile_test.go +++ b/internal/post/service/post_profile_test.go @@ -60,11 +60,12 @@ func TestGetAuthorsPost(t *testing.T) { ExpectedErr: errMock, SetupMock: func(request input, m *mocksHelper) { m.repo.EXPECT().GetAuthorPosts(gomock.Any(), gomock.Any()).Return( - []*models.Post{ + []*models.PostDto{ { ID: 1, }, - }, nil) + }, nil, + ) m.repo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(0), errMock) }, }, @@ -82,11 +83,12 @@ func TestGetAuthorsPost(t *testing.T) { ExpectedErr: errMock, SetupMock: func(request input, m *mocksHelper) { m.repo.EXPECT().GetAuthorPosts(gomock.Any(), gomock.Any()).Return( - []*models.Post{ + []*models.PostDto{ { ID: 1, }, - }, nil) + }, nil, + ) m.repo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(1), nil) m.repo.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(false, errMock) }, @@ -102,52 +104,82 @@ func TestGetAuthorsPost(t *testing.T) { ExpectedResult: func() ([]*models.Post, error) { return []*models.Post{ { - ID: 1, - IsLiked: true, - LikesCount: 1, + ID: 1, + IsLiked: true, + LikesCount: 1, + CommentCount: 1, }, }, nil }, ExpectedErr: nil, SetupMock: func(request input, m *mocksHelper) { m.repo.EXPECT().GetAuthorPosts(gomock.Any(), gomock.Any()).Return( - []*models.Post{ + []*models.PostDto{ { ID: 1, }, - }, nil) + }, nil, + ) + m.repo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(1), nil).AnyTimes() + m.repo.EXPECT().GetCommentCount(gomock.Any(), gomock.Any()).Return(uint32(1), nil).AnyTimes() + m.repo.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) + }, + }, + { + name: "5", + SetupInput: func() (*input, error) { + return &input{}, nil + }, + Run: func(ctx context.Context, implementation *PostProfileImpl, request input) ([]*models.Post, error) { + return implementation.GetAuthorsPosts(ctx, request.header, request.userID) + }, + ExpectedResult: func() ([]*models.Post, error) { + return nil, nil + }, + ExpectedErr: errMock, + SetupMock: func(request input, m *mocksHelper) { + m.repo.EXPECT().GetAuthorPosts(gomock.Any(), gomock.Any()).Return( + []*models.PostDto{ + { + ID: 1, + }, + }, nil, + ) m.repo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(1), nil) + m.repo.EXPECT().GetCommentCount(gomock.Any(), gomock.Any()).Return(uint32(0), errMock) m.repo.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getServiceHelper(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getServiceHelper(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } diff --git a/internal/post/service/post_test.go b/internal/post/service/post_test.go index 823e93ea..248344e2 100644 --- a/internal/post/service/post_test.go +++ b/internal/post/service/post_test.go @@ -38,67 +38,69 @@ func TestNewPostService(t *testing.T) { var errMock = errors.New("mock error") func TestCreate(t *testing.T) { - tests := []TableTest[uint32, models.Post]{ + tests := []TableTest[uint32, models.PostDto]{ { name: "1", - SetupInput: func() (*models.Post, error) { - return &models.Post{PostContent: models.Content{Text: "new post"}}, nil + SetupInput: func() (*models.PostDto, error) { + return &models.PostDto{PostContent: models.ContentDto{Text: "new post"}}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request models.Post) (uint32, error) { + Run: func(ctx context.Context, implementation *PostServiceImpl, request models.PostDto) (uint32, error) { return implementation.Create(ctx, &request) }, ExpectedResult: func() (uint32, error) { return 0, nil }, ExpectedErr: errMock, - SetupMock: func(request models.Post, m *mocks) { + SetupMock: func(request models.PostDto, m *mocks) { m.postRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(uint32(0), errMock) }, }, { name: "2", - SetupInput: func() (*models.Post, error) { - return &models.Post{PostContent: models.Content{Text: "new real post"}}, nil + SetupInput: func() (*models.PostDto, error) { + return &models.PostDto{PostContent: models.ContentDto{Text: "new real post"}}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request models.Post) (uint32, error) { + Run: func(ctx context.Context, implementation *PostServiceImpl, request models.PostDto) (uint32, error) { return implementation.Create(ctx, &request) }, ExpectedResult: func() (uint32, error) { return 1, nil }, ExpectedErr: nil, - SetupMock: func(request models.Post, m *mocks) { + SetupMock: func(request models.PostDto, m *mocks) { m.postRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(uint32(1), nil) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getService(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -108,16 +110,18 @@ type userAndPostIDs struct { } func TestGet(t *testing.T) { - tests := []TableTest[*models.Post, userAndPostIDs]{ + tests := []TableTest[*models.PostDto, userAndPostIDs]{ { name: "1", SetupInput: func() (*userAndPostIDs, error) { return &userAndPostIDs{postId: 0, userID: 0}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndPostIDs) (*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndPostIDs, + ) (*models.PostDto, error) { return implementation.Get(ctx, request.postId, request.userID) }, - ExpectedResult: func() (*models.Post, error) { + ExpectedResult: func() (*models.PostDto, error) { return nil, nil }, ExpectedErr: errMock, @@ -130,22 +134,25 @@ func TestGet(t *testing.T) { SetupInput: func() (*userAndPostIDs, error) { return &userAndPostIDs{postId: 1, userID: 1}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndPostIDs) (*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndPostIDs, + ) (*models.PostDto, error) { return implementation.Get(ctx, request.postId, request.userID) }, - ExpectedResult: func() (*models.Post, error) { + ExpectedResult: func() (*models.PostDto, error) { return nil, nil }, ExpectedErr: errMock, SetupMock: func(request userAndPostIDs, m *mocks) { m.postRepo.EXPECT().Get(gomock.Any(), gomock.Any()).Return( - &models.Post{ + &models.PostDto{ Header: models.Header{ CommunityID: 0, AuthorID: 1, }, }, - nil) + nil, + ) m.profileRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(nil, errMock) }, }, @@ -154,22 +161,25 @@ func TestGet(t *testing.T) { SetupInput: func() (*userAndPostIDs, error) { return &userAndPostIDs{postId: 1, userID: 1}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndPostIDs) (*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndPostIDs, + ) (*models.PostDto, error) { return implementation.Get(ctx, request.postId, request.userID) }, - ExpectedResult: func() (*models.Post, error) { + ExpectedResult: func() (*models.PostDto, error) { return nil, nil }, ExpectedErr: errMock, SetupMock: func(request userAndPostIDs, m *mocks) { m.postRepo.EXPECT().Get(gomock.Any(), gomock.Any()).Return( - &models.Post{ + &models.PostDto{ Header: models.Header{ CommunityID: 1, AuthorID: 0, }, }, - nil) + nil, + ) m.communityRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(nil, errMock) }, }, @@ -178,27 +188,32 @@ func TestGet(t *testing.T) { SetupInput: func() (*userAndPostIDs, error) { return &userAndPostIDs{postId: 1, userID: 1}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndPostIDs) (*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndPostIDs, + ) (*models.PostDto, error) { return implementation.Get(ctx, request.postId, request.userID) }, - ExpectedResult: func() (*models.Post, error) { + ExpectedResult: func() (*models.PostDto, error) { return nil, nil }, ExpectedErr: errMock, SetupMock: func(request userAndPostIDs, m *mocks) { m.postRepo.EXPECT().Get(gomock.Any(), gomock.Any()).Return( - &models.Post{ + &models.PostDto{ Header: models.Header{ CommunityID: 0, AuthorID: 1, }, }, - nil) - m.profileRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(&models.Header{ - CommunityID: 0, - AuthorID: 1, - Author: "user", - }, nil) + nil, + ) + m.profileRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return( + &models.Header{ + CommunityID: 0, + AuthorID: 1, + Author: "user", + }, nil, + ) m.postRepo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(0), errMock) }, }, @@ -207,27 +222,32 @@ func TestGet(t *testing.T) { SetupInput: func() (*userAndPostIDs, error) { return &userAndPostIDs{postId: 1, userID: 1}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndPostIDs) (*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndPostIDs, + ) (*models.PostDto, error) { return implementation.Get(ctx, request.postId, request.userID) }, - ExpectedResult: func() (*models.Post, error) { + ExpectedResult: func() (*models.PostDto, error) { return nil, nil }, ExpectedErr: errMock, SetupMock: func(request userAndPostIDs, m *mocks) { m.postRepo.EXPECT().Get(gomock.Any(), gomock.Any()).Return( - &models.Post{ + &models.PostDto{ Header: models.Header{ CommunityID: 1, AuthorID: 0, }, }, - nil) - m.communityRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(&models.Header{ - CommunityID: 1, - AuthorID: 0, - Author: "community", - }, nil) + nil, + ) + m.communityRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return( + &models.Header{ + CommunityID: 1, + AuthorID: 0, + Author: "community", + }, nil, + ) m.postRepo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(1), nil) m.postRepo.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(false, errMock) }, @@ -237,11 +257,13 @@ func TestGet(t *testing.T) { SetupInput: func() (*userAndPostIDs, error) { return &userAndPostIDs{postId: 1, userID: 1}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndPostIDs) (*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndPostIDs, + ) (*models.PostDto, error) { return implementation.Get(ctx, request.postId, request.userID) }, - ExpectedResult: func() (*models.Post, error) { - return &models.Post{ + ExpectedResult: func() (*models.PostDto, error) { + return &models.PostDto{ Header: models.Header{ CommunityID: 1, AuthorID: 0, @@ -254,18 +276,58 @@ func TestGet(t *testing.T) { ExpectedErr: nil, SetupMock: func(request userAndPostIDs, m *mocks) { m.postRepo.EXPECT().Get(gomock.Any(), gomock.Any()).Return( - &models.Post{ + &models.PostDto{ + Header: models.Header{ + CommunityID: 1, + AuthorID: 0, + }, + }, + nil, + ) + m.communityRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return( + &models.Header{ + CommunityID: 1, + AuthorID: 0, + Author: "community", + }, nil, + ) + m.postRepo.EXPECT().GetCommentCount(gomock.Any(), gomock.Any()).Return(uint32(0), nil) + m.postRepo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(1), nil) + m.postRepo.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) + }, + }, + { + name: "7", + SetupInput: func() (*userAndPostIDs, error) { + return &userAndPostIDs{postId: 1, userID: 1}, nil + }, + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndPostIDs, + ) (*models.PostDto, error) { + return implementation.Get(ctx, request.postId, request.userID) + }, + ExpectedResult: func() (*models.PostDto, error) { + return nil, nil + }, + ExpectedErr: errMock, + SetupMock: func(request userAndPostIDs, m *mocks) { + m.postRepo.EXPECT().Get(gomock.Any(), gomock.Any()).Return( + &models.PostDto{ Header: models.Header{ CommunityID: 1, AuthorID: 0, }, }, - nil) - m.communityRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(&models.Header{ - CommunityID: 1, - AuthorID: 0, - Author: "community", - }, nil) + nil, + ) + m.communityRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return( + &models.Header{ + CommunityID: 1, + AuthorID: 0, + Author: "community", + }, nil, + ) + m.postRepo.EXPECT().GetCommentCount(gomock.Any(), gomock.Any()).Return(uint32(0), errMock) m.postRepo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(1), nil) m.postRepo.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) }, @@ -273,31 +335,33 @@ func TestGet(t *testing.T) { } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getService(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -342,42 +406,44 @@ func TestDelete(t *testing.T) { } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getService(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } func TestUpdate(t *testing.T) { - tests := []TableTest[struct{}, models.Post]{ + tests := []TableTest[struct{}, models.PostDto]{ { name: "1", - SetupInput: func() (*models.Post, error) { - return &models.Post{}, nil + SetupInput: func() (*models.PostDto, error) { + return &models.PostDto{}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request models.Post) (struct{}, error) { + Run: func(ctx context.Context, implementation *PostServiceImpl, request models.PostDto) (struct{}, error) { err := implementation.Update(ctx, &request) return struct{}{}, err }, @@ -385,16 +451,16 @@ func TestUpdate(t *testing.T) { return struct{}{}, nil }, ExpectedErr: errMock, - SetupMock: func(request models.Post, m *mocks) { + SetupMock: func(request models.PostDto, m *mocks) { m.postRepo.EXPECT().Update(gomock.Any(), gomock.Any()).Return(errMock) }, }, { name: "2", - SetupInput: func() (*models.Post, error) { - return &models.Post{ID: 1, PostContent: models.Content{Text: "New post"}}, nil + SetupInput: func() (*models.PostDto, error) { + return &models.PostDto{ID: 1, PostContent: models.ContentDto{Text: "New post"}}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request models.Post) (struct{}, error) { + Run: func(ctx context.Context, implementation *PostServiceImpl, request models.PostDto) (struct{}, error) { err := implementation.Update(ctx, &request) return struct{}{}, err }, @@ -402,38 +468,40 @@ func TestUpdate(t *testing.T) { return struct{}{}, nil }, ExpectedErr: nil, - SetupMock: func(request models.Post, m *mocks) { + SetupMock: func(request models.PostDto, m *mocks) { m.postRepo.EXPECT().Update(gomock.Any(), gomock.Any()).Return(nil) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getService(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -443,16 +511,18 @@ type userAndLastIDs struct { } func TestGetBatch(t *testing.T) { - tests := []TableTest[[]*models.Post, userAndLastIDs]{ + tests := []TableTest[[]*models.PostDto, userAndLastIDs]{ { name: "1", SetupInput: func() (*userAndLastIDs, error) { return &userAndLastIDs{}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs) ([]*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs, + ) ([]*models.PostDto, error) { return implementation.GetBatch(ctx, request.LastId, request.UserID) }, - ExpectedResult: func() ([]*models.Post, error) { + ExpectedResult: func() ([]*models.PostDto, error) { return nil, nil }, ExpectedErr: errMock, @@ -465,18 +535,21 @@ func TestGetBatch(t *testing.T) { SetupInput: func() (*userAndLastIDs, error) { return &userAndLastIDs{UserID: 1, LastId: 2}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs) ([]*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs, + ) ([]*models.PostDto, error) { return implementation.GetBatch(ctx, request.LastId, request.UserID) }, - ExpectedResult: func() ([]*models.Post, error) { + ExpectedResult: func() ([]*models.PostDto, error) { return nil, nil }, ExpectedErr: errMock, SetupMock: func(request userAndLastIDs, m *mocks) { m.postRepo.EXPECT().GetPosts(gomock.Any(), gomock.Any()).Return( - []*models.Post{ + []*models.PostDto{ {ID: 1, Header: models.Header{CommunityID: 1}}, - }, nil) + }, nil, + ) m.communityRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(nil, errMock) }, }, @@ -485,72 +558,107 @@ func TestGetBatch(t *testing.T) { SetupInput: func() (*userAndLastIDs, error) { return &userAndLastIDs{UserID: 1, LastId: 2}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs) ([]*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs, + ) ([]*models.PostDto, error) { return implementation.GetBatch(ctx, request.LastId, request.UserID) }, - ExpectedResult: func() ([]*models.Post, error) { - return []*models.Post{ + ExpectedResult: func() ([]*models.PostDto, error) { + return []*models.PostDto{ { - ID: 1, - Header: models.Header{AuthorID: 1}, - IsLiked: true, - LikesCount: 1, + ID: 1, + Header: models.Header{AuthorID: 1}, + IsLiked: true, + LikesCount: 1, + CommentCount: 1, }, }, nil }, ExpectedErr: nil, SetupMock: func(request userAndLastIDs, m *mocks) { m.postRepo.EXPECT().GetPosts(gomock.Any(), gomock.Any()).Return( - []*models.Post{ + []*models.PostDto{ {ID: 1, Header: models.Header{AuthorID: 1}}, - }, nil) + }, nil, + ) m.profileRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(&models.Header{AuthorID: 1}, nil) m.postRepo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(1), nil) + m.postRepo.EXPECT().GetCommentCount(gomock.Any(), gomock.Any()).Return(uint32(1), nil) + m.postRepo.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) + }, + }, + { + name: "4", + SetupInput: func() (*userAndLastIDs, error) { + return &userAndLastIDs{UserID: 1, LastId: 2}, nil + }, + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs, + ) ([]*models.PostDto, error) { + return implementation.GetBatch(ctx, request.LastId, request.UserID) + }, + ExpectedResult: func() ([]*models.PostDto, error) { + return nil, nil + }, + ExpectedErr: errMock, + SetupMock: func(request userAndLastIDs, m *mocks) { + m.postRepo.EXPECT().GetPosts(gomock.Any(), gomock.Any()).Return( + []*models.PostDto{ + {ID: 1, Header: models.Header{AuthorID: 1}}, + }, nil, + ) + m.profileRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(&models.Header{AuthorID: 1}, nil) + m.postRepo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(1), nil) + m.postRepo.EXPECT().GetCommentCount(gomock.Any(), gomock.Any()).Return(uint32(0), errMock) m.postRepo.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getService(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } func TestGetBatchFromFriend(t *testing.T) { - tests := []TableTest[[]*models.Post, userAndLastIDs]{ + tests := []TableTest[[]*models.PostDto, userAndLastIDs]{ { name: "1", SetupInput: func() (*userAndLastIDs, error) { return &userAndLastIDs{}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs) ([]*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs, + ) ([]*models.PostDto, error) { return implementation.GetBatchFromFriend(ctx, request.LastId, request.UserID) }, - ExpectedResult: func() ([]*models.Post, error) { + ExpectedResult: func() ([]*models.PostDto, error) { return nil, nil }, ExpectedErr: errMock, @@ -563,10 +671,12 @@ func TestGetBatchFromFriend(t *testing.T) { SetupInput: func() (*userAndLastIDs, error) { return &userAndLastIDs{}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs) ([]*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs, + ) ([]*models.PostDto, error) { return implementation.GetBatchFromFriend(ctx, request.LastId, request.UserID) }, - ExpectedResult: func() ([]*models.Post, error) { + ExpectedResult: func() ([]*models.PostDto, error) { return nil, nil }, ExpectedErr: my_err.ErrNoMoreContent, @@ -579,10 +689,12 @@ func TestGetBatchFromFriend(t *testing.T) { SetupInput: func() (*userAndLastIDs, error) { return &userAndLastIDs{}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs) ([]*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs, + ) ([]*models.PostDto, error) { return implementation.GetBatchFromFriend(ctx, request.LastId, request.UserID) }, - ExpectedResult: func() ([]*models.Post, error) { + ExpectedResult: func() ([]*models.PostDto, error) { return nil, nil }, ExpectedErr: errMock, @@ -596,19 +708,22 @@ func TestGetBatchFromFriend(t *testing.T) { SetupInput: func() (*userAndLastIDs, error) { return &userAndLastIDs{UserID: 1, LastId: 2}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs) ([]*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs, + ) ([]*models.PostDto, error) { return implementation.GetBatchFromFriend(ctx, request.LastId, request.UserID) }, - ExpectedResult: func() ([]*models.Post, error) { + ExpectedResult: func() ([]*models.PostDto, error) { return nil, nil }, ExpectedErr: errMock, SetupMock: func(request userAndLastIDs, m *mocks) { m.profileRepo.EXPECT().GetFriendsID(gomock.Any(), gomock.Any()).Return([]uint32{1, 2, 3}, nil) m.postRepo.EXPECT().GetFriendsPosts(gomock.Any(), gomock.Any(), gomock.Any()).Return( - []*models.Post{ + []*models.PostDto{ {ID: 1, Header: models.Header{CommunityID: 1}}, - }, nil) + }, nil, + ) m.communityRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(nil, errMock) }, }, @@ -617,11 +732,13 @@ func TestGetBatchFromFriend(t *testing.T) { SetupInput: func() (*userAndLastIDs, error) { return &userAndLastIDs{UserID: 1, LastId: 2}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs) ([]*models.Post, error) { + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs, + ) ([]*models.PostDto, error) { return implementation.GetBatchFromFriend(ctx, request.LastId, request.UserID) }, - ExpectedResult: func() ([]*models.Post, error) { - return []*models.Post{ + ExpectedResult: func() ([]*models.PostDto, error) { + return []*models.PostDto{ { ID: 1, Header: models.Header{AuthorID: 1}, @@ -634,10 +751,39 @@ func TestGetBatchFromFriend(t *testing.T) { SetupMock: func(request userAndLastIDs, m *mocks) { m.profileRepo.EXPECT().GetFriendsID(gomock.Any(), gomock.Any()).Return([]uint32{1, 2, 3}, nil) m.postRepo.EXPECT().GetFriendsPosts(gomock.Any(), gomock.Any(), gomock.Any()).Return( - []*models.Post{ + []*models.PostDto{ + {ID: 1, Header: models.Header{AuthorID: 1}}, + }, nil, + ) + m.profileRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(&models.Header{AuthorID: 1}, nil) + m.postRepo.EXPECT().GetCommentCount(gomock.Any(), gomock.Any()).Return(uint32(0), nil) + m.postRepo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(1), nil) + m.postRepo.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) + }, + }, + { + name: "6", + SetupInput: func() (*userAndLastIDs, error) { + return &userAndLastIDs{UserID: 1, LastId: 2}, nil + }, + Run: func( + ctx context.Context, implementation *PostServiceImpl, request userAndLastIDs, + ) ([]*models.PostDto, error) { + return implementation.GetBatchFromFriend(ctx, request.LastId, request.UserID) + }, + ExpectedResult: func() ([]*models.PostDto, error) { + return nil, nil + }, + ExpectedErr: errMock, + SetupMock: func(request userAndLastIDs, m *mocks) { + m.profileRepo.EXPECT().GetFriendsID(gomock.Any(), gomock.Any()).Return([]uint32{1, 2, 3}, nil) + m.postRepo.EXPECT().GetFriendsPosts(gomock.Any(), gomock.Any(), gomock.Any()).Return( + []*models.PostDto{ {ID: 1, Header: models.Header{AuthorID: 1}}, - }, nil) + }, nil, + ) m.profileRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(&models.Header{AuthorID: 1}, nil) + m.postRepo.EXPECT().GetCommentCount(gomock.Any(), gomock.Any()).Return(uint32(0), errMock) m.postRepo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(1), nil) m.postRepo.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) }, @@ -645,31 +791,33 @@ func TestGetBatchFromFriend(t *testing.T) { } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getService(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -712,96 +860,102 @@ func TestGetPostAuthor(t *testing.T) { } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getService(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } func TestCreateCommunityPost(t *testing.T) { - tests := []TableTest[uint32, models.Post]{ + tests := []TableTest[uint32, models.PostDto]{ { name: "1", - SetupInput: func() (*models.Post, error) { - return &models.Post{}, nil + SetupInput: func() (*models.PostDto, error) { + return &models.PostDto{}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request models.Post) (uint32, error) { + Run: func(ctx context.Context, implementation *PostServiceImpl, request models.PostDto) (uint32, error) { return implementation.CreateCommunityPost(ctx, &request) }, ExpectedResult: func() (uint32, error) { return uint32(0), nil }, ExpectedErr: errMock, - SetupMock: func(request models.Post, m *mocks) { - m.postRepo.EXPECT().CreateCommunityPost(gomock.Any(), gomock.Any(), gomock.Any()).Return(uint32(0), errMock) + SetupMock: func(request models.PostDto, m *mocks) { + m.postRepo.EXPECT().CreateCommunityPost(gomock.Any(), gomock.Any(), gomock.Any()).Return( + uint32(0), errMock, + ) }, }, { name: "2", - SetupInput: func() (*models.Post, error) { - return &models.Post{PostContent: models.Content{Text: "new post"}}, nil + SetupInput: func() (*models.PostDto, error) { + return &models.PostDto{PostContent: models.ContentDto{Text: "new post"}}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request models.Post) (uint32, error) { + Run: func(ctx context.Context, implementation *PostServiceImpl, request models.PostDto) (uint32, error) { return implementation.CreateCommunityPost(ctx, &request) }, ExpectedResult: func() (uint32, error) { return uint32(1), nil }, ExpectedErr: nil, - SetupMock: func(request models.Post, m *mocks) { + SetupMock: func(request models.PostDto, m *mocks) { m.postRepo.EXPECT().CreateCommunityPost(gomock.Any(), gomock.Any(), gomock.Any()).Return(uint32(1), nil) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getService(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -812,16 +966,16 @@ type IDs struct { } func TestGetCommunityPost(t *testing.T) { - tests := []TableTest[[]*models.Post, IDs]{ + tests := []TableTest[[]*models.PostDto, IDs]{ { name: "1", SetupInput: func() (*IDs, error) { return &IDs{}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request IDs) ([]*models.Post, error) { + Run: func(ctx context.Context, implementation *PostServiceImpl, request IDs) ([]*models.PostDto, error) { return implementation.GetCommunityPost(ctx, request.communityID, request.userID, request.lastID) }, - ExpectedResult: func() ([]*models.Post, error) { + ExpectedResult: func() ([]*models.PostDto, error) { return nil, nil }, ExpectedErr: errMock, @@ -834,18 +988,19 @@ func TestGetCommunityPost(t *testing.T) { SetupInput: func() (*IDs, error) { return &IDs{userID: 1, lastID: 2, communityID: 1}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request IDs) ([]*models.Post, error) { + Run: func(ctx context.Context, implementation *PostServiceImpl, request IDs) ([]*models.PostDto, error) { return implementation.GetCommunityPost(ctx, request.lastID, request.userID, request.communityID) }, - ExpectedResult: func() ([]*models.Post, error) { + ExpectedResult: func() ([]*models.PostDto, error) { return nil, nil }, ExpectedErr: errMock, SetupMock: func(request IDs, m *mocks) { m.postRepo.EXPECT().GetCommunityPosts(gomock.Any(), gomock.Any(), gomock.Any()).Return( - []*models.Post{ + []*models.PostDto{ {ID: 1, Header: models.Header{CommunityID: 1}}, - }, nil) + }, nil, + ) m.communityRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(nil, errMock) }, }, @@ -854,11 +1009,11 @@ func TestGetCommunityPost(t *testing.T) { SetupInput: func() (*IDs, error) { return &IDs{userID: 1, lastID: 2, communityID: 3}, nil }, - Run: func(ctx context.Context, implementation *PostServiceImpl, request IDs) ([]*models.Post, error) { + Run: func(ctx context.Context, implementation *PostServiceImpl, request IDs) ([]*models.PostDto, error) { return implementation.GetCommunityPost(ctx, request.lastID, request.userID, request.communityID) }, - ExpectedResult: func() ([]*models.Post, error) { - return []*models.Post{ + ExpectedResult: func() ([]*models.PostDto, error) { + return []*models.PostDto{ { ID: 1, Header: models.Header{AuthorID: 1}, @@ -870,42 +1025,70 @@ func TestGetCommunityPost(t *testing.T) { ExpectedErr: nil, SetupMock: func(request IDs, m *mocks) { m.postRepo.EXPECT().GetCommunityPosts(gomock.Any(), gomock.Any(), gomock.Any()).Return( - []*models.Post{ + []*models.PostDto{ {ID: 1, Header: models.Header{AuthorID: 1}}, - }, nil) + }, nil, + ) m.profileRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(&models.Header{AuthorID: 1}, nil) m.postRepo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(1), nil) + m.postRepo.EXPECT().GetCommentCount(gomock.Any(), gomock.Any()).Return(uint32(0), nil) + m.postRepo.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) + }, + }, + { + name: "4", + SetupInput: func() (*IDs, error) { + return &IDs{userID: 1, lastID: 2, communityID: 3}, nil + }, + Run: func(ctx context.Context, implementation *PostServiceImpl, request IDs) ([]*models.PostDto, error) { + return implementation.GetCommunityPost(ctx, request.lastID, request.userID, request.communityID) + }, + ExpectedResult: func() ([]*models.PostDto, error) { + return nil, nil + }, + ExpectedErr: errMock, + SetupMock: func(request IDs, m *mocks) { + m.postRepo.EXPECT().GetCommunityPosts(gomock.Any(), gomock.Any(), gomock.Any()).Return( + []*models.PostDto{ + {ID: 1, Header: models.Header{AuthorID: 1}}, + }, nil, + ) + m.profileRepo.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(&models.Header{AuthorID: 1}, nil) + m.postRepo.EXPECT().GetLikesOnPost(gomock.Any(), gomock.Any()).Return(uint32(1), nil) + m.postRepo.EXPECT().GetCommentCount(gomock.Any(), gomock.Any()).Return(uint32(0), errMock) m.postRepo.EXPECT().CheckLikes(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getService(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -953,31 +1136,33 @@ func TestCheckAccessToCommunity(t *testing.T) { } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getService(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1020,31 +1205,33 @@ func TestSetLikeOnPost(t *testing.T) { } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getService(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1087,31 +1274,33 @@ func TestDeleteLikeFromPost(t *testing.T) { } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getService(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1152,31 +1341,33 @@ func TestCheckLikes(t *testing.T) { } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getService(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getService(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } diff --git a/internal/profile/controller/controller.go b/internal/profile/controller/controller.go index 470eb84d..b7daa92f 100644 --- a/internal/profile/controller/controller.go +++ b/internal/profile/controller/controller.go @@ -1,7 +1,6 @@ package controller import ( - "encoding/json" "errors" "fmt" "math" @@ -9,7 +8,11 @@ import ( "strconv" "github.com/gorilla/mux" + "github.com/mailru/easyjson" + "github.com/microcosm-cc/bluemonday" + "golang.org/x/crypto/bcrypt" + "github.com/2024_2_BetterCallFirewall/internal/middleware" "github.com/2024_2_BetterCallFirewall/internal/models" "github.com/2024_2_BetterCallFirewall/internal/profile" "github.com/2024_2_BetterCallFirewall/pkg/my_err" @@ -37,8 +40,29 @@ func NewProfileController(manager profile.ProfileUsecase, responder Responder) * } } +func sanitize(input string) string { + sanitizer := bluemonday.UGCPolicy() + cleaned := sanitizer.Sanitize(input) + return cleaned +} + +func sanitizeProfile(fullProfile *models.FullProfile) { + fullProfile.Bio = sanitize(fullProfile.Bio) + fullProfile.Avatar = models.Picture(sanitize(string(fullProfile.Avatar))) + fullProfile.FirstName = sanitize(fullProfile.FirstName) + fullProfile.LastName = sanitize(fullProfile.LastName) +} + +func sanitizeProfiles(shorts []*models.ShortProfile) { + for _, short := range shorts { + short.Avatar = models.Picture(sanitize(string(short.Avatar))) + short.FirstName = sanitize(short.FirstName) + short.LastName = sanitize(short.LastName) + } +} + func (h *ProfileHandlerImplementation) GetHeader(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { h.Responder.LogError(my_err.ErrInvalidContext, "") } @@ -60,11 +84,15 @@ func (h *ProfileHandlerImplementation) GetHeader(w http.ResponseWriter, r *http. return } + if header != nil { + header.Author = sanitize(header.Author) + header.Avatar = models.Picture(sanitize(string(header.Avatar))) + } h.Responder.OutputJSON(w, &header, reqID) } func (h *ProfileHandlerImplementation) GetProfile(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { h.Responder.LogError(my_err.ErrInvalidContext, "") } @@ -88,12 +116,13 @@ func (h *ProfileHandlerImplementation) GetProfile(w http.ResponseWriter, r *http h.Responder.ErrorInternal(w, err, reqID) return } + sanitizeProfile(userProfile) h.Responder.OutputJSON(w, userProfile, reqID) } func (h *ProfileHandlerImplementation) UpdateProfile(w http.ResponseWriter, r *http.Request) { - reqID, ok := r.Context().Value("requestID").(string) + reqID, ok := r.Context().Value(middleware.RequestKey).(string) if !ok { h.Responder.LogError(my_err.ErrInvalidContext, "") } @@ -117,22 +146,30 @@ func (h *ProfileHandlerImplementation) UpdateProfile(w http.ResponseWriter, r *h return } + sanitizeProfile(newProfile) h.Responder.OutputJSON(w, newProfile, reqID) } func (h *ProfileHandlerImplementation) getNewProfile(r *http.Request) (*models.FullProfile, error) { newProfile := models.FullProfile{} - err := json.NewDecoder(r.Body).Decode(&newProfile) + err := easyjson.UnmarshalFromReader(r.Body, &newProfile) if err != nil { return nil, err } + sanitizeProfile(&newProfile) + + if len([]rune(newProfile.FirstName)) < 3 || len([]rune(newProfile.FirstName)) > 30 || + len([]rune(newProfile.LastName)) < 3 || len([]rune(newProfile.LastName)) > 30 || + len(newProfile.Bio) > 100 || len([]rune(newProfile.Avatar)) > 100 { + return nil, errors.New("invalid profile") + } return &newProfile, nil } func (h *ProfileHandlerImplementation) DeleteProfile(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) + reqID, ok = r.Context().Value(middleware.RequestKey).(string) sess, err = models.SessionFromContext(r.Context()) ) @@ -157,7 +194,7 @@ func (h *ProfileHandlerImplementation) DeleteProfile(w http.ResponseWriter, r *h func GetIdFromURL(r *http.Request) (uint32, error) { vars := mux.Vars(r) - id := vars["id"] + id := sanitize(vars["id"]) if id == "" { return 0, my_err.ErrEmptyId } @@ -174,7 +211,7 @@ func GetIdFromURL(r *http.Request) (uint32, error) { func (h *ProfileHandlerImplementation) GetProfileById(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) + reqID, ok = r.Context().Value(middleware.RequestKey).(string) id, err = GetIdFromURL(r) ) @@ -187,7 +224,7 @@ func (h *ProfileHandlerImplementation) GetProfileById(w http.ResponseWriter, r * return } - profile, err := h.ProfileManager.GetProfileById(r.Context(), id) + uprofile, err := h.ProfileManager.GetProfileById(r.Context(), id) if err != nil { if errors.Is(err, my_err.ErrProfileNotFound) { h.Responder.ErrorBadRequest(w, err, reqID) @@ -197,11 +234,12 @@ func (h *ProfileHandlerImplementation) GetProfileById(w http.ResponseWriter, r * return } - h.Responder.OutputJSON(w, profile, reqID) + sanitizeProfile(uprofile) + h.Responder.OutputJSON(w, uprofile, reqID) } func GetLastId(r *http.Request) (uint32, error) { - strLastId := r.URL.Query().Get("last_id") + strLastId := sanitize(r.URL.Query().Get("last_id")) if strLastId == "" { return 0, nil } @@ -214,7 +252,7 @@ func GetLastId(r *http.Request) (uint32, error) { func (h *ProfileHandlerImplementation) GetAll(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) + reqID, ok = r.Context().Value(middleware.RequestKey).(string) sess, err = models.SessionFromContext(r.Context()) ) @@ -242,6 +280,8 @@ func (h *ProfileHandlerImplementation) GetAll(w http.ResponseWriter, r *http.Req h.Responder.OutputNoMoreContentJSON(w, reqID) return } + + sanitizeProfiles(profiles) h.Responder.OutputJSON(w, profiles, reqID) } @@ -261,7 +301,7 @@ func GetReceiverAndSender(r *http.Request) (uint32, uint32, error) { func (h *ProfileHandlerImplementation) SendFriendReq(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) + reqID, ok = r.Context().Value(middleware.RequestKey).(string) receiver, sender, err = GetReceiverAndSender(r) ) @@ -285,7 +325,7 @@ func (h *ProfileHandlerImplementation) SendFriendReq(w http.ResponseWriter, r *h func (h *ProfileHandlerImplementation) AcceptFriendReq(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) + reqID, ok = r.Context().Value(middleware.RequestKey).(string) whose, who, err = GetReceiverAndSender(r) ) @@ -307,7 +347,7 @@ func (h *ProfileHandlerImplementation) AcceptFriendReq(w http.ResponseWriter, r func (h *ProfileHandlerImplementation) RemoveFromFriends(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) + reqID, ok = r.Context().Value(middleware.RequestKey).(string) whose, who, err = GetReceiverAndSender(r) ) @@ -329,7 +369,7 @@ func (h *ProfileHandlerImplementation) RemoveFromFriends(w http.ResponseWriter, func (h *ProfileHandlerImplementation) Unsubscribe(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) + reqID, ok = r.Context().Value(middleware.RequestKey).(string) ) if !ok { @@ -351,7 +391,7 @@ func (h *ProfileHandlerImplementation) Unsubscribe(w http.ResponseWriter, r *htt func (h *ProfileHandlerImplementation) GetAllFriends(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) + reqID, ok = r.Context().Value(middleware.RequestKey).(string) id, err = GetIdFromURL(r) ) @@ -379,13 +419,13 @@ func (h *ProfileHandlerImplementation) GetAllFriends(w http.ResponseWriter, r *h h.Responder.OutputNoMoreContentJSON(w, reqID) return } - + sanitizeProfiles(profiles) h.Responder.OutputJSON(w, profiles, reqID) } func (h *ProfileHandlerImplementation) GetAllSubs(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) + reqID, ok = r.Context().Value(middleware.RequestKey).(string) id, err = GetIdFromURL(r) ) @@ -412,12 +452,14 @@ func (h *ProfileHandlerImplementation) GetAllSubs(w http.ResponseWriter, r *http h.Responder.OutputNoMoreContentJSON(w, reqID) return } + + sanitizeProfiles(profiles) h.Responder.OutputJSON(w, profiles, reqID) } func (h *ProfileHandlerImplementation) GetAllSubscriptions(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) + reqID, ok = r.Context().Value(middleware.RequestKey).(string) id, err = GetIdFromURL(r) ) @@ -446,12 +488,13 @@ func (h *ProfileHandlerImplementation) GetAllSubscriptions(w http.ResponseWriter return } + sanitizeProfiles(profiles) h.Responder.OutputJSON(w, profiles, reqID) } func (h *ProfileHandlerImplementation) GetCommunitySubs(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) + reqID, ok = r.Context().Value(middleware.RequestKey).(string) id, err = GetIdFromURL(r) ) @@ -481,14 +524,15 @@ func (h *ProfileHandlerImplementation) GetCommunitySubs(w http.ResponseWriter, r return } + sanitizeProfiles(subs) h.Responder.OutputJSON(w, subs, reqID) } func (h *ProfileHandlerImplementation) SearchProfile(w http.ResponseWriter, r *http.Request) { var ( - reqID, ok = r.Context().Value("requestID").(string) - subStr = r.URL.Query().Get("q") - lastID = r.URL.Query().Get("id") + reqID, ok = r.Context().Value(middleware.RequestKey).(string) + subStr = sanitize(r.URL.Query().Get("q")) + lastID = sanitize(r.URL.Query().Get("id")) id uint64 err error ) @@ -523,5 +567,55 @@ func (h *ProfileHandlerImplementation) SearchProfile(w http.ResponseWriter, r *h return } + sanitizeProfiles(profiles) h.Responder.OutputJSON(w, profiles, reqID) } + +func (h *ProfileHandlerImplementation) ChangePassword(w http.ResponseWriter, r *http.Request) { + reqID, ok := r.Context().Value(middleware.RequestKey).(string) + if !ok { + h.Responder.LogError(my_err.ErrInvalidContext, "") + } + + sess, err := models.SessionFromContext(r.Context()) + if err != nil { + h.Responder.ErrorBadRequest(w, fmt.Errorf("update profile: %w", my_err.ErrSessionNotFound), reqID) + return + } + + var request models.ChangePasswordReq + if err := easyjson.UnmarshalFromReader(r.Body, &request); err != nil { + h.Responder.ErrorBadRequest(w, err, reqID) + return + } + if !validate(request) { + h.Responder.ErrorBadRequest(w, errors.New("too small password or old and new same"), reqID) + return + } + + if err = h.ProfileManager.ChangePassword( + r.Context(), sess.UserID, request.OldPassword, request.NewPassword, + ); err != nil { + if errors.Is(err, my_err.ErrUserNotFound) || + errors.Is(err, my_err.ErrWrongEmailOrPassword) || + errors.Is(err, bcrypt.ErrPasswordTooLong) { + h.Responder.ErrorBadRequest(w, err, reqID) + return + } + + h.Responder.ErrorInternal(w, err, reqID) + return + } + + h.Responder.OutputJSON(w, "password change", reqID) +} + +func validate(request models.ChangePasswordReq) bool { + request.OldPassword = sanitize(request.OldPassword) + request.NewPassword = sanitize(request.NewPassword) + if len([]rune(request.OldPassword)) < 6 || len([]rune(request.NewPassword)) < 6 || request.OldPassword == request.NewPassword { + return false + } + + return true +} diff --git a/internal/profile/controller/controller_test.go b/internal/profile/controller/controller_test.go index 13d0cdc6..0d5a21ff 100644 --- a/internal/profile/controller/controller_test.go +++ b/internal/profile/controller/controller_test.go @@ -47,7 +47,9 @@ func TestGetHeader(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetHeader(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -58,10 +60,14 @@ func TestGetHeader(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -73,7 +79,9 @@ func TestGetHeader(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetHeader(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -85,10 +93,14 @@ func TestGetHeader(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(nil, my_err.ErrProfileNotFound) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -100,7 +112,9 @@ func TestGetHeader(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetHeader(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -112,10 +126,14 @@ func TestGetHeader(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(nil, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("internal error")) - }) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("internal error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -127,7 +145,9 @@ func TestGetHeader(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetHeader(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -139,41 +159,98 @@ func TestGetHeader(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetHeader(gomock.Any(), gomock.Any()).Return(&models.Header{}, nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, header, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, header, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +func TestSanitize(t *testing.T) { + test := "" + expected := "" + res := sanitize(test) + assert.Equal(t, expected, res) +} + +func TestSanitizeProfile(t *testing.T) { + testProfile := &models.FullProfile{ + FirstName: "Andrew", + LastName: "", + Bio: "", + Avatar: "file", } + + expProfile := &models.FullProfile{ + FirstName: "Andrew", + LastName: "", + Bio: "", + Avatar: "file", + } + + sanitizeProfile(testProfile) + assert.Equal(t, expProfile, testProfile) +} + +func TestSanitizeProfiles(t *testing.T) { + short1 := &models.ShortProfile{ + FirstName: "Andrew", + LastName: "Savvateev", + Avatar: "filepath", + } + + short2 := &models.ShortProfile{ + FirstName: "Andrew", + LastName: "", + Avatar: "", + } + + short3 := &models.ShortProfile{ + FirstName: "Andrew", + LastName: "", + Avatar: "", + } + + test := []*models.ShortProfile{short1, short2} + exp := []*models.ShortProfile{short1, short3} + sanitizeProfiles(test) + assert.Equal(t, exp, test) } func TestGetProfile(t *testing.T) { @@ -186,7 +263,9 @@ func TestGetProfile(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -197,10 +276,14 @@ func TestGetProfile(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -212,7 +295,9 @@ func TestGetProfile(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -223,11 +308,17 @@ func TestGetProfile(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.profileManager.EXPECT().GetProfileById(gomock.Any(), gomock.Any()).Return(&models.FullProfile{}, my_err.ErrProfileNotFound) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.profileManager.EXPECT().GetProfileById(gomock.Any(), gomock.Any()).Return( + &models.FullProfile{}, my_err.ErrProfileNotFound, + ) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -239,7 +330,9 @@ func TestGetProfile(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -250,10 +343,14 @@ func TestGetProfile(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.profileManager.EXPECT().GetProfileById(gomock.Any(), gomock.Any()).Return(&models.FullProfile{}, my_err.ErrNoMoreContent) - m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do(func(w, req any) { - request.w.WriteHeader(http.StatusNoContent) - }) + m.profileManager.EXPECT().GetProfileById(gomock.Any(), gomock.Any()).Return( + &models.FullProfile{}, my_err.ErrNoMoreContent, + ) + m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do( + func(w, req any) { + request.w.WriteHeader(http.StatusNoContent) + }, + ) }, }, { @@ -265,7 +362,9 @@ func TestGetProfile(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -276,11 +375,17 @@ func TestGetProfile(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.profileManager.EXPECT().GetProfileById(gomock.Any(), gomock.Any()).Return(&models.FullProfile{}, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.profileManager.EXPECT().GetProfileById(gomock.Any(), gomock.Any()).Return( + &models.FullProfile{}, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -292,7 +397,9 @@ func TestGetProfile(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -304,40 +411,46 @@ func TestGetProfile(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetProfileById(gomock.Any(), gomock.Any()).Return(&models.FullProfile{}, nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -351,7 +464,9 @@ func TestUpdateProfile(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.UpdateProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -362,10 +477,14 @@ func TestUpdateProfile(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -377,7 +496,9 @@ func TestUpdateProfile(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.UpdateProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -388,23 +509,31 @@ func TestUpdateProfile(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "3", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPut, "/api/v1/profile", - bytes.NewBuffer([]byte(`{"id":0, "first_name":"Alexey"}`))) + req := httptest.NewRequest( + http.MethodPut, "/api/v1/profile", + bytes.NewBuffer([]byte(`{"id":0, "first_name":"Alexey", "last_name":"Zemliakov"}`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.UpdateProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -416,23 +545,31 @@ func TestUpdateProfile(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().UpdateProfile(gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { name: "4", SetupInput: func() (*Request, error) { - req := httptest.NewRequest(http.MethodPut, "/api/v1/profile", - bytes.NewBuffer([]byte(`{"id":1, "first_name":"Alexey"}`))) + req := httptest.NewRequest( + http.MethodPut, "/api/v1/profile", + bytes.NewBuffer([]byte(`{"id":1, "first_name":"Alexey", "last_name":"Zemliakov"}`)), + ) w := httptest.NewRecorder() req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.UpdateProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -444,40 +581,79 @@ func TestUpdateProfile(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().UpdateProfile(gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "5", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodPut, "/api/v1/profile", + bytes.NewBuffer([]byte(`{"id":0, "first_name":"Alexey", "last_name":""}`)), + ) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { + implementation.UpdateProfile(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -491,7 +667,9 @@ func TestDeleteProfile(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.DeleteProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -502,10 +680,14 @@ func TestDeleteProfile(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -517,7 +699,9 @@ func TestDeleteProfile(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.DeleteProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -529,10 +713,14 @@ func TestDeleteProfile(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().DeleteProfile(gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -544,7 +732,9 @@ func TestDeleteProfile(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.DeleteProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -556,40 +746,46 @@ func TestDeleteProfile(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().DeleteProfile(gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -603,7 +799,9 @@ func TestGetProfileByID(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetProfileById(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -614,10 +812,14 @@ func TestGetProfileByID(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -629,7 +831,9 @@ func TestGetProfileByID(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetProfileById(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -640,10 +844,14 @@ func TestGetProfileByID(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -655,7 +863,9 @@ func TestGetProfileByID(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetProfileById(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -666,10 +876,14 @@ func TestGetProfileByID(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -681,7 +895,9 @@ func TestGetProfileByID(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetProfileById(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -692,11 +908,17 @@ func TestGetProfileByID(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.profileManager.EXPECT().GetProfileById(gomock.Any(), gomock.Any()).Return(&models.FullProfile{}, my_err.ErrProfileNotFound) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.profileManager.EXPECT().GetProfileById(gomock.Any(), gomock.Any()).Return( + &models.FullProfile{}, my_err.ErrProfileNotFound, + ) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -708,7 +930,9 @@ func TestGetProfileByID(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetProfileById(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -719,11 +943,17 @@ func TestGetProfileByID(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.profileManager.EXPECT().GetProfileById(gomock.Any(), gomock.Any()).Return(&models.FullProfile{}, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.profileManager.EXPECT().GetProfileById(gomock.Any(), gomock.Any()).Return( + &models.FullProfile{}, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -735,7 +965,9 @@ func TestGetProfileByID(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetProfileById(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -747,40 +979,46 @@ func TestGetProfileByID(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetProfileById(gomock.Any(), gomock.Any()).Return(&models.FullProfile{}, nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -794,7 +1032,9 @@ func TestGetAll(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAll(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -805,10 +1045,14 @@ func TestGetAll(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -820,7 +1064,9 @@ func TestGetAll(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAll(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -831,10 +1077,14 @@ func TestGetAll(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -846,7 +1096,9 @@ func TestGetAll(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAll(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -857,11 +1109,17 @@ func TestGetAll(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.profileManager.EXPECT().GetAll(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.profileManager.EXPECT().GetAll(gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -873,7 +1131,9 @@ func TestGetAll(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAll(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -885,9 +1145,11 @@ func TestGetAll(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetAll(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do(func(w, req any) { - request.w.WriteHeader(http.StatusNoContent) - }) + m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do( + func(w, req any) { + request.w.WriteHeader(http.StatusNoContent) + }, + ) }, }, { @@ -899,7 +1161,9 @@ func TestGetAll(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAll(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -912,41 +1176,48 @@ func TestGetAll(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetAll(gomock.Any(), gomock.Any(), gomock.Any()).Return( []*models.ShortProfile{{ID: 1}}, - nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + nil, + ) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -960,7 +1231,9 @@ func TestSendFriendReq(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.SendFriendReq(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -971,10 +1244,14 @@ func TestSendFriendReq(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -986,7 +1263,9 @@ func TestSendFriendReq(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.SendFriendReq(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -997,10 +1276,14 @@ func TestSendFriendReq(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1013,7 +1296,9 @@ func TestSendFriendReq(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.SendFriendReq(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1025,10 +1310,14 @@ func TestSendFriendReq(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().SendFriendReq(gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1041,7 +1330,9 @@ func TestSendFriendReq(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.SendFriendReq(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1053,40 +1344,46 @@ func TestSendFriendReq(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().SendFriendReq(gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1100,7 +1397,9 @@ func TestAcceptFriendReq(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.AcceptFriendReq(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1111,10 +1410,14 @@ func TestAcceptFriendReq(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1126,7 +1429,9 @@ func TestAcceptFriendReq(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.AcceptFriendReq(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1137,10 +1442,14 @@ func TestAcceptFriendReq(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1153,7 +1462,9 @@ func TestAcceptFriendReq(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.AcceptFriendReq(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1165,10 +1476,14 @@ func TestAcceptFriendReq(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().AcceptFriendReq(gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1181,7 +1496,9 @@ func TestAcceptFriendReq(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.AcceptFriendReq(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1193,40 +1510,46 @@ func TestAcceptFriendReq(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().AcceptFriendReq(gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1240,7 +1563,9 @@ func TestRemoveFromFriends(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.RemoveFromFriends(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1251,10 +1576,14 @@ func TestRemoveFromFriends(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1266,7 +1595,9 @@ func TestRemoveFromFriends(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.RemoveFromFriends(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1277,10 +1608,14 @@ func TestRemoveFromFriends(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1293,7 +1628,9 @@ func TestRemoveFromFriends(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.RemoveFromFriends(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1305,10 +1642,14 @@ func TestRemoveFromFriends(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().RemoveFromFriends(gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1321,7 +1662,9 @@ func TestRemoveFromFriends(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.RemoveFromFriends(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1333,40 +1676,46 @@ func TestRemoveFromFriends(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().RemoveFromFriends(gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1380,7 +1729,9 @@ func TestUnsubscribe(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.Unsubscribe(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1391,10 +1742,14 @@ func TestUnsubscribe(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1406,7 +1761,9 @@ func TestUnsubscribe(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.Unsubscribe(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1417,10 +1774,14 @@ func TestUnsubscribe(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1433,7 +1794,9 @@ func TestUnsubscribe(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.Unsubscribe(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1445,10 +1808,14 @@ func TestUnsubscribe(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().Unsubscribe(gomock.Any(), gomock.Any()).Return(errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1461,7 +1828,9 @@ func TestUnsubscribe(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.Unsubscribe(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1473,40 +1842,46 @@ func TestUnsubscribe(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().Unsubscribe(gomock.Any(), gomock.Any()).Return(nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1520,7 +1895,9 @@ func TestGetAllFriends(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllFriends(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1531,10 +1908,14 @@ func TestGetAllFriends(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1546,7 +1927,9 @@ func TestGetAllFriends(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllFriends(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1557,10 +1940,14 @@ func TestGetAllFriends(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1572,7 +1959,9 @@ func TestGetAllFriends(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllFriends(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1583,11 +1972,17 @@ func TestGetAllFriends(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.profileManager.EXPECT().GetAllFriends(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.profileManager.EXPECT().GetAllFriends(gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1599,7 +1994,9 @@ func TestGetAllFriends(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllFriends(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1611,9 +2008,11 @@ func TestGetAllFriends(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetAllFriends(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do(func(w, req any) { - request.w.WriteHeader(http.StatusNoContent) - }) + m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do( + func(w, req any) { + request.w.WriteHeader(http.StatusNoContent) + }, + ) }, }, { @@ -1625,7 +2024,9 @@ func TestGetAllFriends(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllFriends(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1638,41 +2039,48 @@ func TestGetAllFriends(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetAllFriends(gomock.Any(), gomock.Any(), gomock.Any()).Return( []*models.ShortProfile{{ID: 1}}, - nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + nil, + ) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1686,7 +2094,9 @@ func TestGetAllSubs(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllSubs(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1697,10 +2107,14 @@ func TestGetAllSubs(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1712,7 +2126,9 @@ func TestGetAllSubs(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllSubs(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1723,10 +2139,14 @@ func TestGetAllSubs(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1738,7 +2158,9 @@ func TestGetAllSubs(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllSubs(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1749,11 +2171,17 @@ func TestGetAllSubs(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.profileManager.EXPECT().GetAllSubs(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.profileManager.EXPECT().GetAllSubs(gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1765,7 +2193,9 @@ func TestGetAllSubs(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllSubs(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1777,9 +2207,11 @@ func TestGetAllSubs(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetAllSubs(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do(func(w, req any) { - request.w.WriteHeader(http.StatusNoContent) - }) + m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do( + func(w, req any) { + request.w.WriteHeader(http.StatusNoContent) + }, + ) }, }, { @@ -1791,7 +2223,9 @@ func TestGetAllSubs(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllSubs(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1804,41 +2238,48 @@ func TestGetAllSubs(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetAllSubs(gomock.Any(), gomock.Any(), gomock.Any()).Return( []*models.ShortProfile{{ID: 1}}, - nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + nil, + ) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -1852,7 +2293,9 @@ func TestGetAllSubscriptions(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllSubscriptions(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1863,10 +2306,14 @@ func TestGetAllSubscriptions(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1878,7 +2325,9 @@ func TestGetAllSubscriptions(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllSubscriptions(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1889,10 +2338,14 @@ func TestGetAllSubscriptions(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1904,7 +2357,9 @@ func TestGetAllSubscriptions(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllSubscriptions(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1915,11 +2370,17 @@ func TestGetAllSubscriptions(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.profileManager.EXPECT().GetAllSubscriptions(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.profileManager.EXPECT().GetAllSubscriptions(gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -1931,7 +2392,9 @@ func TestGetAllSubscriptions(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllSubscriptions(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1943,9 +2406,11 @@ func TestGetAllSubscriptions(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetAllSubscriptions(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do(func(w, req any) { - request.w.WriteHeader(http.StatusNoContent) - }) + m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do( + func(w, req any) { + request.w.WriteHeader(http.StatusNoContent) + }, + ) }, }, { @@ -1957,7 +2422,9 @@ func TestGetAllSubscriptions(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetAllSubscriptions(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -1970,41 +2437,48 @@ func TestGetAllSubscriptions(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetAllSubscriptions(gomock.Any(), gomock.Any(), gomock.Any()).Return( []*models.ShortProfile{{ID: 1}}, - nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + nil, + ) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -2018,7 +2492,9 @@ func TestGetCommunitySubs(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetCommunitySubs(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -2029,10 +2505,14 @@ func TestGetCommunitySubs(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -2044,7 +2524,9 @@ func TestGetCommunitySubs(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetCommunitySubs(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -2055,10 +2537,14 @@ func TestGetCommunitySubs(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -2070,7 +2556,9 @@ func TestGetCommunitySubs(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetCommunitySubs(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -2081,11 +2569,17 @@ func TestGetCommunitySubs(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.profileManager.EXPECT().GetCommunitySubs(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.profileManager.EXPECT().GetCommunitySubs(gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -2097,7 +2591,9 @@ func TestGetCommunitySubs(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetCommunitySubs(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -2109,9 +2605,11 @@ func TestGetCommunitySubs(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetCommunitySubs(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do(func(w, req any) { - request.w.WriteHeader(http.StatusNoContent) - }) + m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do( + func(w, req any) { + request.w.WriteHeader(http.StatusNoContent) + }, + ) }, }, { @@ -2123,7 +2621,9 @@ func TestGetCommunitySubs(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.GetCommunitySubs(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -2136,41 +2636,48 @@ func TestGetCommunitySubs(t *testing.T) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().GetCommunitySubs(gomock.Any(), gomock.Any(), gomock.Any()).Return( []*models.ShortProfile{{ID: 1}}, - nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + nil, + ) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } @@ -2184,7 +2691,9 @@ func TestSearch(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.SearchProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -2195,10 +2704,14 @@ func TestSearch(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -2209,7 +2722,9 @@ func TestSearch(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.SearchProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -2220,11 +2735,17 @@ func TestSearch(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.profileManager.EXPECT().Search(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("error")) - m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusInternalServerError) - request.w.Write([]byte("error")) - }) + m.profileManager.EXPECT().Search(gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, errors.New("error"), + ) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -2235,7 +2756,9 @@ func TestSearch(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.SearchProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -2246,10 +2769,14 @@ func TestSearch(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -2260,7 +2787,9 @@ func TestSearch(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.SearchProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -2272,10 +2801,14 @@ func TestSearch(t *testing.T) { SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) m.profileManager.EXPECT().Search(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) - m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do(func(w, data, req any) { - request.w.WriteHeader(http.StatusOK) - request.w.Write([]byte("OK")) - }) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) }, }, { @@ -2286,7 +2819,9 @@ func TestSearch(t *testing.T) { res := &Request{r: req, w: w} return res, nil }, - Run: func(ctx context.Context, implementation *ProfileHandlerImplementation, request Request) (Response, error) { + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { implementation.SearchProfile(request.w, request.r) res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} return res, nil @@ -2297,41 +2832,331 @@ func TestSearch(t *testing.T) { ExpectedErr: nil, SetupMock: func(request Request, m *mocks) { m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) - m.profileManager.EXPECT().Search(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, my_err.ErrSessionNotFound) - m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do(func(w, err, req any) { - request.w.WriteHeader(http.StatusBadRequest) - request.w.Write([]byte("bad request")) - }) + m.profileManager.EXPECT().Search(gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, my_err.ErrSessionNotFound, + ) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) }, }, } for _, v := range tests { - t.Run(v.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - serv, mock := getController(ctrl) - ctx := context.Background() - - input, err := v.SetupInput() - if err != nil { - t.Error(err) - } - - v.SetupMock(*input, mock) - - res, err := v.ExpectedResult() - if err != nil { - t.Error(err) - } - - actual, err := v.Run(ctx, serv, *input) - assert.Equal(t, res, actual) - if !errors.Is(err, v.ExpectedErr) { - t.Errorf("expect %v, got %v", v.ExpectedErr, err) - } - }) + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +func TestChangePassword(t *testing.T) { + tests := []TableTest[Response, Request]{ + { + name: "1", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/profile/password", nil) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { + implementation.ChangePassword(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "2", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodGet, "/api/v1/profile/password", + bytes.NewBuffer([]byte("dejfoh")), + ) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { + implementation.ChangePassword(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "3", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodGet, "/api/v1/profile/password", + bytes.NewBuffer([]byte(`{"old_password":"password"}`)), + ) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { + implementation.ChangePassword(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "4", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodGet, "/api/v1/profile/password", + bytes.NewBuffer([]byte(`{"old_password":"password", "new_password":"password"}`)), + ) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { + implementation.ChangePassword(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "5", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodGet, "/api/v1/profile/password", + bytes.NewBuffer([]byte(`{"old_password":"password", "new_password":"password1"}`)), + ) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { + implementation.ChangePassword(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.profileManager.EXPECT().ChangePassword(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(my_err.ErrUserNotFound) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "6", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodGet, "/api/v1/profile/password", + bytes.NewBuffer([]byte(`{"old_password":"password", "new_password":"password1"}`)), + ) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { + implementation.ChangePassword(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusInternalServerError, Body: "error"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.profileManager.EXPECT().ChangePassword(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(errors.New("error")) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + if _, err1 := request.w.Write([]byte("error")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "7", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodGet, "/api/v1/profile/password", + bytes.NewBuffer([]byte(`{"old_password":"password", "new_password":"password1"}`)), + ) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *ProfileHandlerImplementation, request Request, + ) (Response, error) { + implementation.ChangePassword(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusOK, Body: "OK"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.profileManager.EXPECT().ChangePassword(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, data, req any) { + request.w.WriteHeader(http.StatusOK) + if _, err1 := request.w.Write([]byte("OK")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) } } diff --git a/internal/profile/controller/mock.go b/internal/profile/controller/mock.go index a005d3a4..20f96de8 100644 --- a/internal/profile/controller/mock.go +++ b/internal/profile/controller/mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: controller.go +// +// Generated by this command: +// +// mockgen -destination=mock.go -source=controller.go -package=controller +// // Package controller is a generated GoMock package. package controller @@ -13,10 +18,95 @@ import ( gomock "github.com/golang/mock/gomock" ) +// MockResponder is a mock of Responder interface. +type MockResponder struct { + ctrl *gomock.Controller + recorder *MockResponderMockRecorder + isgomock struct{} +} + +// MockResponderMockRecorder is the mock recorder for MockResponder. +type MockResponderMockRecorder struct { + mock *MockResponder +} + +// NewMockResponder creates a new mock instance. +func NewMockResponder(ctrl *gomock.Controller) *MockResponder { + mock := &MockResponder{ctrl: ctrl} + mock.recorder = &MockResponderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResponder) EXPECT() *MockResponderMockRecorder { + return m.recorder +} + +// ErrorBadRequest mocks base method. +func (m *MockResponder) ErrorBadRequest(w http.ResponseWriter, err error, requestID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ErrorBadRequest", w, err, requestID) +} + +// ErrorBadRequest indicates an expected call of ErrorBadRequest. +func (mr *MockResponderMockRecorder) ErrorBadRequest(w, err, requestID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ErrorBadRequest", reflect.TypeOf((*MockResponder)(nil).ErrorBadRequest), w, err, requestID) +} + +// ErrorInternal mocks base method. +func (m *MockResponder) ErrorInternal(w http.ResponseWriter, err error, requestID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ErrorInternal", w, err, requestID) +} + +// ErrorInternal indicates an expected call of ErrorInternal. +func (mr *MockResponderMockRecorder) ErrorInternal(w, err, requestID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ErrorInternal", reflect.TypeOf((*MockResponder)(nil).ErrorInternal), w, err, requestID) +} + +// LogError mocks base method. +func (m *MockResponder) LogError(err error, requestID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LogError", err, requestID) +} + +// LogError indicates an expected call of LogError. +func (mr *MockResponderMockRecorder) LogError(err, requestID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LogError", reflect.TypeOf((*MockResponder)(nil).LogError), err, requestID) +} + +// OutputJSON mocks base method. +func (m *MockResponder) OutputJSON(w http.ResponseWriter, data any, requestID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OutputJSON", w, data, requestID) +} + +// OutputJSON indicates an expected call of OutputJSON. +func (mr *MockResponderMockRecorder) OutputJSON(w, data, requestID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutputJSON", reflect.TypeOf((*MockResponder)(nil).OutputJSON), w, data, requestID) +} + +// OutputNoMoreContentJSON mocks base method. +func (m *MockResponder) OutputNoMoreContentJSON(w http.ResponseWriter, requestID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OutputNoMoreContentJSON", w, requestID) +} + +// OutputNoMoreContentJSON indicates an expected call of OutputNoMoreContentJSON. +func (mr *MockResponderMockRecorder) OutputNoMoreContentJSON(w, requestID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutputNoMoreContentJSON", reflect.TypeOf((*MockResponder)(nil).OutputNoMoreContentJSON), w, requestID) +} + // MockProfileUsecase is a mock of ProfileUsecase interface. type MockProfileUsecase struct { ctrl *gomock.Controller recorder *MockProfileUsecaseMockRecorder + isgomock struct{} } // MockProfileUsecaseMockRecorder is the mock recorder for MockProfileUsecase. @@ -45,11 +135,25 @@ func (m *MockProfileUsecase) AcceptFriendReq(who, whose uint32) error { } // AcceptFriendReq indicates an expected call of AcceptFriendReq. -func (mr *MockProfileUsecaseMockRecorder) AcceptFriendReq(who, whose interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) AcceptFriendReq(who, whose any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptFriendReq", reflect.TypeOf((*MockProfileUsecase)(nil).AcceptFriendReq), who, whose) } +// ChangePassword mocks base method. +func (m *MockProfileUsecase) ChangePassword(ctx context.Context, userID uint32, oldPassword, newPassword string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ChangePassword", ctx, userID, oldPassword, newPassword) + ret0, _ := ret[0].(error) + return ret0 +} + +// ChangePassword indicates an expected call of ChangePassword. +func (mr *MockProfileUsecaseMockRecorder) ChangePassword(ctx, userID, oldPassword, newPassword any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangePassword", reflect.TypeOf((*MockProfileUsecase)(nil).ChangePassword), ctx, userID, oldPassword, newPassword) +} + // DeleteProfile mocks base method. func (m *MockProfileUsecase) DeleteProfile(arg0 uint32) error { m.ctrl.T.Helper() @@ -59,7 +163,7 @@ func (m *MockProfileUsecase) DeleteProfile(arg0 uint32) error { } // DeleteProfile indicates an expected call of DeleteProfile. -func (mr *MockProfileUsecaseMockRecorder) DeleteProfile(arg0 interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) DeleteProfile(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteProfile", reflect.TypeOf((*MockProfileUsecase)(nil).DeleteProfile), arg0) } @@ -74,7 +178,7 @@ func (m *MockProfileUsecase) GetAll(ctx context.Context, self, lastId uint32) ([ } // GetAll indicates an expected call of GetAll. -func (mr *MockProfileUsecaseMockRecorder) GetAll(ctx, self, lastId interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) GetAll(ctx, self, lastId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAll", reflect.TypeOf((*MockProfileUsecase)(nil).GetAll), ctx, self, lastId) } @@ -89,7 +193,7 @@ func (m *MockProfileUsecase) GetAllFriends(ctx context.Context, id, lastId uint3 } // GetAllFriends indicates an expected call of GetAllFriends. -func (mr *MockProfileUsecaseMockRecorder) GetAllFriends(ctx, id, lastId interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) GetAllFriends(ctx, id, lastId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllFriends", reflect.TypeOf((*MockProfileUsecase)(nil).GetAllFriends), ctx, id, lastId) } @@ -104,7 +208,7 @@ func (m *MockProfileUsecase) GetAllSubs(ctx context.Context, id, lastId uint32) } // GetAllSubs indicates an expected call of GetAllSubs. -func (mr *MockProfileUsecaseMockRecorder) GetAllSubs(ctx, id, lastId interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) GetAllSubs(ctx, id, lastId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllSubs", reflect.TypeOf((*MockProfileUsecase)(nil).GetAllSubs), ctx, id, lastId) } @@ -119,7 +223,7 @@ func (m *MockProfileUsecase) GetAllSubscriptions(ctx context.Context, id, lastId } // GetAllSubscriptions indicates an expected call of GetAllSubscriptions. -func (mr *MockProfileUsecaseMockRecorder) GetAllSubscriptions(ctx, id, lastId interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) GetAllSubscriptions(ctx, id, lastId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllSubscriptions", reflect.TypeOf((*MockProfileUsecase)(nil).GetAllSubscriptions), ctx, id, lastId) } @@ -134,7 +238,7 @@ func (m *MockProfileUsecase) GetCommunitySubs(ctx context.Context, communityID, } // GetCommunitySubs indicates an expected call of GetCommunitySubs. -func (mr *MockProfileUsecaseMockRecorder) GetCommunitySubs(ctx, communityID, lastID interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) GetCommunitySubs(ctx, communityID, lastID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCommunitySubs", reflect.TypeOf((*MockProfileUsecase)(nil).GetCommunitySubs), ctx, communityID, lastID) } @@ -149,7 +253,7 @@ func (m *MockProfileUsecase) GetHeader(ctx context.Context, userID uint32) (*mod } // GetHeader indicates an expected call of GetHeader. -func (mr *MockProfileUsecaseMockRecorder) GetHeader(ctx, userID interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) GetHeader(ctx, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeader", reflect.TypeOf((*MockProfileUsecase)(nil).GetHeader), ctx, userID) } @@ -164,7 +268,7 @@ func (m *MockProfileUsecase) GetProfileById(arg0 context.Context, arg1 uint32) ( } // GetProfileById indicates an expected call of GetProfileById. -func (mr *MockProfileUsecaseMockRecorder) GetProfileById(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) GetProfileById(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProfileById", reflect.TypeOf((*MockProfileUsecase)(nil).GetProfileById), arg0, arg1) } @@ -178,7 +282,7 @@ func (m *MockProfileUsecase) RemoveFromFriends(who, whose uint32) error { } // RemoveFromFriends indicates an expected call of RemoveFromFriends. -func (mr *MockProfileUsecaseMockRecorder) RemoveFromFriends(who, whose interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) RemoveFromFriends(who, whose any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveFromFriends", reflect.TypeOf((*MockProfileUsecase)(nil).RemoveFromFriends), who, whose) } @@ -193,7 +297,7 @@ func (m *MockProfileUsecase) Search(ctx context.Context, subStr string, lastId u } // Search indicates an expected call of Search. -func (mr *MockProfileUsecaseMockRecorder) Search(ctx, subStr, lastId interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) Search(ctx, subStr, lastId any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Search", reflect.TypeOf((*MockProfileUsecase)(nil).Search), ctx, subStr, lastId) } @@ -207,7 +311,7 @@ func (m *MockProfileUsecase) SendFriendReq(receiver, sender uint32) error { } // SendFriendReq indicates an expected call of SendFriendReq. -func (mr *MockProfileUsecaseMockRecorder) SendFriendReq(receiver, sender interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) SendFriendReq(receiver, sender any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendFriendReq", reflect.TypeOf((*MockProfileUsecase)(nil).SendFriendReq), receiver, sender) } @@ -221,7 +325,7 @@ func (m *MockProfileUsecase) Unsubscribe(who, whose uint32) error { } // Unsubscribe indicates an expected call of Unsubscribe. -func (mr *MockProfileUsecaseMockRecorder) Unsubscribe(who, whose interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) Unsubscribe(who, whose any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unsubscribe", reflect.TypeOf((*MockProfileUsecase)(nil).Unsubscribe), who, whose) } @@ -235,90 +339,7 @@ func (m *MockProfileUsecase) UpdateProfile(arg0 context.Context, arg1 *models.Fu } // UpdateProfile indicates an expected call of UpdateProfile. -func (mr *MockProfileUsecaseMockRecorder) UpdateProfile(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockProfileUsecaseMockRecorder) UpdateProfile(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProfile", reflect.TypeOf((*MockProfileUsecase)(nil).UpdateProfile), arg0, arg1) } - -// MockResponder is a mock of Responder interface. -type MockResponder struct { - ctrl *gomock.Controller - recorder *MockResponderMockRecorder -} - -// MockResponderMockRecorder is the mock recorder for MockResponder. -type MockResponderMockRecorder struct { - mock *MockResponder -} - -// NewMockResponder creates a new mock instance. -func NewMockResponder(ctrl *gomock.Controller) *MockResponder { - mock := &MockResponder{ctrl: ctrl} - mock.recorder = &MockResponderMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockResponder) EXPECT() *MockResponderMockRecorder { - return m.recorder -} - -// ErrorBadRequest mocks base method. -func (m *MockResponder) ErrorBadRequest(w http.ResponseWriter, err error, requestID string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ErrorBadRequest", w, err, requestID) -} - -// ErrorBadRequest indicates an expected call of ErrorBadRequest. -func (mr *MockResponderMockRecorder) ErrorBadRequest(w, err, requestID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ErrorBadRequest", reflect.TypeOf((*MockResponder)(nil).ErrorBadRequest), w, err, requestID) -} - -// ErrorInternal mocks base method. -func (m *MockResponder) ErrorInternal(w http.ResponseWriter, err error, requestID string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ErrorInternal", w, err, requestID) -} - -// ErrorInternal indicates an expected call of ErrorInternal. -func (mr *MockResponderMockRecorder) ErrorInternal(w, err, requestID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ErrorInternal", reflect.TypeOf((*MockResponder)(nil).ErrorInternal), w, err, requestID) -} - -// LogError mocks base method. -func (m *MockResponder) LogError(err error, requestID string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LogError", err, requestID) -} - -// LogError indicates an expected call of LogError. -func (mr *MockResponderMockRecorder) LogError(err, requestID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LogError", reflect.TypeOf((*MockResponder)(nil).LogError), err, requestID) -} - -// OutputJSON mocks base method. -func (m *MockResponder) OutputJSON(w http.ResponseWriter, data any, requestID string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OutputJSON", w, data, requestID) -} - -// OutputJSON indicates an expected call of OutputJSON. -func (mr *MockResponderMockRecorder) OutputJSON(w, data, requestID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutputJSON", reflect.TypeOf((*MockResponder)(nil).OutputJSON), w, data, requestID) -} - -// OutputNoMoreContentJSON mocks base method. -func (m *MockResponder) OutputNoMoreContentJSON(w http.ResponseWriter, requestID string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OutputNoMoreContentJSON", w, requestID) -} - -// OutputNoMoreContentJSON indicates an expected call of OutputNoMoreContentJSON. -func (mr *MockResponderMockRecorder) OutputNoMoreContentJSON(w, requestID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutputNoMoreContentJSON", reflect.TypeOf((*MockResponder)(nil).OutputNoMoreContentJSON), w, requestID) -} diff --git a/internal/profile/repository.go b/internal/profile/repository.go index 831163a1..b5c543ce 100644 --- a/internal/profile/repository.go +++ b/internal/profile/repository.go @@ -14,6 +14,8 @@ type Repository interface { UpdateWithAvatar(context.Context, *models.FullProfile) error DeleteProfile(uint32) error Search(ctx context.Context, subStr string, lastId uint32) ([]*models.ShortProfile, error) + GetUserById(ctx context.Context, id uint32) (*models.User, error) + ChangePassword(ctx context.Context, id uint32, password string) error CheckFriendship(context.Context, uint32, uint32) (bool, error) AddFriendsReq(receiver uint32, sender uint32) error diff --git a/internal/profile/repository/QueryConsts.go b/internal/profile/repository/QueryConsts.go index 14a9fe42..efe49bdc 100644 --- a/internal/profile/repository/QueryConsts.go +++ b/internal/profile/repository/QueryConsts.go @@ -3,15 +3,17 @@ package repository const ( CreateUser = `INSERT INTO profile (first_name, last_name, email, hashed_password) VALUES ($1, $2, $3, $4) ON CONFLICT (email) DO NOTHING RETURNING id;` GetUserByEmail = `SELECT id, first_name, last_name, email, hashed_password FROM profile WHERE email = $1 LIMIT 1;` + GetUserByID = `SELECT id, first_name, last_name, email, hashed_password FROM profile WHERE id = $1 LIMIT 1;` + ChangePassword = `UPDATE profile SET hashed_password = $1 WHERE id = $2;` GetProfileByID = "SELECT profile.id, first_name, last_name, bio, avatar FROM profile WHERE profile.id = $1 LIMIT 1;" - GetStatus = "SELECT status FROM friend WHERE (sender = $1 AND receiver = $2) LIMIT 1" - GetAllProfilesBatch = "WITH friends AS (SELECT sender AS friend FROM friend WHERE (receiver = $1 AND status = 0) UNION SELECT receiver AS friend FROM friend WHERE (sender = $1 AND status = 0)), subscriptions AS (SELECT sender AS subscription FROM friend WHERE (receiver = $1 AND status = -1) UNION SELECT receiver AS subscriber FROM friend WHERE (sender = $1 AND status = 1)) SELECT p.id, first_name, last_name, avatar FROM profile p WHERE p.id <> $1 AND p.id > $2 AND p.id NOT IN (SELECT friend FROM friends) AND p.id NOT IN (SELECT subscription FROM subscriptions) ORDER BY p.id LIMIT $3;" + GetStatus = "SELECT sender, status FROM friend WHERE (sender = $1 AND receiver = $2) OR (sender = $2 AND receiver = $1) LIMIT 1" + GetAllProfilesBatch = "WITH friends AS (SELECT sender AS friend FROM friend WHERE (receiver = $1 AND status = 0) UNION SELECT receiver AS friend FROM friend WHERE (sender = $1 AND status = 0)), subscriptions AS (SELECT sender AS subscription FROM friend WHERE (receiver = $1 AND status = -1) UNION SELECT receiver AS subscription FROM friend WHERE (sender = $1 AND status = 1)), subscribers AS (SELECT sender AS subscriber FROM friend WHERE (receiver = $1 AND status = 1) UNION SELECT receiver AS subscriber FROM friend WHERE (sender = $1 AND status = -1)) SELECT p.id, first_name, last_name, avatar FROM profile p WHERE p.id <> $1 AND p.id > $2 AND p.id NOT IN (SELECT friend FROM friends) AND p.id NOT IN (SELECT subscription FROM subscriptions) AND p.id NOT IN (SELECT subscriber FROM subscribers) ORDER BY p.id LIMIT $3;" UpdateProfile = "UPDATE profile SET first_name = $1, last_name = $2, bio = $3 WHERE id = $4;" UpdateProfileAvatar = "UPDATE profile SET avatar = $2, first_name = $3, last_name = $4, bio = $5 WHERE id = $1;" DeleteProfile = "DELETE FROM profile WHERE id = $1;" AddFriends = "INSERT INTO friend(sender, receiver, status) VALUES ($1, $2, 1);" - AcceptFriendReq = "UPDATE friend SET status = 0 WHERE sender = $1 AND receiver = $2;" + AcceptFriendReq = "UPDATE friend SET status = 0 WHERE (sender = $1 AND receiver = $2) OR (sender = $2 AND receiver = $1);" RemoveFriendsReq = "UPDATE friend SET status = ( CASE WHEN sender = $1 THEN -1 ELSE 1 END) WHERE (receiver = $1 AND sender = $2) OR (sender = $1 AND receiver = $2);" GetAllFriends = "WITH friends AS (SELECT sender AS friend FROM friend WHERE (receiver = $1 AND status = 0) UNION SELECT receiver AS friend FROM friend WHERE (sender = $1 AND status = 0)) SELECT profile.id, first_name, last_name, avatar FROM profile INNER JOIN friends ON friend = profile.id WHERE profile.id > $2 ORDER BY profile.id LIMIT $3;" GetAllSubs = "WITH subs AS ( SELECT sender AS subscriber FROM friend WHERE (receiver = $1 AND status = 1) UNION SELECT receiver AS subscriber FROM friend WHERE (sender = $1 AND status = -1)) SELECT profile.id, first_name, last_name, avatar FROM profile INNER JOIN subs ON subscriber = profile.id WHERE profile.id > $2 ORDER BY profile.id LIMIT $3;" @@ -23,8 +25,21 @@ const ( GetFriendsID = "SELECT sender AS friend FROM friend WHERE (receiver = $1 AND status = 0) UNION SELECT receiver AS friend FROM friend WHERE (sender = $1 AND status = 0)" GetSubsID = "SELECT sender AS subscriber FROM friend WHERE (receiver = $1 AND status = 1) UNION SELECT receiver AS subscriber FROM friend WHERE (sender = $1 AND status = -1)" GetSubscriptionsID = "SELECT sender AS subscription FROM friend WHERE (receiver = $1 AND status = -1) UNION SELECT receiver AS subscriber FROM friend WHERE (sender = $1 AND status = 1)" - GetAllStatuses = "WITH friends AS (\n SELECT sender AS friend\n FROM friend\n WHERE (receiver = $1 AND status = 0)\n UNION\n SELECT receiver AS friend\n FROM friend\n WHERE (sender = $1 AND status = 0)\n), subscriptions AS (\n SELECT sender AS subscription FROM friend WHERE (receiver = $1 AND status = -1) UNION SELECT receiver AS subscriber FROM friend WHERE (sender = $1 AND status = 1)\n), subscribers AS (\n SELECT sender AS subscriber FROM friend WHERE (receiver = $1 AND status = 1) UNION SELECT receiver AS subscriber FROM friend WHERE (sender = $1 AND status = -1)) SELECT (SELECT json_agg(friend) FROM friends) AS friends, (SELECT json_agg(subscriber) FROM subscribers) AS subscribers, (SELECT json_agg(subscription) FROM subscriptions) AS subscriptions;" - GetShortProfile = "SELECT first_name || ' ' || last_name AS name, avatar FROM profile WHERE profile.id = $1 LIMIT 1;" + GetAllStatuses = `WITH friends AS ( +SELECT sender AS friend +FROM friend WHERE (receiver = $1 AND status = 0) UNION +SELECT receiver AS friend FROM friend WHERE (sender = $1 AND status = 0) +), subscriptions AS ( +SELECT sender AS subscription FROM friend WHERE (receiver = $1 AND status = -1) UNION +SELECT receiver AS subscriber FROM friend WHERE (sender = $1 AND status = 1) +), subscribers AS ( +SELECT sender AS subscriber FROM friend WHERE (receiver = $1 AND status = 1) UNION +SELECT receiver AS subscriber FROM friend WHERE (sender = $1 AND status = -1) +) SELECT (SELECT json_agg(friend) FROM friends) AS friends, +(SELECT json_agg(subscriber) FROM subscribers) AS subscribers, +(SELECT json_agg(subscription) FROM subscriptions) AS subscriptions;` + + GetShortProfile = "SELECT first_name || ' ' || last_name AS name, avatar FROM profile WHERE profile.id = $1 LIMIT 1;" GetCommunitySubs = `WITH subs AS (SELECT profile_id AS id FROM community_profile WHERE community_id = $1) SELECT p.id, first_name, last_name, avatar FROM profile p JOIN subs ON p.id = subs.id WHERE id > $2 ORDER BY id LIMIT $3;` diff --git a/internal/profile/repository/postgres.go b/internal/profile/repository/postgres.go index 446e1f75..5dc7bae3 100644 --- a/internal/profile/repository/postgres.go +++ b/internal/profile/repository/postgres.go @@ -41,7 +41,9 @@ func (p *ProfileRepo) Create(user *models.User, ctx context.Context) (uint32, er func (p *ProfileRepo) GetByEmail(email string, ctx context.Context) (*models.User, error) { user := &models.User{} - err := p.DB.QueryRowContext(ctx, GetUserByEmail, email).Scan(&user.ID, &user.FirstName, &user.LastName, &user.Email, &user.Password) + err := p.DB.QueryRowContext(ctx, GetUserByEmail, email).Scan( + &user.ID, &user.FirstName, &user.LastName, &user.Email, &user.Password, + ) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("postgres get user: %w", my_err.ErrUserNotFound) @@ -54,7 +56,9 @@ func (p *ProfileRepo) GetByEmail(email string, ctx context.Context) (*models.Use func (p *ProfileRepo) GetProfileById(ctx context.Context, id uint32) (*models.FullProfile, error) { res := &models.FullProfile{} - err := p.DB.QueryRowContext(ctx, GetProfileByID, id).Scan(&res.ID, &res.FirstName, &res.LastName, &res.Bio, &res.Avatar) + err := p.DB.QueryRowContext(ctx, GetProfileByID, id).Scan( + &res.ID, &res.FirstName, &res.LastName, &res.Bio, &res.Avatar, + ) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, my_err.ErrProfileNotFound @@ -65,11 +69,20 @@ func (p *ProfileRepo) GetProfileById(ctx context.Context, id uint32) (*models.Fu } func (p *ProfileRepo) GetStatus(ctx context.Context, self uint32, profile uint32) (int, error) { - var status int - err := p.DB.QueryRowContext(ctx, GetStatus, self, profile).Scan(&status) + var ( + status int + sender uint32 + ) + + err := p.DB.QueryRowContext(ctx, GetStatus, self, profile).Scan(&sender, &status) if err != nil { return 0, err } + + if sender != self && status != 0 { + status = -status + } + return status, nil } @@ -138,7 +151,10 @@ func (p *ProfileRepo) UpdateProfile(ctx context.Context, profile *models.FullPro } func (p *ProfileRepo) UpdateWithAvatar(ctx context.Context, newProfile *models.FullProfile) error { - _, err := p.DB.ExecContext(ctx, UpdateProfileAvatar, newProfile.ID, newProfile.Avatar, newProfile.FirstName, newProfile.LastName, newProfile.Bio) + _, err := p.DB.ExecContext( + ctx, UpdateProfileAvatar, newProfile.ID, newProfile.Avatar, newProfile.FirstName, newProfile.LastName, + newProfile.Bio, + ) if err != nil { return fmt.Errorf("update profile with avatar %w", err) } @@ -159,7 +175,7 @@ func (p *ProfileRepo) CheckFriendship(ctx context.Context, self uint32, profile err := p.DB.QueryRowContext(ctx, CheckFriendship, self, profile).Scan(&status) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return true, nil + return false, nil } return false, fmt.Errorf("check friendship: %w", err) } @@ -233,7 +249,9 @@ func (p *ProfileRepo) GetAllSubs(ctx context.Context, u uint32, lastId uint32) ( return res, nil } -func (p *ProfileRepo) GetAllSubscriptions(ctx context.Context, u uint32, lastId uint32) ([]*models.ShortProfile, error) { +func (p *ProfileRepo) GetAllSubscriptions( + ctx context.Context, u uint32, lastId uint32, +) ([]*models.ShortProfile, error) { res := make([]*models.ShortProfile, 0) rows, err := p.DB.QueryContext(ctx, GetAllSubscriptions, u, lastId, LIMIT) if err != nil { @@ -313,7 +331,9 @@ func (p *ProfileRepo) GetHeader(ctx context.Context, u uint32) (*models.Header, return profile, nil } -func (p *ProfileRepo) GetCommunitySubs(ctx context.Context, communityID uint32, lastInsertId uint32) ([]*models.ShortProfile, error) { +func (p *ProfileRepo) GetCommunitySubs( + ctx context.Context, communityID uint32, lastInsertId uint32, +) ([]*models.ShortProfile, error) { var subs []*models.ShortProfile rows, err := p.DB.QueryContext(ctx, GetCommunitySubs, communityID, lastInsertId, LIMIT) if err != nil { @@ -354,3 +374,30 @@ func (p *ProfileRepo) Search(ctx context.Context, query string, lastID uint32) ( return res, nil } + +func (p *ProfileRepo) GetUserById(ctx context.Context, id uint32) (*models.User, error) { + user := &models.User{} + err := p.DB.QueryRowContext(ctx, GetUserByID, id).Scan( + &user.ID, &user.FirstName, &user.LastName, &user.Email, &user.Password, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("postgres get user: %w", my_err.ErrUserNotFound) + } + return nil, fmt.Errorf("postgres get user: %w", err) + } + + return user, nil +} + +func (p *ProfileRepo) ChangePassword(ctx context.Context, id uint32, password string) error { + if err := p.DB.QueryRowContext(ctx, ChangePassword, password, id).Err(); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("postgres change password: %w", my_err.ErrUserNotFound) + } + + return fmt.Errorf("change password: %w", err) + } + + return nil +} diff --git a/internal/profile/service/usecase.go b/internal/profile/service/usecase.go index 8b61de0c..f00bf100 100644 --- a/internal/profile/service/usecase.go +++ b/internal/profile/service/usecase.go @@ -7,6 +7,8 @@ import ( "fmt" "slices" + "golang.org/x/crypto/bcrypt" + "github.com/2024_2_BetterCallFirewall/internal/models" "github.com/2024_2_BetterCallFirewall/internal/profile" "github.com/2024_2_BetterCallFirewall/pkg/my_err" @@ -64,7 +66,9 @@ func (p ProfileUsecaseImplementation) GetProfileById(ctx context.Context, u uint return profile, nil } -func (p ProfileUsecaseImplementation) GetAll(ctx context.Context, self uint32, lastId uint32) ([]*models.ShortProfile, error) { +func (p ProfileUsecaseImplementation) GetAll( + ctx context.Context, self uint32, lastId uint32, +) ([]*models.ShortProfile, error) { profiles, err := p.repo.GetAll(ctx, self, lastId) if err != nil { return nil, fmt.Errorf("get all profiles usecase: %w", err) @@ -172,7 +176,9 @@ func (p ProfileUsecaseImplementation) setStatuses(ctx context.Context, profiles return nil } -func (p ProfileUsecaseImplementation) GetAllFriends(ctx context.Context, id uint32, lastId uint32) ([]*models.ShortProfile, error) { +func (p ProfileUsecaseImplementation) GetAllFriends( + ctx context.Context, id uint32, lastId uint32, +) ([]*models.ShortProfile, error) { res, err := p.repo.GetAllFriends(ctx, id, lastId) if err != nil { return nil, fmt.Errorf("get all friends usecase: %w", err) @@ -189,7 +195,9 @@ func (p ProfileUsecaseImplementation) GetAllFriends(ctx context.Context, id uint return res, nil } -func (p ProfileUsecaseImplementation) GetAllSubs(ctx context.Context, id uint32, lastId uint32) ([]*models.ShortProfile, error) { +func (p ProfileUsecaseImplementation) GetAllSubs( + ctx context.Context, id uint32, lastId uint32, +) ([]*models.ShortProfile, error) { res, err := p.repo.GetAllSubs(ctx, id, lastId) if err != nil { return nil, fmt.Errorf("get all subs usecase: %w", err) @@ -203,7 +211,9 @@ func (p ProfileUsecaseImplementation) GetAllSubs(ctx context.Context, id uint32, return res, nil } -func (p ProfileUsecaseImplementation) GetAllSubscriptions(ctx context.Context, id uint32, lastId uint32) ([]*models.ShortProfile, error) { +func (p ProfileUsecaseImplementation) GetAllSubscriptions( + ctx context.Context, id uint32, lastId uint32, +) ([]*models.ShortProfile, error) { res, err := p.repo.GetAllSubscriptions(ctx, id, lastId) if err != nil { return nil, fmt.Errorf("get all subscriptions usecase: %w", err) @@ -226,7 +236,9 @@ func (p ProfileUsecaseImplementation) GetHeader(ctx context.Context, userID uint return header, nil } -func (p ProfileUsecaseImplementation) GetCommunitySubs(ctx context.Context, communityID, lastId uint32) ([]*models.ShortProfile, error) { +func (p ProfileUsecaseImplementation) GetCommunitySubs( + ctx context.Context, communityID, lastId uint32, +) ([]*models.ShortProfile, error) { subs, err := p.repo.GetCommunitySubs(ctx, communityID, lastId) if err != nil { return nil, fmt.Errorf("get subs: %w", err) @@ -235,7 +247,9 @@ func (p ProfileUsecaseImplementation) GetCommunitySubs(ctx context.Context, comm return subs, nil } -func (p ProfileUsecaseImplementation) Search(ctx context.Context, subStr string, lastId uint32) ([]*models.ShortProfile, error) { +func (p ProfileUsecaseImplementation) Search( + ctx context.Context, subStr string, lastId uint32, +) ([]*models.ShortProfile, error) { profiles, err := p.repo.Search(ctx, subStr, lastId) if err != nil { return nil, fmt.Errorf("search: %w", err) @@ -248,3 +262,29 @@ func (p ProfileUsecaseImplementation) Search(ctx context.Context, subStr string, return profiles, nil } + +func (p ProfileUsecaseImplementation) ChangePassword( + ctx context.Context, + userID uint32, + oldPassword, + newPassword string, +) error { + user, err := p.repo.GetUserById(ctx, userID) + if err != nil { + return fmt.Errorf("change password usecase: %w", err) + } + if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(oldPassword)); err != nil { + return fmt.Errorf("change password usecase: %w", my_err.ErrWrongEmailOrPassword) + } + + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("change password usecase: %w", err) + } + + if err := p.repo.ChangePassword(ctx, userID, string(hashedPassword)); err != nil { + return fmt.Errorf("change password usecase: %w", err) + } + + return nil +} diff --git a/internal/profile/service/usecase_test.go b/internal/profile/service/usecase_test.go index c865e44f..25d24abc 100644 --- a/internal/profile/service/usecase_test.go +++ b/internal/profile/service/usecase_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" "github.com/2024_2_BetterCallFirewall/internal/models" "github.com/2024_2_BetterCallFirewall/pkg/my_err" @@ -16,11 +17,30 @@ type MockProfileDB struct { Storage struct{} } +func (m MockProfileDB) GetUserById(ctx context.Context, id uint32) (*models.User, error) { + if id == 0 { + return nil, ErrExec + } + + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) + return &models.User{Password: string(hashedPassword)}, nil +} + +func (m MockProfileDB) ChangePassword(ctx context.Context, id uint32, password string) error { + if id == 1 { + return errMock + } + + return nil +} + func (m MockProfileDB) CheckFriendship(ctx context.Context, u uint32, u2 uint32) (bool, error) { return false, nil } -func (m MockProfileDB) GetCommunitySubs(ctx context.Context, communityID uint32, lastInsertId uint32) ([]*models.ShortProfile, error) { +func (m MockProfileDB) GetCommunitySubs( + ctx context.Context, communityID uint32, lastInsertId uint32, +) ([]*models.ShortProfile, error) { return nil, nil } @@ -201,6 +221,31 @@ func (m MockPostDB) GetAuthorsPosts(ctx context.Context, header *models.Header, return nil, nil } +type changePasswordTests struct { + ctx context.Context + userID uint32 + oldPassword string + newPassword string + err error +} + +func TestChangePassword(t *testing.T) { + tests := []changePasswordTests{ + {ctx: context.Background(), userID: 0, err: ErrExec}, + {ctx: context.Background(), userID: 1, err: my_err.ErrWrongEmailOrPassword}, + {ctx: context.Background(), userID: 1, oldPassword: "password", err: errMock}, + {ctx: context.Background(), userID: 2, oldPassword: "password", err: nil}, + } + + for i, tt := range tests { + err := pu.ChangePassword(tt.ctx, tt.userID, tt.oldPassword, tt.newPassword) + if !errors.Is(err, tt.err) { + t.Errorf("case %d: expected error %s, got %s", i, tt.err, err) + } + } + +} + func TestGetProfileByID(t *testing.T) { sessId1, err := models.NewSession(1) if err != nil { @@ -632,8 +677,11 @@ func TestSearch(t *testing.T) { tests := []TestSearchInput{ {str: "", ID: 0, ctx: context.Background(), want: nil, err: ErrExec}, {str: "alexey", ID: 1, ctx: context.Background(), want: nil, err: my_err.ErrSessionNotFound}, - {str: "alexey", ID: 10, ctx: models.ContextWithSession(context.Background(), &models.Session{ID: "1", UserID: 10}), - want: nil, err: nil}, + { + str: "alexey", ID: 10, + ctx: models.ContextWithSession(context.Background(), &models.Session{ID: "1", UserID: 10}), + want: nil, err: nil, + }, } for caseNum, test := range tests { diff --git a/internal/profile/usecase.go b/internal/profile/usecase.go index 8cc1d3db..0592afbc 100644 --- a/internal/profile/usecase.go +++ b/internal/profile/usecase.go @@ -12,6 +12,7 @@ type ProfileUsecase interface { UpdateProfile(context.Context, *models.FullProfile) error DeleteProfile(uint32) error Search(ctx context.Context, subStr string, lastId uint32) ([]*models.ShortProfile, error) + ChangePassword(ctx context.Context, userID uint32, oldPassword, newPassword string) error SendFriendReq(receiver uint32, sender uint32) error AcceptFriendReq(who uint32, whose uint32) error diff --git a/internal/router/file/router.go b/internal/router/file/router.go index c5ae529a..3beebf22 100644 --- a/internal/router/file/router.go +++ b/internal/router/file/router.go @@ -21,6 +21,7 @@ type SessionManager interface { type FileController interface { Upload(w http.ResponseWriter, r *http.Request) Download(w http.ResponseWriter, r *http.Request) + UploadNonImage(w http.ResponseWriter, r *http.Request) } func NewRouter( @@ -30,6 +31,7 @@ func NewRouter( router.HandleFunc("/image/{name}", fc.Upload).Methods(http.MethodGet, http.MethodOptions) router.HandleFunc("/image", fc.Download).Methods(http.MethodPost, http.MethodOptions) + router.HandleFunc("/files/{name}", fc.UploadNonImage).Methods(http.MethodGet, http.MethodOptions) router.Handle("/api/v1/metrics", promhttp.Handler()) diff --git a/internal/router/file/router_test.go b/internal/router/file/router_test.go index 6af2200b..1d88b5c0 100644 --- a/internal/router/file/router_test.go +++ b/internal/router/file/router_test.go @@ -27,6 +27,8 @@ func (m mockSessionManager) Destroy(sess *models.Session) error { type mockFileController struct{} +func (m mockFileController) UploadNonImage(w http.ResponseWriter, r *http.Request) {} + func (m mockFileController) Upload(w http.ResponseWriter, r *http.Request) {} func (m mockFileController) Download(w http.ResponseWriter, r *http.Request) {} diff --git a/internal/router/post/router.go b/internal/router/post/router.go index 43c3b3a7..5a3c5c5d 100644 --- a/internal/router/post/router.go +++ b/internal/router/post/router.go @@ -27,6 +27,11 @@ type Controller interface { SetLikeOnPost(w http.ResponseWriter, r *http.Request) DeleteLikeFromPost(w http.ResponseWriter, r *http.Request) + + Comment(w http.ResponseWriter, r *http.Request) + DeleteComment(w http.ResponseWriter, r *http.Request) + EditComment(w http.ResponseWriter, r *http.Request) + GetComments(w http.ResponseWriter, r *http.Request) } func NewRouter( @@ -39,6 +44,13 @@ func NewRouter( router.HandleFunc("/api/v1/feed/{id}", contr.Delete).Methods(http.MethodDelete, http.MethodOptions) router.HandleFunc("/api/v1/feed", contr.GetBatchPosts).Methods(http.MethodGet, http.MethodOptions) + router.HandleFunc("/api/v1/feed/{id}", contr.Comment).Methods(http.MethodPost, http.MethodOptions) + router.HandleFunc("/api/v1/feed/{id}/{comment_id}", contr.EditComment).Methods(http.MethodPut, http.MethodOptions) + router.HandleFunc("/api/v1/feed/{id}/{comment_id}", contr.DeleteComment).Methods( + http.MethodDelete, http.MethodOptions, + ) + router.HandleFunc("/api/v1/feed/{id}/comments", contr.GetComments).Methods(http.MethodGet, http.MethodOptions) + router.HandleFunc("/api/v1/feed/{id}/like", contr.SetLikeOnPost).Methods(http.MethodPost, http.MethodOptions) router.HandleFunc("/api/v1/feed/{id}/unlike", contr.DeleteLikeFromPost).Methods(http.MethodPost, http.MethodOptions) diff --git a/internal/router/post/router_test.go b/internal/router/post/router_test.go index 89058dd5..ea893a5c 100644 --- a/internal/router/post/router_test.go +++ b/internal/router/post/router_test.go @@ -27,6 +27,14 @@ func (m mockSessionManager) Destroy(sess *models.Session) error { type mockPostController struct{} +func (m mockPostController) Comment(w http.ResponseWriter, r *http.Request) {} + +func (m mockPostController) DeleteComment(w http.ResponseWriter, r *http.Request) {} + +func (m mockPostController) EditComment(w http.ResponseWriter, r *http.Request) {} + +func (m mockPostController) GetComments(w http.ResponseWriter, r *http.Request) {} + func (m mockPostController) SetLikeOnPost(w http.ResponseWriter, r *http.Request) {} func (m mockPostController) DeleteLikeFromPost(w http.ResponseWriter, r *http.Request) {} diff --git a/internal/router/profile/router.go b/internal/router/profile/router.go index b7064138..05b7daca 100644 --- a/internal/router/profile/router.go +++ b/internal/router/profile/router.go @@ -30,6 +30,7 @@ type ProfileController interface { GetAllSubscriptions(w http.ResponseWriter, r *http.Request) GetCommunitySubs(w http.ResponseWriter, r *http.Request) + ChangePassword(w http.ResponseWriter, r *http.Request) } type SessionManager interface { @@ -51,7 +52,7 @@ func NewRouter( router.HandleFunc("/api/v1/profile/{id}", profileControl.GetProfileById).Methods(http.MethodGet, http.MethodOptions) router.HandleFunc("/api/v1/profiles", profileControl.GetAll).Methods(http.MethodGet, http.MethodOptions) router.HandleFunc("/api/v1/profile", profileControl.UpdateProfile).Methods(http.MethodPut, http.MethodOptions) - router.HandleFunc("api/v1/profile", profileControl.DeleteProfile).Methods(http.MethodDelete, http.MethodOptions) + router.HandleFunc("/api/v1/profile", profileControl.DeleteProfile).Methods(http.MethodDelete, http.MethodOptions) router.HandleFunc("/api/v1/profile/{id}/friend/subscribe", profileControl.SendFriendReq).Methods( http.MethodPost, http.MethodOptions, ) @@ -79,6 +80,9 @@ func NewRouter( router.HandleFunc("/api/v1/profile/search/", profileControl.SearchProfile).Methods( http.MethodGet, http.MethodOptions, ) + router.HandleFunc("/api/v1/profile/password", profileControl.ChangePassword).Methods( + http.MethodPut, http.MethodOptions, + ) router.Handle("/api/v1/metrics", promhttp.Handler()) diff --git a/internal/router/profile/router_test.go b/internal/router/profile/router_test.go index 8e30b2cc..0f536550 100644 --- a/internal/router/profile/router_test.go +++ b/internal/router/profile/router_test.go @@ -57,6 +57,8 @@ func (m mockProfileController) GetCommunitySubs(w http.ResponseWriter, r *http.R func (m mockProfileController) SearchProfile(w http.ResponseWriter, r *http.Request) {} +func (m mockProfileController) ChangePassword(w http.ResponseWriter, r *http.Request) {} + func TestNewRouter(t *testing.T) { r := NewRouter(mockProfileController{}, mockSessionManager{}, logrus.New(), &metrics.HttpMetrics{}) assert.NotNil(t, r) diff --git a/internal/router/responder.go b/internal/router/responder.go index 29ce8895..8e69de91 100644 --- a/internal/router/responder.go +++ b/internal/router/responder.go @@ -6,6 +6,7 @@ import ( "errors" "net/http" + "github.com/mailru/easyjson" log "github.com/sirupsen/logrus" "github.com/2024_2_BetterCallFirewall/pkg/my_err" @@ -22,12 +23,14 @@ func fullUnwrap(err error) error { return last } +//esyjson:json type Response struct { Success bool `json:"success"` Data any `json:"data,omitempty"` Message string `json:"message,omitempty"` } +//easyjson:skip type Respond struct { logger *log.Logger } @@ -47,7 +50,8 @@ func (r *Respond) OutputJSON(w http.ResponseWriter, data any, requestID string) w.WriteHeader(http.StatusOK) r.logger.Infof("req: %s: success request", requestID) - if err := json.NewEncoder(w).Encode(&Response{Success: true, Data: data}); err != nil { + + if _, err := easyjson.MarshalToWriter(&Response{Success: true, Data: data}, w); err != nil { r.logger.Error(err) } } @@ -60,8 +64,15 @@ func (r *Respond) OutputNoMoreContentJSON(w http.ResponseWriter, requestID strin } func (r *Respond) OutputBytes(w http.ResponseWriter, data []byte, requestID string) { - w.Header().Set("Content-Type", "image/avif,image/webp") - w.Header().Set("Access-Control-Allow-Origin", "http://185.241.194.197:8000") + var format string + if len(data) < 512 { + format = http.DetectContentType(data) + } else { + format = http.DetectContentType(data[:512]) + } + + w.Header().Set("Content-Type", format) + w.Header().Set("Access-Control-Allow-Origin", "http://vilka.online") w.Header().Set("Access-Control-Allow-Credentials", "true") w.WriteHeader(http.StatusOK) diff --git a/internal/router/responder_easyjson.go b/internal/router/responder_easyjson.go new file mode 100644 index 00000000..2c504463 --- /dev/null +++ b/internal/router/responder_easyjson.go @@ -0,0 +1,111 @@ +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package router + +import ( + json "encoding/json" + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) + +func easyjson7e1c2d3cDecodeGithubCom20242BetterCallFirewallInternalRouter(in *jlexer.Lexer, out *Response) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "success": + out.Success = bool(in.Bool()) + case "data": + if m, ok := out.Data.(easyjson.Unmarshaler); ok { + m.UnmarshalEasyJSON(in) + } else if m, ok := out.Data.(json.Unmarshaler); ok { + _ = m.UnmarshalJSON(in.Raw()) + } else { + out.Data = in.Interface() + } + case "message": + out.Message = string(in.String()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson7e1c2d3cEncodeGithubCom20242BetterCallFirewallInternalRouter(out *jwriter.Writer, in Response) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"success\":" + out.RawString(prefix[1:]) + out.Bool(bool(in.Success)) + } + if in.Data != nil { + const prefix string = ",\"data\":" + out.RawString(prefix) + if m, ok := in.Data.(easyjson.Marshaler); ok { + m.MarshalEasyJSON(out) + } else if m, ok := in.Data.(json.Marshaler); ok { + out.Raw(m.MarshalJSON()) + } else { + out.Raw(json.Marshal(in.Data)) + } + } + if in.Message != "" { + const prefix string = ",\"message\":" + out.RawString(prefix) + out.String(string(in.Message)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v Response) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson7e1c2d3cEncodeGithubCom20242BetterCallFirewallInternalRouter(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v Response) MarshalEasyJSON(w *jwriter.Writer) { + easyjson7e1c2d3cEncodeGithubCom20242BetterCallFirewallInternalRouter(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *Response) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson7e1c2d3cDecodeGithubCom20242BetterCallFirewallInternalRouter(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *Response) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson7e1c2d3cDecodeGithubCom20242BetterCallFirewallInternalRouter(l, v) +} diff --git a/internal/router/stickers/router.go b/internal/router/stickers/router.go new file mode 100644 index 00000000..3180fb79 --- /dev/null +++ b/internal/router/stickers/router.go @@ -0,0 +1,37 @@ +package stickers + +import ( + "net/http" + + "github.com/gorilla/mux" + "github.com/sirupsen/logrus" + + "github.com/2024_2_BetterCallFirewall/internal/middleware" + "github.com/2024_2_BetterCallFirewall/internal/models" +) + +type Controller interface { + AddNewSticker(w http.ResponseWriter, r *http.Request) + GetAllStickers(w http.ResponseWriter, r *http.Request) + GetMineStickers(w http.ResponseWriter, r *http.Request) +} + +type SessionManager interface { + Check(string) (*models.Session, error) + Create(userID uint32) (*models.Session, error) + Destroy(sess *models.Session) error +} + +func NewRouter(controller Controller, sm SessionManager, logger *logrus.Logger) http.Handler { + router := mux.NewRouter() + + router.HandleFunc("/api/v1/stickers/all", controller.GetAllStickers).Methods(http.MethodGet, http.MethodOptions) + router.HandleFunc("/api/v1/stickers", controller.AddNewSticker).Methods(http.MethodPost, http.MethodOptions) + router.HandleFunc("/api/v1/stickers", controller.GetMineStickers).Methods(http.MethodGet, http.MethodOptions) + + res := middleware.Auth(sm, router) + res = middleware.Preflite(res) + res = middleware.AccessLog(logger, res) + + return res +} diff --git a/internal/router/stickers/touter_test.go b/internal/router/stickers/touter_test.go new file mode 100644 index 00000000..ba5de64a --- /dev/null +++ b/internal/router/stickers/touter_test.go @@ -0,0 +1,38 @@ +package stickers + +import ( + "net/http" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/2024_2_BetterCallFirewall/internal/models" +) + +type mockSessionManager struct{} + +func (m mockSessionManager) Check(s string) (*models.Session, error) { + return nil, nil +} + +func (m mockSessionManager) Create(userID uint32) (*models.Session, error) { + return nil, nil +} + +func (m mockSessionManager) Destroy(sess *models.Session) error { + return nil +} + +type mockStickerController struct{} + +func (m mockStickerController) GetAllStickers(w http.ResponseWriter, r *http.Request) {} + +func (m mockStickerController) AddNewSticker(w http.ResponseWriter, r *http.Request) {} + +func (m mockStickerController) GetMineStickers(w http.ResponseWriter, r *http.Request) {} + +func TestNewRouter(t *testing.T) { + r := NewRouter(mockStickerController{}, mockSessionManager{}, logrus.New()) + assert.NotNil(t, r) +} diff --git a/internal/stickers/controller/controller.go b/internal/stickers/controller/controller.go new file mode 100644 index 00000000..152f2b7b --- /dev/null +++ b/internal/stickers/controller/controller.go @@ -0,0 +1,136 @@ +package controller + +import ( + "encoding/json" + "errors" + "net/http" + "strings" + + "github.com/microcosm-cc/bluemonday" + + "github.com/2024_2_BetterCallFirewall/internal/middleware" + "github.com/2024_2_BetterCallFirewall/internal/models" + "github.com/2024_2_BetterCallFirewall/internal/stickers" + "github.com/2024_2_BetterCallFirewall/pkg/my_err" +) + +const imagePrefix = "/image/" + +//go:generate mockgen -destination=mock.go -source=$GOFILE -package=${GOPACKAGE} +type Responder interface { + OutputJSON(w http.ResponseWriter, data any, requestID string) + OutputNoMoreContentJSON(w http.ResponseWriter, requestID string) + + ErrorBadRequest(w http.ResponseWriter, err error, requestID string) + ErrorInternal(w http.ResponseWriter, err error, requestID string) + LogError(err error, requestID string) +} + +type StickersHandlerImplementation struct { + StickersManager stickers.Usecase + Responder Responder +} + +func NewStickerController(manager stickers.Usecase, responder Responder) *StickersHandlerImplementation { + return &StickersHandlerImplementation{ + StickersManager: manager, + Responder: responder, + } +} + +func sanitize(input string) string { + sanitizer := bluemonday.UGCPolicy() + cleaned := sanitizer.Sanitize(input) + return cleaned +} + +func sanitizeFiles(pics []*models.Picture) { + for _, pic := range pics { + *pic = models.Picture(sanitize(string(*pic))) + } +} + +func (s StickersHandlerImplementation) AddNewSticker(w http.ResponseWriter, r *http.Request) { + reqID, ok := r.Context().Value(middleware.RequestKey).(string) + + if !ok { + s.Responder.LogError(my_err.ErrInvalidContext, "") + } + + filePath := models.StickerRequest{} + if err := json.NewDecoder(r.Body).Decode(&filePath); err != nil { + s.Responder.ErrorBadRequest(w, my_err.ErrNoFile, reqID) + return + } + + filePath.File = sanitize(filePath.File) + if !validate(filePath.File) { + s.Responder.ErrorBadRequest(w, my_err.ErrNoImage, reqID) + return + } + + sess, err := models.SessionFromContext(r.Context()) + if err != nil { + s.Responder.ErrorBadRequest(w, err, reqID) + return + } + + err = s.StickersManager.AddNewSticker(r.Context(), filePath.File, sess.UserID) + if err != nil { + s.Responder.ErrorInternal(w, err, reqID) + return + } + + s.Responder.OutputJSON(w, "New sticker is added", reqID) +} + +func (s StickersHandlerImplementation) GetAllStickers(w http.ResponseWriter, r *http.Request) { + reqID, ok := r.Context().Value(middleware.RequestKey).(string) + + if !ok { + s.Responder.LogError(my_err.ErrInvalidContext, "") + } + + res, err := s.StickersManager.GetAllStickers(r.Context()) + if err != nil { + if errors.Is(err, my_err.ErrNoStickers) { + s.Responder.OutputNoMoreContentJSON(w, reqID) + return + } + s.Responder.ErrorInternal(w, err, reqID) + return + } + + sanitizeFiles(res) + s.Responder.OutputJSON(w, res, reqID) +} + +func (s StickersHandlerImplementation) GetMineStickers(w http.ResponseWriter, r *http.Request) { + reqID, ok := r.Context().Value(middleware.RequestKey).(string) + + if !ok { + s.Responder.LogError(my_err.ErrInvalidContext, "") + } + + sess, err := models.SessionFromContext(r.Context()) + if err != nil { + s.Responder.ErrorBadRequest(w, err, reqID) + return + } + + res, err := s.StickersManager.GetMineStickers(r.Context(), sess.UserID) + if err != nil { + if errors.Is(err, my_err.ErrNoStickers) { + s.Responder.OutputNoMoreContentJSON(w, reqID) + return + } + s.Responder.ErrorInternal(w, err, reqID) + return + } + sanitizeFiles(res) + s.Responder.OutputJSON(w, res, reqID) +} + +func validate(filepath string) bool { + return len([]rune(filepath)) < 100 && strings.HasPrefix(filepath, imagePrefix) +} diff --git a/internal/stickers/controller/controller_test.go b/internal/stickers/controller/controller_test.go new file mode 100644 index 00000000..fff95e3a --- /dev/null +++ b/internal/stickers/controller/controller_test.go @@ -0,0 +1,557 @@ +package controller + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/2024_2_BetterCallFirewall/internal/models" + "github.com/2024_2_BetterCallFirewall/pkg/my_err" +) + +var errMock = errors.New("mock error") + +func getController(ctrl *gomock.Controller) (*StickersHandlerImplementation, *mocks) { + m := &mocks{ + stickerService: NewMockUsecase(ctrl), + responder: NewMockResponder(ctrl), + } + + return NewStickerController(m.stickerService, m.responder), m +} + +func TestNewPostController(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + handler, _ := getController(ctrl) + assert.NotNil(t, handler) +} + +func TestAddNewSticker(t *testing.T) { + tests := []TableTest[Response, Request]{ + { + name: "1", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPost, "/api/v1/stickers", nil) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *StickersHandlerImplementation, request Request, + ) (Response, error) { + implementation.AddNewSticker(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "2", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodPost, "/api/v1/stickers", bytes.NewBuffer([]byte(`"files"`))) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *StickersHandlerImplementation, request Request, + ) (Response, error) { + implementation.AddNewSticker(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "3", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodPost, "/api/v1/stickers", bytes.NewBuffer([]byte(`"/image/someimage"`)), + ) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *StickersHandlerImplementation, request Request, + ) (Response, error) { + implementation.AddNewSticker(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + if _, err1 := request.w.Write([]byte("bad request")); err1 != nil { + panic(err1) + } + }, + ) + }, + }, + { + name: "4", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodPost, "/api/v1/stickers", bytes.NewBuffer([]byte(`{"file":"/image/someimage"}`)), + ) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *StickersHandlerImplementation, request Request, + ) (Response, error) { + implementation.AddNewSticker(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusInternalServerError, Body: "error"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.stickerService.EXPECT().AddNewSticker(gomock.Any(), gomock.Any(), gomock.Any()).Return(errMock) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + _, _ = request.w.Write([]byte("error")) + }, + ) + }, + }, + { + name: "5", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest( + http.MethodPost, "/api/v1/stickers", bytes.NewBuffer([]byte(`{"file":"/image/someimage"}`)), + ) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *StickersHandlerImplementation, request Request, + ) (Response, error) { + implementation.AddNewSticker(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusOK, Body: "OK"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.stickerService.EXPECT().AddNewSticker(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + _, _ = request.w.Write([]byte("OK")) + }, + ) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +func TestSanitize(t *testing.T) { + test := "" + expected := "" + res := sanitize(test) + assert.Equal(t, expected, res) +} + +func TestSanitizeFiles(t *testing.T) { + xssV := models.Picture("") + file := models.Picture("filepath") + empty := models.Picture("") + test := []*models.Picture{&xssV, &file} + expected := []*models.Picture{&empty, &file} + sanitizeFiles(test) + assert.Equal(t, expected, test) +} + +func TestGetAllSticker(t *testing.T) { + tests := []TableTest[Response, Request]{ + { + name: "1", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/stickers/all", nil) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *StickersHandlerImplementation, request Request, + ) (Response, error) { + implementation.GetAllStickers(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusNoContent}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.stickerService.EXPECT().GetAllStickers(gomock.Any()).Return(nil, my_err.ErrNoStickers) + m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do( + func(w, req any) { + request.w.WriteHeader(http.StatusNoContent) + }, + ) + }, + }, + { + name: "2", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/stickers/all", nil) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *StickersHandlerImplementation, request Request, + ) (Response, error) { + implementation.GetAllStickers(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusInternalServerError, Body: "error"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.stickerService.EXPECT().GetAllStickers(gomock.Any()).Return(nil, errMock) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + _, _ = request.w.Write([]byte("error")) + }, + ) + }, + }, + { + name: "3", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/stickers/all", nil) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *StickersHandlerImplementation, request Request, + ) (Response, error) { + implementation.GetAllStickers(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusOK, Body: "OK"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + pic := models.Picture("/image/sticker1") + m.stickerService.EXPECT().GetAllStickers(gomock.Any()).Return([]*models.Picture{&pic}, nil) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + _, _ = request.w.Write([]byte("OK")) + }, + ) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +func TestGetMineSticker(t *testing.T) { + tests := []TableTest[Response, Request]{ + { + name: "1", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/stickers", nil) + w := httptest.NewRecorder() + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *StickersHandlerImplementation, request Request, + ) (Response, error) { + implementation.GetMineStickers(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusBadRequest, Body: "bad request"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.responder.EXPECT().ErrorBadRequest(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusBadRequest) + _, _ = request.w.Write([]byte("bad request")) + }, + ) + }, + }, + { + name: "2", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/stickers", nil) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *StickersHandlerImplementation, request Request, + ) (Response, error) { + implementation.GetMineStickers(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusInternalServerError, Body: "error"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.stickerService.EXPECT().GetMineStickers(gomock.Any(), gomock.Any()).Return(nil, errMock) + m.responder.EXPECT().ErrorInternal(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusInternalServerError) + _, _ = request.w.Write([]byte("error")) + }, + ) + }, + }, + { + name: "3", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/stickers", nil) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *StickersHandlerImplementation, request Request, + ) (Response, error) { + implementation.GetMineStickers(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusOK, Body: "OK"}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + pic := models.Picture("/image/sticker1") + m.stickerService.EXPECT().GetMineStickers(gomock.Any(), gomock.Any()).Return( + []*models.Picture{&pic}, nil, + ) + m.responder.EXPECT().OutputJSON(request.w, gomock.Any(), gomock.Any()).Do( + func(w, err, req any) { + request.w.WriteHeader(http.StatusOK) + _, _ = request.w.Write([]byte("OK")) + }, + ) + }, + }, + { + name: "4", + SetupInput: func() (*Request, error) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/stickers", nil) + w := httptest.NewRecorder() + req = req.WithContext(models.ContextWithSession(req.Context(), &models.Session{ID: "1", UserID: 1})) + res := &Request{r: req, w: w} + return res, nil + }, + Run: func( + ctx context.Context, implementation *StickersHandlerImplementation, request Request, + ) (Response, error) { + implementation.GetMineStickers(request.w, request.r) + res := Response{StatusCode: request.w.Code, Body: request.w.Body.String()} + return res, nil + }, + ExpectedResult: func() (Response, error) { + return Response{StatusCode: http.StatusNoContent}, nil + }, + ExpectedErr: nil, + SetupMock: func(request Request, m *mocks) { + m.responder.EXPECT().LogError(gomock.Any(), gomock.Any()) + m.stickerService.EXPECT().GetMineStickers(gomock.Any(), gomock.Any()).Return( + nil, my_err.ErrNoStickers, + ) + m.responder.EXPECT().OutputNoMoreContentJSON(request.w, gomock.Any()).Do( + func(w, req any) { + request.w.WriteHeader(http.StatusNoContent) + }, + ) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getController(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +type mocks struct { + stickerService *MockUsecase + responder *MockResponder +} + +type Request struct { + w *httptest.ResponseRecorder + r *http.Request +} + +type Response struct { + StatusCode int + Body string +} + +type TableTest[T, In any] struct { + name string + SetupInput func() (*In, error) + Run func(context.Context, *StickersHandlerImplementation, In) (T, error) + ExpectedResult func() (T, error) + ExpectedErr error + SetupMock func(In, *mocks) +} diff --git a/internal/stickers/controller/mock.go b/internal/stickers/controller/mock.go new file mode 100644 index 00000000..cabb0cfa --- /dev/null +++ b/internal/stickers/controller/mock.go @@ -0,0 +1,172 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: controller.go +// +// Generated by this command: +// +// mockgen -destination=mock.go -source=controller.go -package=controller +// + +// Package controller is a generated GoMock package. +package controller + +import ( + context "context" + http "net/http" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + + models "github.com/2024_2_BetterCallFirewall/internal/models" +) + +// MockUsecase is a mock of Usecase interface. +type MockUsecase struct { + ctrl *gomock.Controller + recorder *MockUsecaseMockRecorder + isgomock struct{} +} + +// MockUsecaseMockRecorder is the mock recorder for MockUsecase. +type MockUsecaseMockRecorder struct { + mock *MockUsecase +} + +// NewMockUsecase creates a new mock instance. +func NewMockUsecase(ctrl *gomock.Controller) *MockUsecase { + mock := &MockUsecase{ctrl: ctrl} + mock.recorder = &MockUsecaseMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUsecase) EXPECT() *MockUsecaseMockRecorder { + return m.recorder +} + +// AddNewSticker mocks base method. +func (m *MockUsecase) AddNewSticker(ctx context.Context, filepath string, userID uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddNewSticker", ctx, filepath, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddNewSticker indicates an expected call of AddNewSticker. +func (mr *MockUsecaseMockRecorder) AddNewSticker(ctx, filepath, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddNewSticker", reflect.TypeOf((*MockUsecase)(nil).AddNewSticker), ctx, filepath, userID) +} + +// GetAllStickers mocks base method. +func (m *MockUsecase) GetAllStickers(ctx context.Context) ([]*models.Picture, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllStickers", ctx) + ret0, _ := ret[0].([]*models.Picture) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllStickers indicates an expected call of GetAllStickers. +func (mr *MockUsecaseMockRecorder) GetAllStickers(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllStickers", reflect.TypeOf((*MockUsecase)(nil).GetAllStickers), ctx) +} + +// GetMineStickers mocks base method. +func (m *MockUsecase) GetMineStickers(ctx context.Context, userID uint32) ([]*models.Picture, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMineStickers", ctx, userID) + ret0, _ := ret[0].([]*models.Picture) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMineStickers indicates an expected call of GetMineStickers. +func (mr *MockUsecaseMockRecorder) GetMineStickers(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMineStickers", reflect.TypeOf((*MockUsecase)(nil).GetMineStickers), ctx, userID) +} + +// MockResponder is a mock of Responder interface. +type MockResponder struct { + ctrl *gomock.Controller + recorder *MockResponderMockRecorder + isgomock struct{} +} + +// MockResponderMockRecorder is the mock recorder for MockResponder. +type MockResponderMockRecorder struct { + mock *MockResponder +} + +// NewMockResponder creates a new mock instance. +func NewMockResponder(ctrl *gomock.Controller) *MockResponder { + mock := &MockResponder{ctrl: ctrl} + mock.recorder = &MockResponderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResponder) EXPECT() *MockResponderMockRecorder { + return m.recorder +} + +// ErrorBadRequest mocks base method. +func (m *MockResponder) ErrorBadRequest(w http.ResponseWriter, err error, requestID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ErrorBadRequest", w, err, requestID) +} + +// ErrorBadRequest indicates an expected call of ErrorBadRequest. +func (mr *MockResponderMockRecorder) ErrorBadRequest(w, err, requestID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ErrorBadRequest", reflect.TypeOf((*MockResponder)(nil).ErrorBadRequest), w, err, requestID) +} + +// ErrorInternal mocks base method. +func (m *MockResponder) ErrorInternal(w http.ResponseWriter, err error, requestID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ErrorInternal", w, err, requestID) +} + +// ErrorInternal indicates an expected call of ErrorInternal. +func (mr *MockResponderMockRecorder) ErrorInternal(w, err, requestID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ErrorInternal", reflect.TypeOf((*MockResponder)(nil).ErrorInternal), w, err, requestID) +} + +// LogError mocks base method. +func (m *MockResponder) LogError(err error, requestID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LogError", err, requestID) +} + +// LogError indicates an expected call of LogError. +func (mr *MockResponderMockRecorder) LogError(err, requestID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LogError", reflect.TypeOf((*MockResponder)(nil).LogError), err, requestID) +} + +// OutputJSON mocks base method. +func (m *MockResponder) OutputJSON(w http.ResponseWriter, data any, requestID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OutputJSON", w, data, requestID) +} + +// OutputJSON indicates an expected call of OutputJSON. +func (mr *MockResponderMockRecorder) OutputJSON(w, data, requestID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutputJSON", reflect.TypeOf((*MockResponder)(nil).OutputJSON), w, data, requestID) +} + +// OutputNoMoreContentJSON mocks base method. +func (m *MockResponder) OutputNoMoreContentJSON(w http.ResponseWriter, requestID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OutputNoMoreContentJSON", w, requestID) +} + +// OutputNoMoreContentJSON indicates an expected call of OutputNoMoreContentJSON. +func (mr *MockResponderMockRecorder) OutputNoMoreContentJSON(w, requestID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OutputNoMoreContentJSON", reflect.TypeOf((*MockResponder)(nil).OutputNoMoreContentJSON), w, requestID) +} diff --git a/internal/stickers/repository.go b/internal/stickers/repository.go new file mode 100644 index 00000000..d1f10877 --- /dev/null +++ b/internal/stickers/repository.go @@ -0,0 +1,13 @@ +package stickers + +import ( + "context" + + "github.com/2024_2_BetterCallFirewall/internal/models" +) + +type Repository interface { + AddNewSticker(ctx context.Context, filepath string, userID uint32) error + GetAllStickers(ctx context.Context) ([]*models.Picture, error) + GetMineStickers(ctx context.Context, userID uint32) ([]*models.Picture, error) +} diff --git a/internal/stickers/repository/QueryConsts.go b/internal/stickers/repository/QueryConsts.go new file mode 100644 index 00000000..8ea9f2fc --- /dev/null +++ b/internal/stickers/repository/QueryConsts.go @@ -0,0 +1,7 @@ +package repository + +const ( + InsertNewSticker = `INSERT INTO sticker(file_path, profile_id) VALUES ($1, $2)` + GetAllSticker = `SELECT file_path FROM sticker` + GetUserStickers = `SELECT file_path FROM sticker WHERE user_id = $1` +) diff --git a/internal/stickers/repository/postgres.go b/internal/stickers/repository/postgres.go new file mode 100644 index 00000000..829834a2 --- /dev/null +++ b/internal/stickers/repository/postgres.go @@ -0,0 +1,73 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + + "github.com/2024_2_BetterCallFirewall/internal/models" + "github.com/2024_2_BetterCallFirewall/pkg/my_err" +) + +type StickerRepo struct { + DB *sql.DB +} + +func NewStickerRepo(db *sql.DB) *StickerRepo { + repo := &StickerRepo{ + DB: db, + } + return repo +} + +func (s StickerRepo) AddNewSticker(ctx context.Context, filepath string, userID uint32) error { + _, err := s.DB.ExecContext(ctx, InsertNewSticker, filepath, userID) + if err != nil { + return err + } + return nil +} + +func (s StickerRepo) GetAllStickers(ctx context.Context) ([]*models.Picture, error) { + rows, err := s.DB.QueryContext(ctx, GetAllSticker) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, my_err.ErrNoStickers + } + return nil, err + } + defer rows.Close() + var res []*models.Picture + for rows.Next() { + var pic models.Picture + err = rows.Scan(&pic) + if err != nil { + return nil, err + } + res = append(res, &pic) + } + + return res, nil +} + +func (s StickerRepo) GetMineStickers(ctx context.Context, userID uint32) ([]*models.Picture, error) { + rows, err := s.DB.QueryContext(ctx, GetUserStickers, userID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, my_err.ErrNoStickers + } + return nil, err + } + defer rows.Close() + var res []*models.Picture + for rows.Next() { + var pic models.Picture + err = rows.Scan(&pic) + if err != nil { + return nil, err + } + res = append(res, &pic) + } + + return res, nil +} diff --git a/internal/stickers/service/mock.go b/internal/stickers/service/mock.go new file mode 100644 index 00000000..6c63501b --- /dev/null +++ b/internal/stickers/service/mock.go @@ -0,0 +1,86 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: usecase.go +// +// Generated by this command: +// +// mockgen -destination=mock.go -source=usecase.go -package=service +// + +// Package service is a generated GoMock package. +package service + +import ( + context "context" + reflect "reflect" + + models "github.com/2024_2_BetterCallFirewall/internal/models" + gomock "github.com/golang/mock/gomock" +) + +// MockRepository is a mock of Repository interface. +type MockRepository struct { + ctrl *gomock.Controller + recorder *MockRepositoryMockRecorder + isgomock struct{} +} + +// MockRepositoryMockRecorder is the mock recorder for MockRepository. +type MockRepositoryMockRecorder struct { + mock *MockRepository +} + +// NewMockRepository creates a new mock instance. +func NewMockRepository(ctrl *gomock.Controller) *MockRepository { + mock := &MockRepository{ctrl: ctrl} + mock.recorder = &MockRepositoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRepository) EXPECT() *MockRepositoryMockRecorder { + return m.recorder +} + +// AddNewSticker mocks base method. +func (m *MockRepository) AddNewSticker(ctx context.Context, filepath string, userID uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddNewSticker", ctx, filepath, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddNewSticker indicates an expected call of AddNewSticker. +func (mr *MockRepositoryMockRecorder) AddNewSticker(ctx, filepath, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddNewSticker", reflect.TypeOf((*MockRepository)(nil).AddNewSticker), ctx, filepath, userID) +} + +// GetAllStickers mocks base method. +func (m *MockRepository) GetAllStickers(ctx context.Context) ([]*models.Picture, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllStickers", ctx) + ret0, _ := ret[0].([]*models.Picture) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllStickers indicates an expected call of GetAllStickers. +func (mr *MockRepositoryMockRecorder) GetAllStickers(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllStickers", reflect.TypeOf((*MockRepository)(nil).GetAllStickers), ctx) +} + +// GetMineStickers mocks base method. +func (m *MockRepository) GetMineStickers(ctx context.Context, userID uint32) ([]*models.Picture, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMineStickers", ctx, userID) + ret0, _ := ret[0].([]*models.Picture) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMineStickers indicates an expected call of GetMineStickers. +func (mr *MockRepositoryMockRecorder) GetMineStickers(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMineStickers", reflect.TypeOf((*MockRepository)(nil).GetMineStickers), ctx, userID) +} diff --git a/internal/stickers/service/usecase.go b/internal/stickers/service/usecase.go new file mode 100644 index 00000000..e4891e6b --- /dev/null +++ b/internal/stickers/service/usecase.go @@ -0,0 +1,40 @@ +package service + +import ( + "context" + + "github.com/2024_2_BetterCallFirewall/internal/models" + "github.com/2024_2_BetterCallFirewall/internal/stickers" +) + +type StickerUsecaseImplementation struct { + repo stickers.Repository +} + +func NewStickerUsecase(stickerRepo stickers.Repository) *StickerUsecaseImplementation { + return &StickerUsecaseImplementation{repo: stickerRepo} +} + +func (s StickerUsecaseImplementation) AddNewSticker(ctx context.Context, filepath string, userID uint32) error { + err := s.repo.AddNewSticker(ctx, filepath, userID) + if err != nil { + return err + } + return nil +} + +func (s StickerUsecaseImplementation) GetAllStickers(ctx context.Context) ([]*models.Picture, error) { + res, err := s.repo.GetAllStickers(ctx) + if err != nil { + return nil, err + } + return res, nil +} + +func (s StickerUsecaseImplementation) GetMineStickers(ctx context.Context, userID uint32) ([]*models.Picture, error) { + res, err := s.repo.GetMineStickers(ctx, userID) + if err != nil { + return nil, err + } + return res, nil +} diff --git a/internal/stickers/service/usecase_test.go b/internal/stickers/service/usecase_test.go new file mode 100644 index 00000000..5258a714 --- /dev/null +++ b/internal/stickers/service/usecase_test.go @@ -0,0 +1,262 @@ +package service + +import ( + "context" + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/2024_2_BetterCallFirewall/internal/models" +) + +var ( + errMock = errors.New("mock error") + pic = models.Picture("/image/sticker") +) + +func getUseCase(ctrl *gomock.Controller) (*StickerUsecaseImplementation, *mocks) { + m := &mocks{ + repository: NewMockRepository(ctrl), + } + + return NewStickerUsecase(m.repository), m +} + +type input struct { + userID uint32 + filepath string +} + +func TestAddNewSticker(t *testing.T) { + tests := []TableTest[struct{}, input]{ + { + name: "1", + SetupInput: func() (*input, error) { + return &input{filepath: "", userID: 0}, nil + }, + Run: func( + ctx context.Context, implementation *StickerUsecaseImplementation, request input, + ) (struct{}, error) { + err := implementation.AddNewSticker(ctx, request.filepath, request.userID) + return struct{}{}, err + }, + ExpectedResult: func() (struct{}, error) { + return struct{}{}, nil + }, + ExpectedErr: errMock, + SetupMock: func(request input, m *mocks) { + m.repository.EXPECT().AddNewSticker(gomock.Any(), gomock.Any(), gomock.Any()). + Return(errMock) + }, + }, + { + name: "2", + SetupInput: func() (*input, error) { + return &input{filepath: "/image/sticker", userID: 1}, nil + }, + Run: func( + ctx context.Context, implementation *StickerUsecaseImplementation, request input, + ) (struct{}, error) { + err := implementation.AddNewSticker(ctx, request.filepath, request.userID) + return struct{}{}, err + }, + ExpectedResult: func() (struct{}, error) { + return struct{}{}, nil + }, + ExpectedErr: nil, + SetupMock: func(request input, m *mocks) { + m.repository.EXPECT().AddNewSticker(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getUseCase(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +func TestGetMineSticker(t *testing.T) { + tests := []TableTest[[]*models.Picture, uint32]{ + { + name: "1", + SetupInput: func() (*uint32, error) { + req := uint32(0) + return &req, nil + }, + Run: func( + ctx context.Context, implementation *StickerUsecaseImplementation, request uint32, + ) ([]*models.Picture, error) { + return implementation.GetMineStickers(ctx, request) + }, + ExpectedResult: func() ([]*models.Picture, error) { + return nil, nil + }, + ExpectedErr: errMock, + SetupMock: func(request uint32, m *mocks) { + m.repository.EXPECT().GetMineStickers(gomock.Any(), gomock.Any()).Return(nil, errMock) + }, + }, + { + name: "2", + SetupInput: func() (*uint32, error) { + req := uint32(10) + return &req, nil + }, + Run: func( + ctx context.Context, implementation *StickerUsecaseImplementation, request uint32, + ) ([]*models.Picture, error) { + return implementation.GetMineStickers(ctx, request) + }, + ExpectedResult: func() ([]*models.Picture, error) { + return []*models.Picture{&pic}, nil + }, + ExpectedErr: nil, + SetupMock: func(request uint32, m *mocks) { + m.repository.EXPECT().GetMineStickers(gomock.Any(), gomock.Any()).Return([]*models.Picture{&pic}, nil) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getUseCase(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +func TestGetAllSticker(t *testing.T) { + tests := []TableTest[[]*models.Picture, struct{}]{ + { + name: "1", + SetupInput: func() (*struct{}, error) { + return &struct{}{}, nil + }, + Run: func( + ctx context.Context, implementation *StickerUsecaseImplementation, request struct{}, + ) ([]*models.Picture, error) { + return implementation.GetAllStickers(ctx) + }, + ExpectedResult: func() ([]*models.Picture, error) { + return nil, nil + }, + ExpectedErr: errMock, + SetupMock: func(request struct{}, m *mocks) { + m.repository.EXPECT().GetAllStickers(gomock.Any()).Return(nil, errMock) + }, + }, + { + name: "2", + SetupInput: func() (*struct{}, error) { + return &struct{}{}, nil + }, + Run: func( + ctx context.Context, implementation *StickerUsecaseImplementation, request struct{}, + ) ([]*models.Picture, error) { + return implementation.GetAllStickers(ctx) + }, + ExpectedResult: func() ([]*models.Picture, error) { + return []*models.Picture{&pic}, nil + }, + ExpectedErr: nil, + SetupMock: func(request struct{}, m *mocks) { + m.repository.EXPECT().GetAllStickers(gomock.Any()).Return([]*models.Picture{&pic}, nil) + }, + }, + } + + for _, v := range tests { + t.Run( + v.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + serv, mock := getUseCase(ctrl) + ctx := context.Background() + + input, err := v.SetupInput() + if err != nil { + t.Error(err) + } + + v.SetupMock(*input, mock) + + res, err := v.ExpectedResult() + if err != nil { + t.Error(err) + } + + actual, err := v.Run(ctx, serv, *input) + assert.Equal(t, res, actual) + if !errors.Is(err, v.ExpectedErr) { + t.Errorf("expect %v, got %v", v.ExpectedErr, err) + } + }, + ) + } +} + +type TableTest[T, In any] struct { + name string + SetupInput func() (*In, error) + Run func(context.Context, *StickerUsecaseImplementation, In) (T, error) + ExpectedResult func() (T, error) + ExpectedErr error + SetupMock func(In, *mocks) +} + +type mocks struct { + repository *MockRepository +} diff --git a/internal/stickers/sticker_usecase.go b/internal/stickers/sticker_usecase.go new file mode 100644 index 00000000..315e6939 --- /dev/null +++ b/internal/stickers/sticker_usecase.go @@ -0,0 +1,13 @@ +package stickers + +import ( + "context" + + "github.com/2024_2_BetterCallFirewall/internal/models" +) + +type Usecase interface { + AddNewSticker(ctx context.Context, filepath string, userID uint32) error + GetAllStickers(ctx context.Context) ([]*models.Picture, error) + GetMineStickers(ctx context.Context, userID uint32) ([]*models.Picture, error) +} diff --git a/perf_test/README.md b/perf_test/README.md new file mode 100644 index 00000000..e69de29b diff --git a/pkg/my_err/error.go b/pkg/my_err/error.go index 04ceb597..ab339a95 100644 --- a/pkg/my_err/error.go +++ b/pkg/my_err/error.go @@ -33,5 +33,11 @@ var ( ErrLikeAlreadyExists = errors.New("like already exists") ErrWrongCommunity = errors.New("wrong community") ErrWrongPost = errors.New("wrong post") - ErrPostTooLong = errors.New("post len is too big") + ErrBadPostOrComment = errors.New("content is bad") + ErrWrongComment = errors.New("wrong comment") + ErrBadMessageContent = errors.New("bad message content") + ErrNoStickers = errors.New("no stickers found") + ErrNoImage = errors.New("file is no image") + ErrBadCommunity = errors.New("bad community data") + ErrBadUserInfo = errors.New("bad username or password") ) diff --git a/proto/post.proto b/proto/post.proto index 01ebd048..26fef9a7 100644 --- a/proto/post.proto +++ b/proto/post.proto @@ -30,11 +30,12 @@ message Post { Header Head = 3; uint32 LikesCount = 4; bool IsLiked = 5; + uint32 CommentCount = 6; } message Content { string Text = 1; - string File = 2; + repeated string File = 2; int64 CreatedAt = 3; int64 UpdatedAt = 4; } \ No newline at end of file diff --git a/test b/test new file mode 100644 index 00000000..30d74d25 --- /dev/null +++ b/test @@ -0,0 +1 @@ +test \ No newline at end of file