diff --git a/internal/source/go.go b/internal/source/go.go index 0020cbb..af1202f 100644 --- a/internal/source/go.go +++ b/internal/source/go.go @@ -241,10 +241,12 @@ func String(s string) *ast.BasicLit { return &ast.BasicLit{Value: strconv.Quote(s), Kind: token.STRING} } +func Nil() *ast.Ident { return ast.NewIdent("nil") } + func ErrorCheckReturn(errVarIdent string, body ...ast.Stmt) *ast.IfStmt { return &ast.IfStmt{ - Cond: &ast.BinaryExpr{X: ast.NewIdent(errVarIdent), Op: token.NEQ, Y: ast.NewIdent("nil")}, - Body: &ast.BlockStmt{List: body}, + Cond: &ast.BinaryExpr{X: ast.NewIdent(errVarIdent), Op: token.NEQ, Y: Nil()}, + Body: &ast.BlockStmt{List: append(body, &ast.ReturnStmt{})}, } } @@ -261,3 +263,13 @@ func FieldIndex(fields []*ast.Field, i int) (*ast.Ident, ast.Expr, bool) { } return nil, nil, false } + +func CallError(errIdent string) *ast.CallExpr { + return &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(errIdent), + Sel: ast.NewIdent("Error"), + }, + Args: []ast.Expr{}, + } +} diff --git a/internal/source/imports.go b/internal/source/imports.go index 521f612..befeaf8 100644 --- a/internal/source/imports.go +++ b/internal/source/imports.go @@ -52,6 +52,22 @@ func (imports *Imports) Add(pkgIdent, pkgPath string) string { return pkgIdent } +func (imports *Imports) Ident(pkgPath string) string { + if imports != nil && imports.GenDecl != nil { + for _, s := range imports.GenDecl.Specs { + spec := s.(*ast.ImportSpec) + pp, _ := strconv.Unquote(spec.Path.Value) + if pp == pkgPath { + if spec.Name != nil && spec.Name.Name != "" { + return spec.Name.Name + } + return path.Base(pp) + } + } + } + return path.Base(pkgPath) +} + func (imports *Imports) ImportSpecs() []*ast.ImportSpec { result := make([]*ast.ImportSpec, 0, len(imports.GenDecl.Specs)) for _, spec := range imports.GenDecl.Specs { @@ -72,3 +88,17 @@ func (imports *Imports) SortImports() { func (imports *Imports) AddNetHTTP() string { return imports.Add("", "net/http") } func (imports *Imports) AddHTMLTemplate() string { return imports.Add("", "html/template") } func (imports *Imports) AddContext() string { return imports.Add("", "context") } + +func (imports *Imports) HTTPErrorCall(response ast.Expr, message ast.Expr, code int) *ast.ExprStmt { + return &ast.ExprStmt{X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(imports.AddNetHTTP()), + Sel: ast.NewIdent("Error"), + }, + Args: []ast.Expr{ + response, + message, + HTTPStatusCode(imports, code), + }, + }} +} diff --git a/internal/source/parse.go b/internal/source/parse.go index fccc9d8..fd49df8 100644 --- a/internal/source/parse.go +++ b/internal/source/parse.go @@ -239,20 +239,10 @@ func GenerateValidations(imports *Imports, variable, variableType ast.Expr, inpu var statements []ast.Stmt for _, validation := range validations { statements = append(statements, validation.GenerateValidation(imports, variable, func(message string) ast.Stmt { - return &ast.ExprStmt{X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(imports.AddNetHTTP()), - Sel: ast.NewIdent("Error"), - }, - Args: []ast.Expr{ - ast.NewIdent(responseIdent), - &ast.BasicLit{ - Kind: token.STRING, - Value: strconv.Quote(message), - }, - HTTPStatusCode(imports, http.StatusBadRequest), - }, - }} + return imports.HTTPErrorCall(ast.NewIdent(responseIdent), &ast.BasicLit{ + Kind: token.STRING, + Value: strconv.Quote(message), + }, http.StatusBadRequest) })) } return statements, nil, true diff --git a/routes.go b/routes.go index cf8b734..ad2caf3 100644 --- a/routes.go +++ b/routes.go @@ -43,6 +43,8 @@ const ( receiverInterfaceIdent = "RoutesReceiver" InputAttributeNameStructTag = "name" + + errIdent = "err" ) func Generate(templateNames []TemplateName, ts *template.Template, packageName, templatesVariableName, routesFunctionName, receiverTypeIdent, output string, fileSet *token.FileSet, receiverPackage, templatesPackage []*ast.File, log *log.Logger) (string, error) { @@ -121,13 +123,7 @@ func (def TemplateName) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStm X: ast.NewIdent(muxVarIdent), Sel: ast.NewIdent(httpHandleFuncIdent), }, - Args: []ast.Expr{ - &ast.BasicLit{ - Kind: token.STRING, - Value: strconv.Quote(def.endpoint), - }, - handlerFuncLit, - }, + Args: []ast.Expr{source.String(def.endpoint), handlerFuncLit}, }} } @@ -152,7 +148,6 @@ func (def TemplateName) funcLit(imports *source.Imports, method *ast.FuncType, t formStruct = s } } - const errVarIdent = "err" writeHeader := true for i, a := range def.call.Args { arg := a.(*ast.Ident) @@ -185,20 +180,8 @@ func (def TemplateName) funcLit(imports *source.Imports, method *ast.FuncType, t X: ast.NewIdent(arg.Name), Sel: ast.NewIdent(name.Name), } - errCheck := source.ErrorCheckReturn(errVarIdent, &ast.ExprStmt{X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(imports.AddNetHTTP()), - Sel: ast.NewIdent("Error"), - }, - Args: []ast.Expr{ - ast.NewIdent(httpResponseField(imports).Names[0].Name), - &ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("err"), Sel: ast.NewIdent("Error")}, - Args: []ast.Expr{}, - }, - source.HTTPStatusCode(imports, http.StatusBadRequest), - }, - }}, &ast.ReturnStmt{}) + + errCheck := source.ErrorCheckReturn(errIdent, imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), source.CallError(errIdent), http.StatusBadRequest)) const parsedVariableName = "value" if fieldType, ok := field.Type.(*ast.ArrayType); ok { @@ -217,7 +200,7 @@ func (def TemplateName) funcLit(imports *source.Imports, method *ast.FuncType, t if ok && err != nil { return nil, err } - statements, err := source.GenerateParseValueFromStringStatements(imports, parsedVariableName, errVarIdent, ast.NewIdent(valVar), fieldType.Elt, errCheck, validations, assignment) + statements, err := source.GenerateParseValueFromStringStatements(imports, parsedVariableName, errIdent, ast.NewIdent(valVar), fieldType.Elt, errCheck, validations, assignment) if err != nil { return nil, fmt.Errorf("failed to generate parse statements for form field %s: %w", name.Name, err) } @@ -263,7 +246,7 @@ func (def TemplateName) funcLit(imports *source.Imports, method *ast.FuncType, t if ok && err != nil { return nil, err } - statements, err := source.GenerateParseValueFromStringStatements(imports, parsedVariableName, errVarIdent, str, field.Type, errCheck, validations, assignment) + statements, err := source.GenerateParseValueFromStringStatements(imports, parsedVariableName, errIdent, str, field.Type, errCheck, validations, assignment) if err != nil { return nil, fmt.Errorf("failed to generate parse statements for form field %s: %w", name.Name, err) } @@ -283,20 +266,7 @@ func (def TemplateName) funcLit(imports *source.Imports, method *ast.FuncType, t } call.Args = append(call.Args, ast.NewIdent(arg.Name)) default: - errCheck := source.ErrorCheckReturn(errVarIdent, &ast.ExprStmt{X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(imports.AddNetHTTP()), - Sel: ast.NewIdent("Error"), - }, - Args: []ast.Expr{ - ast.NewIdent(httpResponseField(imports).Names[0].Name), - &ast.CallExpr{ - Fun: &ast.SelectorExpr{X: ast.NewIdent("err"), Sel: ast.NewIdent("Error")}, - Args: []ast.Expr{}, - }, - source.HTTPStatusCode(imports, http.StatusBadRequest), - }, - }}, &ast.ReturnStmt{}) + errCheck := source.ErrorCheckReturn(errIdent, imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), source.CallError(errIdent), http.StatusBadRequest)) src := &ast.CallExpr{ Fun: &ast.SelectorExpr{ X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), @@ -304,7 +274,7 @@ func (def TemplateName) funcLit(imports *source.Imports, method *ast.FuncType, t }, Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, } - statements, err := httpPathValueAssignment(imports, method, i, arg, errVarIdent, src, token.DEFINE, errCheck) + statements, err := httpPathValueAssignment(imports, method, i, arg, errIdent, src, token.DEFINE, errCheck) if err != nil { return nil, err } @@ -315,31 +285,13 @@ func (def TemplateName) funcLit(imports *source.Imports, method *ast.FuncType, t const dataVarIdent = "data" if len(method.Results.List) > 1 { - errVar := ast.NewIdent("err") - lit.Body.List = append(lit.Body.List, - &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent), ast.NewIdent(errVar.Name)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}, + &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent), ast.NewIdent(errIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}, &ast.IfStmt{ - Cond: &ast.BinaryExpr{X: ast.NewIdent(errVar.Name), Op: token.NEQ, Y: ast.NewIdent("nil")}, + Cond: &ast.BinaryExpr{X: ast.NewIdent(errIdent), Op: token.NEQ, Y: source.Nil()}, Body: &ast.BlockStmt{ List: []ast.Stmt{ - &ast.ExprStmt{X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(imports.AddNetHTTP()), - Sel: ast.NewIdent("Error"), - }, - Args: []ast.Expr{ - ast.NewIdent(httpResponseField(imports).Names[0].Name), - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("err"), - Sel: ast.NewIdent("Error"), - }, - Args: []ast.Expr{}, - }, - source.HTTPStatusCode(imports, http.StatusInternalServerError), - }, - }}, + imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), source.CallError(errIdent), http.StatusInternalServerError), &ast.ReturnStmt{}, }, }, @@ -379,7 +331,7 @@ func (def TemplateName) funcType(imports *source.Imports) *ast.FuncType { case TemplateNameScopeIdentifierContext: method.Params.List = append(method.Params.List, contextContextField(imports)) case TemplateNameScopeIdentifierForm: - method.Params.List = append(method.Params.List, urlValuesField(arg.Name)) + method.Params.List = append(method.Params.List, urlValuesField(imports, arg.Name)) default: method.Params.List = append(method.Params.List, pathValueField(arg.Name)) } @@ -525,10 +477,10 @@ func routesFuncType(imports *source.Imports, receiverType ast.Expr) *ast.FuncTyp }} } -func urlValuesField(ident string) *ast.Field { +func urlValuesField(imports *source.Imports, ident string) *ast.Field { return &ast.Field{ Names: []*ast.Ident{ast.NewIdent(ident)}, - Type: &ast.SelectorExpr{X: ast.NewIdent("url"), Sel: ast.NewIdent("Values")}, + Type: &ast.SelectorExpr{X: ast.NewIdent(imports.Add("", "net/url")), Sel: ast.NewIdent("Values")}, } } @@ -565,7 +517,8 @@ func contextAssignment() *ast.AssignStmt { } func formDeclaration(imports *source.Imports, ident string, typeExp ast.Expr) *ast.DeclStmt { - if matchSelectorIdents(typeExp, "url", "Values", false) { + if matchSelectorIdents(typeExp, imports.Ident("net/url"), "Values", false) { + imports.Add("", "net/url") return &ast.DeclStmt{ Decl: &ast.GenDecl{ Tok: token.VAR, @@ -581,7 +534,6 @@ func formDeclaration(imports *source.Imports, ident string, typeExp ast.Expr) *a }, } } - return &ast.DeclStmt{ Decl: &ast.GenDecl{ Tok: token.VAR, @@ -665,7 +617,6 @@ func (def TemplateName) matchReceiver(funcDecl *ast.FuncDecl, receiverTypeIdent func executeFuncDecl(imports *source.Imports, templatesVariableIdent string) *ast.FuncDecl { const writeHeaderIdent = "writeHeader" - imports.Add("", "bytes") return &ast.FuncDecl{ Name: ast.NewIdent(executeIdentName), Type: &ast.FuncType{ @@ -687,15 +638,15 @@ func executeFuncDecl(imports *source.Imports, templatesVariableIdent string) *as Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{ - X: ast.NewIdent("bytes"), + X: ast.NewIdent(imports.Add("", "bytes")), Sel: ast.NewIdent("NewBuffer"), }, - Args: []ast.Expr{ast.NewIdent("nil")}, + Args: []ast.Expr{source.Nil()}, }}, }, &ast.IfStmt{ Init: &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent("err")}, + Lhs: []ast.Expr{ast.NewIdent(errIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{&ast.CallExpr{ Fun: &ast.SelectorExpr{ @@ -706,29 +657,13 @@ func executeFuncDecl(imports *source.Imports, templatesVariableIdent string) *as }}, }, Cond: &ast.BinaryExpr{ - X: ast.NewIdent("err"), + X: ast.NewIdent(errIdent), Op: token.NEQ, - Y: ast.NewIdent("nil"), + Y: source.Nil(), }, Body: &ast.BlockStmt{ List: []ast.Stmt{ - &ast.ExprStmt{X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(imports.AddNetHTTP()), - Sel: ast.NewIdent("Error"), - }, - Args: []ast.Expr{ - ast.NewIdent(httpResponseField(imports).Names[0].Name), - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("err"), - Sel: ast.NewIdent("Error"), - }, - Args: []ast.Expr{}, - }, - source.HTTPStatusCode(imports, http.StatusInternalServerError), - }, - }}, + imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), source.CallError(errIdent), http.StatusInternalServerError), &ast.ReturnStmt{}, }, }, @@ -738,7 +673,7 @@ func executeFuncDecl(imports *source.Imports, templatesVariableIdent string) *as Body: &ast.BlockStmt{List: []ast.Stmt{ &ast.ExprStmt{X: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent(TemplateNameScopeIdentifierHTTPResponse), Sel: ast.NewIdent("Header")}, Args: []ast.Expr{}}, Sel: ast.NewIdent("Set")}, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote("content-type")}, &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote("text/html; charset=utf-8")}}, + Args: []ast.Expr{source.String("content-type"), source.String("text/html; charset=utf-8")}, }}, &ast.ExprStmt{X: &ast.CallExpr{ Fun: &ast.SelectorExpr{X: ast.NewIdent(httpResponseField(imports).Names[0].Name), Sel: ast.NewIdent("WriteHeader")},