From 155b1977289acb26a6ea34da3914e3a6b5c54962 Mon Sep 17 00:00:00 2001 From: Chrstopher Hunter <8398225+crhntr@users.noreply.github.com> Date: Sat, 24 Aug 2024 09:27:49 -0700 Subject: [PATCH] feat: only write header when response is not a receiver method parameter --- generate.go | 26 ++++++++++++++++---------- generate_internal_test.go | 8 +++----- generate_test.go | 2 +- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/generate.go b/generate.go index d6ab07a..b304d48 100644 --- a/generate.go +++ b/generate.go @@ -80,7 +80,7 @@ func Generate(templateNames []TemplateName, packageName, templatesVariableName, Type: method, }) } - handlerFunc, methodImports, err := pattern.funcLit(templatesVariableName, method) + handlerFunc, methodImports, err := pattern.funcLit(method) if err != nil { return "", err } @@ -129,9 +129,9 @@ func (def TemplateName) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStm }} } -func (def TemplateName) funcLit(templatesVariableIdent string, method *ast.FuncType) (*ast.FuncLit, []*ast.ImportSpec, error) { +func (def TemplateName) funcLit(method *ast.FuncType) (*ast.FuncLit, []*ast.ImportSpec, error) { if def.handler == "" { - return def.httpRequestReceiverTemplateHandlerFunc(templatesVariableIdent), nil, nil + return def.httpRequestReceiverTemplateHandlerFunc(), nil, nil } lit := &ast.FuncLit{ Type: httpHandlerFuncType(), @@ -148,11 +148,17 @@ func (def TemplateName) funcLit(templatesVariableIdent string, method *ast.FuncT } } } - var imports []*ast.ImportSpec + var ( + imports []*ast.ImportSpec + writeHeader = true + ) for i, a := range def.call.Args { arg := a.(*ast.Ident) switch arg.Name { - case TemplateNameScopeIdentifierHTTPRequest, TemplateNameScopeIdentifierHTTPResponse: + case TemplateNameScopeIdentifierHTTPResponse: + writeHeader = false + fallthrough + case TemplateNameScopeIdentifierHTTPRequest: call.Args = append(call.Args, ast.NewIdent(arg.Name)) imports = append(imports, importSpec("net/http")) case TemplateNameScopeIdentifierContext: @@ -205,7 +211,7 @@ func (def TemplateName) funcLit(templatesVariableIdent string, method *ast.FuncT } else { lit.Body.List = append(lit.Body.List, &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}) } - lit.Body.List = append(lit.Body.List, def.executeCall(httpStatusCode(httpStatusCode200Ident), ast.NewIdent(dataVarIdent))) + lit.Body.List = append(lit.Body.List, def.executeCall(httpStatusCode(httpStatusCode200Ident), ast.NewIdent(dataVarIdent), writeHeader)) return lit, imports, nil } @@ -830,13 +836,13 @@ func paramParseError(errVar *ast.Ident) *ast.IfStmt { } } -func (def TemplateName) executeCall(status, data ast.Expr) *ast.ExprStmt { +func (def TemplateName) executeCall(status, data ast.Expr, writeHeader bool) *ast.ExprStmt { return &ast.ExprStmt{X: &ast.CallExpr{ Fun: ast.NewIdent(executeIdentName), Args: []ast.Expr{ ast.NewIdent(TemplateNameScopeIdentifierHTTPResponse), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - ast.NewIdent("true"), + ast.NewIdent(strconv.FormatBool(writeHeader)), &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(def.name)}, status, data, @@ -844,10 +850,10 @@ func (def TemplateName) executeCall(status, data ast.Expr) *ast.ExprStmt { }} } -func (def TemplateName) httpRequestReceiverTemplateHandlerFunc(templatesVariableName string) *ast.FuncLit { +func (def TemplateName) httpRequestReceiverTemplateHandlerFunc() *ast.FuncLit { return &ast.FuncLit{ Type: httpHandlerFuncType(), - Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(httpStatusCode(httpStatusCode200Ident), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest))}}, + Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(httpStatusCode(httpStatusCode200Ident), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}}, } } diff --git a/generate_internal_test.go b/generate_internal_test.go index 8d4fd26..4c32560 100644 --- a/generate_internal_test.go +++ b/generate_internal_test.go @@ -54,7 +54,7 @@ func TestTemplateName_funcLit(t *testing.T) { }, Out: `func(response http.ResponseWriter, request *http.Request) { data := receiver.F(response) - execute(response, request, true, "GET / F(response)", http.StatusOK, data) + execute(response, request, false, "GET / F(response)", http.StatusOK, data) }`, }, { @@ -105,8 +105,7 @@ func TestTemplateName_funcLit(t *testing.T) { pat, err, ok := NewTemplateName(tt.In) require.True(t, ok) require.NoError(t, err) - tv := "templates" - out, _, err := pat.funcLit(tv, tt.Method) + out, _, err := pat.funcLit(tt.Method) require.NoError(t, err) assert.Equal(t, tt.Out, source.Format(out)) }) @@ -212,8 +211,7 @@ func TestTemplateName_HandlerFuncLit_err(t *testing.T) { pat, err, ok := NewTemplateName(tt.In) require.True(t, ok) require.NoError(t, err) - tv := "templates" - _, _, err = pat.funcLit(tv, tt.Method) + _, _, err = pat.funcLit(tt.Method) assert.ErrorContains(t, err, tt.ErrSub) }) } diff --git a/generate_test.go b/generate_test.go index fe44022..d34e5e3 100644 --- a/generate_test.go +++ b/generate_test.go @@ -116,7 +116,7 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { projectID := request.PathValue("projectID") taskID := request.PathValue("taskID") data := receiver.F(ctx, response, request, projectID, taskID) - execute(response, request, true, "GET /project/{projectID}/task/{taskID} F(ctx, response, request, projectID, taskID)", http.StatusOK, data) + execute(response, request, false, "GET /project/{projectID}/task/{taskID} F(ctx, response, request, projectID, taskID)", http.StatusOK, data) }) } func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {