Skip to content

Commit

Permalink
feat: only write header when response is not a receiver method parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
crhntr committed Aug 24, 2024
1 parent 9a45e2e commit 155b197
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
26 changes: 16 additions & 10 deletions generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func Generate(templateNames []TemplateName, packageName, templatesVariableName,
Type: method,
})
}
handlerFunc, methodImports, err := pattern.funcLit(templatesVariableName, method)
handlerFunc, methodImports, err := pattern.funcLit(method)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -129,9 +129,9 @@ func (def TemplateName) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStm
}}
}

func (def TemplateName) funcLit(templatesVariableIdent string, method *ast.FuncType) (*ast.FuncLit, []*ast.ImportSpec, error) {
func (def TemplateName) funcLit(method *ast.FuncType) (*ast.FuncLit, []*ast.ImportSpec, error) {
if def.handler == "" {
return def.httpRequestReceiverTemplateHandlerFunc(templatesVariableIdent), nil, nil
return def.httpRequestReceiverTemplateHandlerFunc(), nil, nil
}
lit := &ast.FuncLit{
Type: httpHandlerFuncType(),
Expand All @@ -148,11 +148,17 @@ func (def TemplateName) funcLit(templatesVariableIdent string, method *ast.FuncT
}
}
}
var imports []*ast.ImportSpec
var (
imports []*ast.ImportSpec
writeHeader = true
)
for i, a := range def.call.Args {
arg := a.(*ast.Ident)
switch arg.Name {
case TemplateNameScopeIdentifierHTTPRequest, TemplateNameScopeIdentifierHTTPResponse:
case TemplateNameScopeIdentifierHTTPResponse:
writeHeader = false
fallthrough
case TemplateNameScopeIdentifierHTTPRequest:
call.Args = append(call.Args, ast.NewIdent(arg.Name))
imports = append(imports, importSpec("net/http"))
case TemplateNameScopeIdentifierContext:
Expand Down Expand Up @@ -205,7 +211,7 @@ func (def TemplateName) funcLit(templatesVariableIdent string, method *ast.FuncT
} else {
lit.Body.List = append(lit.Body.List, &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}})
}
lit.Body.List = append(lit.Body.List, def.executeCall(httpStatusCode(httpStatusCode200Ident), ast.NewIdent(dataVarIdent)))
lit.Body.List = append(lit.Body.List, def.executeCall(httpStatusCode(httpStatusCode200Ident), ast.NewIdent(dataVarIdent), writeHeader))
return lit, imports, nil
}

Expand Down Expand Up @@ -830,24 +836,24 @@ func paramParseError(errVar *ast.Ident) *ast.IfStmt {
}
}

func (def TemplateName) executeCall(status, data ast.Expr) *ast.ExprStmt {
func (def TemplateName) executeCall(status, data ast.Expr, writeHeader bool) *ast.ExprStmt {
return &ast.ExprStmt{X: &ast.CallExpr{
Fun: ast.NewIdent(executeIdentName),
Args: []ast.Expr{
ast.NewIdent(TemplateNameScopeIdentifierHTTPResponse),
ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest),
ast.NewIdent("true"),
ast.NewIdent(strconv.FormatBool(writeHeader)),
&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(def.name)},
status,
data,
},
}}
}

func (def TemplateName) httpRequestReceiverTemplateHandlerFunc(templatesVariableName string) *ast.FuncLit {
func (def TemplateName) httpRequestReceiverTemplateHandlerFunc() *ast.FuncLit {
return &ast.FuncLit{
Type: httpHandlerFuncType(),
Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(httpStatusCode(httpStatusCode200Ident), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest))}},
Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(httpStatusCode(httpStatusCode200Ident), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}},
}
}

Expand Down
8 changes: 3 additions & 5 deletions generate_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestTemplateName_funcLit(t *testing.T) {
},
Out: `func(response http.ResponseWriter, request *http.Request) {
data := receiver.F(response)
execute(response, request, true, "GET / F(response)", http.StatusOK, data)
execute(response, request, false, "GET / F(response)", http.StatusOK, data)
}`,
},
{
Expand Down Expand Up @@ -105,8 +105,7 @@ func TestTemplateName_funcLit(t *testing.T) {
pat, err, ok := NewTemplateName(tt.In)
require.True(t, ok)
require.NoError(t, err)
tv := "templates"
out, _, err := pat.funcLit(tv, tt.Method)
out, _, err := pat.funcLit(tt.Method)
require.NoError(t, err)
assert.Equal(t, tt.Out, source.Format(out))
})
Expand Down Expand Up @@ -212,8 +211,7 @@ func TestTemplateName_HandlerFuncLit_err(t *testing.T) {
pat, err, ok := NewTemplateName(tt.In)
require.True(t, ok)
require.NoError(t, err)
tv := "templates"
_, _, err = pat.funcLit(tv, tt.Method)
_, _, err = pat.funcLit(tt.Method)
assert.ErrorContains(t, err, tt.ErrSub)
})
}
Expand Down
2 changes: 1 addition & 1 deletion generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) {
projectID := request.PathValue("projectID")
taskID := request.PathValue("taskID")
data := receiver.F(ctx, response, request, projectID, taskID)
execute(response, request, true, "GET /project/{projectID}/task/{taskID} F(ctx, response, request, projectID, taskID)", http.StatusOK, data)
execute(response, request, false, "GET /project/{projectID}/task/{taskID} F(ctx, response, request, projectID, taskID)", http.StatusOK, data)
})
}
func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {
Expand Down

0 comments on commit 155b197

Please sign in to comment.