From 2a43e9036443e32cb32e4d03d371064be037d876 Mon Sep 17 00:00:00 2001 From: Chrstopher Hunter <8398225+crhntr@users.noreply.github.com> Date: Fri, 11 Oct 2024 00:47:42 -0700 Subject: [PATCH] fix log test and standardize TemplateName variable ident --- cmd/muxt/generate.go | 2 +- name.go | 37 +++++++++++-------- routes.go | 88 ++++++++++++++++++++++---------------------- routes_test.go | 2 +- 4 files changed, 68 insertions(+), 61 deletions(-) diff --git a/cmd/muxt/generate.go b/cmd/muxt/generate.go index 74bdf49..186770f 100644 --- a/cmd/muxt/generate.go +++ b/cmd/muxt/generate.go @@ -107,7 +107,7 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string) return err } out := log.New(stdout, "", 0) - s, err := muxt.Generate(templateNames, ts, g.goPackage, g.templatesVariable, g.routesFunction, g.receiverIdent, g.receiverInterfaceIdent, g.outputFilename, g.Package.Fset, g.Package.Syntax, g.Package.Syntax, out) + s, err := muxt.Generate(templateNames, g.goPackage, g.templatesVariable, g.routesFunction, g.receiverIdent, g.receiverInterfaceIdent, g.outputFilename, g.Package.Fset, g.Package.Syntax, g.Package.Syntax, out) if err != nil { return err } diff --git a/name.go b/name.go index e9a8da0..8b3f0f8 100644 --- a/name.go +++ b/name.go @@ -20,18 +20,19 @@ func TemplateNames(ts *template.Template) ([]TemplateName, error) { var templateNames []TemplateName routes := make(map[string]struct{}) for _, t := range ts.Templates() { - pat, err, ok := NewTemplateName(t.Name()) + templateName, err, ok := NewTemplateName(t.Name()) if !ok { continue } if err != nil { return templateNames, err } - if _, exists := routes[pat.method+pat.path]; exists { - return templateNames, fmt.Errorf("duplicate route pattern: %s", pat.endpoint) + if _, exists := routes[templateName.method+templateName.path]; exists { + return templateNames, fmt.Errorf("duplicate route pattern: %s", templateName.endpoint) } - routes[pat.method+pat.path] = struct{}{} - templateNames = append(templateNames, pat) + templateName.template = t + routes[templateName.method+templateName.path] = struct{}{} + templateNames = append(templateNames, templateName) } slices.SortFunc(templateNames, TemplateName.byPathThenMethod) return templateNames, nil @@ -55,6 +56,8 @@ type TemplateName struct { fileSet *token.FileSet + template *template.Template + pathValueNames []string } @@ -114,11 +117,11 @@ var ( templateNameMux = regexp.MustCompile(`^(?P(((?P[A-Z]+)\s+)?)(?P([^/])*)(?P(/(\S)*)))(\s+(?P(\d|http\.Status)\S+))?(?P.*)?$`) ) -func (def TemplateName) parsePathValueNames() ([]string, error) { +func (tn TemplateName) parsePathValueNames() ([]string, error) { var result []string - for _, match := range pathSegmentPattern.FindAllStringSubmatch(def.path, strings.Count(def.path, "/")) { + for _, match := range pathSegmentPattern.FindAllStringSubmatch(tn.path, strings.Count(tn.path, "/")) { n := match[1] - if n == "$" && strings.Count(def.path, "$") == 1 && strings.HasSuffix(def.path, "{$}") { + if n == "$" && strings.Count(tn.path, "$") == 1 && strings.HasSuffix(tn.path, "{$}") { continue } n = strings.TrimSuffix(n, "...") @@ -142,19 +145,23 @@ func checkPathValueNames(in []string) error { return nil } -func (def TemplateName) String() string { return def.name } -func (def TemplateName) Pattern() string { return def.method + " " + def.path } +func (tn TemplateName) String() string { return tn.name } +func (tn TemplateName) Pattern() string { + return tn.method + " " + tn.path +} -func (def TemplateName) sameRoute(p TemplateName) bool { return def.endpoint == p.endpoint } +func (tn TemplateName) sameRoute(p TemplateName) bool { + return tn.endpoint == p.endpoint +} -func (def TemplateName) byPathThenMethod(d TemplateName) int { - if n := cmp.Compare(def.path, d.path); n != 0 { +func (tn TemplateName) byPathThenMethod(d TemplateName) int { + if n := cmp.Compare(tn.path, d.path); n != 0 { return n } - if m := cmp.Compare(def.method, d.method); m != 0 { + if m := cmp.Compare(tn.method, d.method); m != 0 { return m } - return cmp.Compare(def.handler, d.handler) + return cmp.Compare(tn.handler, d.handler) } func parseHandler(fileSet *token.FileSet, def *TemplateName) error { diff --git a/routes.go b/routes.go index b2ba017..6748e62 100644 --- a/routes.go +++ b/routes.go @@ -48,7 +48,7 @@ const ( errIdent = "err" ) -func Generate(templateNames []TemplateName, ts *template.Template, packageName, templatesVariableName, routesFunctionName, receiverTypeIdent, receiverInterfaceIdent, output string, fileSet *token.FileSet, receiverPackage, templatesPackage []*ast.File, log *log.Logger) (string, error) { +func Generate(templateNames []TemplateName, packageName, templatesVariableName, routesFunctionName, receiverTypeIdent, receiverInterfaceIdent, output string, fileSet *token.FileSet, receiverPackage, templatesPackage []*ast.File, log *log.Logger) (string, error) { packageName = cmp.Or(packageName, defaultPackageName) templatesVariableName = cmp.Or(templatesVariableName, DefaultTemplatesVariableName) routesFunctionName = cmp.Or(routesFunctionName, DefaultRoutesFunctionName) @@ -64,7 +64,7 @@ func Generate(templateNames []TemplateName, ts *template.Template, packageName, receiverMethods := source.StaticTypeMethods(receiverPackage, receiverTypeIdent) receiverInterface := receiverInterfaceType(imports, receiverMethods, templateNames) - routesFunc, err := routesFuncDeclaration(imports, ts, routesFunctionName, receiverInterfaceIdent, receiverInterface, receiverPackage, templateNames) + routesFunc, err := routesFuncDeclaration(imports, routesFunctionName, receiverInterfaceIdent, receiverInterface, receiverPackage, templateNames, log) if err != nil { return "", err } @@ -92,26 +92,26 @@ func addExecuteFunction(imports *source.Imports, fileSet *token.FileSet, files [ file.Decls = append(file.Decls, executeFuncDecl(imports, templatesVariableName)) } -func routesFuncDeclaration(imports *source.Imports, ts *template.Template, routesFunctionName, receiverInterfaceIdent string, receiverInterfaceType *ast.InterfaceType, receiverPackage []*ast.File, templateNames []TemplateName) (*ast.FuncDecl, error) { +func routesFuncDeclaration(imports *source.Imports, routesFunctionName, receiverInterfaceIdent string, receiverInterfaceType *ast.InterfaceType, receiverPackage []*ast.File, templateNames []TemplateName, log *log.Logger) (*ast.FuncDecl, error) { routes := &ast.FuncDecl{ Name: ast.NewIdent(routesFunctionName), Type: routesFuncType(imports, ast.NewIdent(receiverInterfaceIdent)), Body: &ast.BlockStmt{}, } - for _, name := range templateNames { - t := ts.Lookup(name.name) - var hf *ast.FuncLit - if name.fun == nil { - hf = name.httpRequestReceiverTemplateHandlerFunc(imports, name.statusCode) - } else { - var err error - hf, err = name.funcLit(imports, t, receiverInterfaceType, receiverPackage) - if err != nil { - return nil, err - } + for _, tn := range templateNames { + log.Printf("%s has route for %s", routesFunctionName, tn.endpoint) + if tn.fun == nil { + hf := tn.httpRequestReceiverTemplateHandlerFunc(imports, tn.statusCode) + routes.Body.List = append(routes.Body.List, tn.callHandleFunc(hf)) + continue + } + + hf, err := tn.funcLit(imports, receiverInterfaceType, receiverPackage) + if err != nil { + return nil, err } - routes.Body.List = append(routes.Body.List, name.callHandleFunc(hf)) + routes.Body.List = append(routes.Body.List, tn.callHandleFunc(hf)) } return routes, nil @@ -120,50 +120,50 @@ func routesFuncDeclaration(imports *source.Imports, ts *template.Template, route func receiverInterfaceType(imports *source.Imports, receiverMethods *ast.FieldList, templateNames []TemplateName) *ast.InterfaceType { interfaceMethods := new(ast.FieldList) - for _, name := range templateNames { - if name.fun == nil { + for _, tn := range templateNames { + if tn.fun == nil { continue } - if source.HasFieldWithName(interfaceMethods, name.fun.Name) { + if source.HasFieldWithName(interfaceMethods, tn.fun.Name) { continue } - if field, ok := source.FindFieldWithName(receiverMethods, name.fun.Name); ok { + if field, ok := source.FindFieldWithName(receiverMethods, tn.fun.Name); ok { interfaceMethods.List = append(interfaceMethods.List, field) continue } - interfaceMethods.List = append(interfaceMethods.List, name.methodField(imports)) + interfaceMethods.List = append(interfaceMethods.List, tn.methodField(imports)) } return &ast.InterfaceType{Methods: interfaceMethods} } -func (def TemplateName) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStmt { +func (tn TemplateName) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStmt { return &ast.ExprStmt{X: &ast.CallExpr{ Fun: &ast.SelectorExpr{ X: ast.NewIdent(muxVarIdent), Sel: ast.NewIdent(httpHandleFuncIdent), }, - Args: []ast.Expr{source.String(def.endpoint), handlerFuncLit}, + Args: []ast.Expr{source.String(tn.endpoint), handlerFuncLit}, }} } -func (def TemplateName) funcLit(imports *source.Imports, t *template.Template, receiverInterfaceType *ast.InterfaceType, files []*ast.File) (*ast.FuncLit, error) { - methodField, ok := source.FindFieldWithName(receiverInterfaceType.Methods, def.fun.Name) +func (tn TemplateName) funcLit(imports *source.Imports, receiverInterfaceType *ast.InterfaceType, files []*ast.File) (*ast.FuncLit, error) { + methodField, ok := source.FindFieldWithName(receiverInterfaceType.Methods, tn.fun.Name) if !ok { - log.Fatalf("receiver does not have a method declaration for %s", def.fun.Name) + log.Fatalf("receiver does not have a method declaration for %s", tn.fun.Name) } method := methodField.Type.(*ast.FuncType) lit := &ast.FuncLit{ Type: httpHandlerFuncType(imports), Body: &ast.BlockStmt{}, } - call := &ast.CallExpr{Fun: callReceiverMethod(def.fun)} - if method.Params.NumFields() != len(def.call.Args) { - return nil, errWrongNumberOfArguments(def, method) + call := &ast.CallExpr{Fun: callReceiverMethod(tn.fun)} + if method.Params.NumFields() != len(tn.call.Args) { + return nil, errWrongNumberOfArguments(tn, method) } var formStruct *ast.StructType for pi, pt := range fieldListTypes(method.Params) { - if err := checkArgument(imports, method, pi, def.call.Args[pi], pt, files); err != nil { + if err := checkArgument(imports, method, pi, tn.call.Args[pi], pt, files); err != nil { return nil, err } if s, ok := findFormStruct(pt, files); ok { @@ -171,7 +171,7 @@ func (def TemplateName) funcLit(imports *source.Imports, t *template.Template, r } } writeHeader := true - for i, a := range def.call.Args { + for i, a := range tn.call.Args { arg := a.(*ast.Ident) switch arg.Name { case TemplateNameScopeIdentifierHTTPResponse: @@ -203,7 +203,7 @@ func (def TemplateName) funcLit(imports *source.Imports, t *template.Template, r Sel: ast.NewIdent(name.Name), } - fieldTemplate := formInputTemplate(field, t) + fieldTemplate := formInputTemplate(field, tn.template) errCheck := func(exp ast.Expr) ast.Stmt { return &ast.ExprStmt{ @@ -326,7 +326,7 @@ func (def TemplateName) funcLit(imports *source.Imports, t *template.Template, r okIdent = "ok" ) if method.Results == nil || len(method.Results.List) == 0 { - return lit, fmt.Errorf("method for endpoint %q has no results it should have one or two", def) + return lit, fmt.Errorf("method for endpoint %q has no results it should have one or two", tn) } else if len(method.Results.List) > 1 { _, lastResultType, ok := source.FieldIndex(method.Results.List, method.Results.NumFields()-1) if !ok { @@ -370,7 +370,7 @@ func (def TemplateName) funcLit(imports *source.Imports, t *template.Template, r } else { lit.Body.List = append(lit.Body.List, &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}) } - lit.Body.List = append(lit.Body.List, def.executeCall(source.HTTPStatusCode(imports, def.statusCode), ast.NewIdent(dataVarIdent), writeHeader)) + lit.Body.List = append(lit.Body.List, tn.executeCall(source.HTTPStatusCode(imports, tn.statusCode), ast.NewIdent(dataVarIdent), writeHeader)) return lit, nil } @@ -398,19 +398,19 @@ func formInputTemplate(field *ast.Field, t *template.Template) *template.Templat return t } -func (def TemplateName) methodField(imports *source.Imports) *ast.Field { +func (tn TemplateName) methodField(imports *source.Imports) *ast.Field { return &ast.Field{ - Names: []*ast.Ident{ast.NewIdent(def.fun.Name)}, - Type: def.funcType(imports), + Names: []*ast.Ident{ast.NewIdent(tn.fun.Name)}, + Type: tn.funcType(imports), } } -func (def TemplateName) funcType(imports *source.Imports) *ast.FuncType { +func (tn TemplateName) funcType(imports *source.Imports) *ast.FuncType { method := &ast.FuncType{ Params: &ast.FieldList{}, Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, } - for _, a := range def.call.Args { + for _, a := range tn.call.Args { arg := a.(*ast.Ident) switch arg.Name { case TemplateNameScopeIdentifierHTTPRequest: @@ -686,29 +686,29 @@ func appendAssignment(assignTok token.Token, result ast.Expr) func(exp ast.Expr) } } -func (def TemplateName) executeCall(status, data ast.Expr, writeHeader bool) *ast.ExprStmt { +func (tn TemplateName) executeCall(status, data ast.Expr, writeHeader bool) *ast.ExprStmt { return &ast.ExprStmt{X: &ast.CallExpr{ Fun: ast.NewIdent(executeIdentName), Args: []ast.Expr{ ast.NewIdent(TemplateNameScopeIdentifierHTTPResponse), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), ast.NewIdent(strconv.FormatBool(writeHeader)), - &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(def.name)}, + &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(tn.name)}, status, data, }, }} } -func (def TemplateName) httpRequestReceiverTemplateHandlerFunc(imports *source.Imports, statusCode int) *ast.FuncLit { +func (tn TemplateName) httpRequestReceiverTemplateHandlerFunc(imports *source.Imports, statusCode int) *ast.FuncLit { return &ast.FuncLit{ Type: httpHandlerFuncType(imports), - Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(source.HTTPStatusCode(imports, statusCode), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}}, + Body: &ast.BlockStmt{List: []ast.Stmt{tn.executeCall(source.HTTPStatusCode(imports, statusCode), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}}, } } -func (def TemplateName) matchReceiver(funcDecl *ast.FuncDecl, receiverTypeIdent string) bool { - if funcDecl == nil || funcDecl.Name == nil || funcDecl.Name.Name != def.fun.Name || +func (tn TemplateName) matchReceiver(funcDecl *ast.FuncDecl, receiverTypeIdent string) bool { + if funcDecl == nil || funcDecl.Name == nil || funcDecl.Name.Name != tn.fun.Name || funcDecl.Recv == nil || len(funcDecl.Recv.List) < 1 { return false } diff --git a/routes_test.go b/routes_test.go index fe37730..741cbf5 100644 --- a/routes_test.go +++ b/routes_test.go @@ -1667,7 +1667,7 @@ func execute(response http.ResponseWriter, request *http.Request, writeHeader bo logs := log.New(io.Discard, "", 0) set := token.NewFileSet() goFiles := methodFuncTypeLoader(t, set, tt.ReceiverPackage) - out, err := muxt.Generate(templateNames, ts, tt.PackageName, tt.TemplatesVar, tt.RoutesFunc, tt.Receiver, tt.Interface, muxt.DefaultOutputFileName, set, goFiles, goFiles, logs) + out, err := muxt.Generate(templateNames, tt.PackageName, tt.TemplatesVar, tt.RoutesFunc, tt.Receiver, tt.Interface, muxt.DefaultOutputFileName, set, goFiles, goFiles, logs) if tt.ExpectedError == "" { assert.NoError(t, err) assert.Equal(t, tt.ExpectedFile, out)