Skip to content

Commit

Permalink
fix log test and standardize TemplateName variable ident
Browse files Browse the repository at this point in the history
  • Loading branch information
crhntr committed Oct 11, 2024
1 parent 0c9ef56 commit 2a43e90
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 61 deletions.
2 changes: 1 addition & 1 deletion cmd/muxt/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string)
return err
}
out := log.New(stdout, "", 0)
s, err := muxt.Generate(templateNames, ts, g.goPackage, g.templatesVariable, g.routesFunction, g.receiverIdent, g.receiverInterfaceIdent, g.outputFilename, g.Package.Fset, g.Package.Syntax, g.Package.Syntax, out)
s, err := muxt.Generate(templateNames, g.goPackage, g.templatesVariable, g.routesFunction, g.receiverIdent, g.receiverInterfaceIdent, g.outputFilename, g.Package.Fset, g.Package.Syntax, g.Package.Syntax, out)
if err != nil {
return err
}
Expand Down
37 changes: 22 additions & 15 deletions name.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@ func TemplateNames(ts *template.Template) ([]TemplateName, error) {
var templateNames []TemplateName
routes := make(map[string]struct{})
for _, t := range ts.Templates() {
pat, err, ok := NewTemplateName(t.Name())
templateName, err, ok := NewTemplateName(t.Name())
if !ok {
continue
}
if err != nil {
return templateNames, err
}
if _, exists := routes[pat.method+pat.path]; exists {
return templateNames, fmt.Errorf("duplicate route pattern: %s", pat.endpoint)
if _, exists := routes[templateName.method+templateName.path]; exists {
return templateNames, fmt.Errorf("duplicate route pattern: %s", templateName.endpoint)
}
routes[pat.method+pat.path] = struct{}{}
templateNames = append(templateNames, pat)
templateName.template = t
routes[templateName.method+templateName.path] = struct{}{}
templateNames = append(templateNames, templateName)
}
slices.SortFunc(templateNames, TemplateName.byPathThenMethod)
return templateNames, nil
Expand All @@ -55,6 +56,8 @@ type TemplateName struct {

fileSet *token.FileSet

template *template.Template

pathValueNames []string
}

Expand Down Expand Up @@ -114,11 +117,11 @@ var (
templateNameMux = regexp.MustCompile(`^(?P<endpoint>(((?P<method>[A-Z]+)\s+)?)(?P<host>([^/])*)(?P<path>(/(\S)*)))(\s+(?P<code>(\d|http\.Status)\S+))?(?P<handler>.*)?$`)
)

func (def TemplateName) parsePathValueNames() ([]string, error) {
func (tn TemplateName) parsePathValueNames() ([]string, error) {
var result []string
for _, match := range pathSegmentPattern.FindAllStringSubmatch(def.path, strings.Count(def.path, "/")) {
for _, match := range pathSegmentPattern.FindAllStringSubmatch(tn.path, strings.Count(tn.path, "/")) {
n := match[1]
if n == "$" && strings.Count(def.path, "$") == 1 && strings.HasSuffix(def.path, "{$}") {
if n == "$" && strings.Count(tn.path, "$") == 1 && strings.HasSuffix(tn.path, "{$}") {
continue
}
n = strings.TrimSuffix(n, "...")
Expand All @@ -142,19 +145,23 @@ func checkPathValueNames(in []string) error {
return nil
}

func (def TemplateName) String() string { return def.name }
func (def TemplateName) Pattern() string { return def.method + " " + def.path }
func (tn TemplateName) String() string { return tn.name }
func (tn TemplateName) Pattern() string {
return tn.method + " " + tn.path
}

func (def TemplateName) sameRoute(p TemplateName) bool { return def.endpoint == p.endpoint }
func (tn TemplateName) sameRoute(p TemplateName) bool {
return tn.endpoint == p.endpoint
}

func (def TemplateName) byPathThenMethod(d TemplateName) int {
if n := cmp.Compare(def.path, d.path); n != 0 {
func (tn TemplateName) byPathThenMethod(d TemplateName) int {
if n := cmp.Compare(tn.path, d.path); n != 0 {
return n
}
if m := cmp.Compare(def.method, d.method); m != 0 {
if m := cmp.Compare(tn.method, d.method); m != 0 {
return m
}
return cmp.Compare(def.handler, d.handler)
return cmp.Compare(tn.handler, d.handler)
}

func parseHandler(fileSet *token.FileSet, def *TemplateName) error {
Expand Down
88 changes: 44 additions & 44 deletions routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ const (
errIdent = "err"
)

func Generate(templateNames []TemplateName, ts *template.Template, packageName, templatesVariableName, routesFunctionName, receiverTypeIdent, receiverInterfaceIdent, output string, fileSet *token.FileSet, receiverPackage, templatesPackage []*ast.File, log *log.Logger) (string, error) {
func Generate(templateNames []TemplateName, packageName, templatesVariableName, routesFunctionName, receiverTypeIdent, receiverInterfaceIdent, output string, fileSet *token.FileSet, receiverPackage, templatesPackage []*ast.File, log *log.Logger) (string, error) {
packageName = cmp.Or(packageName, defaultPackageName)
templatesVariableName = cmp.Or(templatesVariableName, DefaultTemplatesVariableName)
routesFunctionName = cmp.Or(routesFunctionName, DefaultRoutesFunctionName)
Expand All @@ -64,7 +64,7 @@ func Generate(templateNames []TemplateName, ts *template.Template, packageName,

receiverMethods := source.StaticTypeMethods(receiverPackage, receiverTypeIdent)
receiverInterface := receiverInterfaceType(imports, receiverMethods, templateNames)
routesFunc, err := routesFuncDeclaration(imports, ts, routesFunctionName, receiverInterfaceIdent, receiverInterface, receiverPackage, templateNames)
routesFunc, err := routesFuncDeclaration(imports, routesFunctionName, receiverInterfaceIdent, receiverInterface, receiverPackage, templateNames, log)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -92,26 +92,26 @@ func addExecuteFunction(imports *source.Imports, fileSet *token.FileSet, files [
file.Decls = append(file.Decls, executeFuncDecl(imports, templatesVariableName))
}

func routesFuncDeclaration(imports *source.Imports, ts *template.Template, routesFunctionName, receiverInterfaceIdent string, receiverInterfaceType *ast.InterfaceType, receiverPackage []*ast.File, templateNames []TemplateName) (*ast.FuncDecl, error) {
func routesFuncDeclaration(imports *source.Imports, routesFunctionName, receiverInterfaceIdent string, receiverInterfaceType *ast.InterfaceType, receiverPackage []*ast.File, templateNames []TemplateName, log *log.Logger) (*ast.FuncDecl, error) {
routes := &ast.FuncDecl{
Name: ast.NewIdent(routesFunctionName),
Type: routesFuncType(imports, ast.NewIdent(receiverInterfaceIdent)),
Body: &ast.BlockStmt{},
}

for _, name := range templateNames {
t := ts.Lookup(name.name)
var hf *ast.FuncLit
if name.fun == nil {
hf = name.httpRequestReceiverTemplateHandlerFunc(imports, name.statusCode)
} else {
var err error
hf, err = name.funcLit(imports, t, receiverInterfaceType, receiverPackage)
if err != nil {
return nil, err
}
for _, tn := range templateNames {
log.Printf("%s has route for %s", routesFunctionName, tn.endpoint)
if tn.fun == nil {
hf := tn.httpRequestReceiverTemplateHandlerFunc(imports, tn.statusCode)
routes.Body.List = append(routes.Body.List, tn.callHandleFunc(hf))
continue
}

hf, err := tn.funcLit(imports, receiverInterfaceType, receiverPackage)
if err != nil {
return nil, err
}
routes.Body.List = append(routes.Body.List, name.callHandleFunc(hf))
routes.Body.List = append(routes.Body.List, tn.callHandleFunc(hf))
}

return routes, nil
Expand All @@ -120,58 +120,58 @@ func routesFuncDeclaration(imports *source.Imports, ts *template.Template, route
func receiverInterfaceType(imports *source.Imports, receiverMethods *ast.FieldList, templateNames []TemplateName) *ast.InterfaceType {
interfaceMethods := new(ast.FieldList)

for _, name := range templateNames {
if name.fun == nil {
for _, tn := range templateNames {
if tn.fun == nil {
continue
}
if source.HasFieldWithName(interfaceMethods, name.fun.Name) {
if source.HasFieldWithName(interfaceMethods, tn.fun.Name) {
continue
}
if field, ok := source.FindFieldWithName(receiverMethods, name.fun.Name); ok {
if field, ok := source.FindFieldWithName(receiverMethods, tn.fun.Name); ok {
interfaceMethods.List = append(interfaceMethods.List, field)
continue
}
interfaceMethods.List = append(interfaceMethods.List, name.methodField(imports))
interfaceMethods.List = append(interfaceMethods.List, tn.methodField(imports))
}

return &ast.InterfaceType{Methods: interfaceMethods}
}

func (def TemplateName) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStmt {
func (tn TemplateName) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStmt {
return &ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(muxVarIdent),
Sel: ast.NewIdent(httpHandleFuncIdent),
},
Args: []ast.Expr{source.String(def.endpoint), handlerFuncLit},
Args: []ast.Expr{source.String(tn.endpoint), handlerFuncLit},
}}
}

func (def TemplateName) funcLit(imports *source.Imports, t *template.Template, receiverInterfaceType *ast.InterfaceType, files []*ast.File) (*ast.FuncLit, error) {
methodField, ok := source.FindFieldWithName(receiverInterfaceType.Methods, def.fun.Name)
func (tn TemplateName) funcLit(imports *source.Imports, receiverInterfaceType *ast.InterfaceType, files []*ast.File) (*ast.FuncLit, error) {
methodField, ok := source.FindFieldWithName(receiverInterfaceType.Methods, tn.fun.Name)
if !ok {
log.Fatalf("receiver does not have a method declaration for %s", def.fun.Name)
log.Fatalf("receiver does not have a method declaration for %s", tn.fun.Name)
}
method := methodField.Type.(*ast.FuncType)
lit := &ast.FuncLit{
Type: httpHandlerFuncType(imports),
Body: &ast.BlockStmt{},
}
call := &ast.CallExpr{Fun: callReceiverMethod(def.fun)}
if method.Params.NumFields() != len(def.call.Args) {
return nil, errWrongNumberOfArguments(def, method)
call := &ast.CallExpr{Fun: callReceiverMethod(tn.fun)}
if method.Params.NumFields() != len(tn.call.Args) {
return nil, errWrongNumberOfArguments(tn, method)
}
var formStruct *ast.StructType
for pi, pt := range fieldListTypes(method.Params) {
if err := checkArgument(imports, method, pi, def.call.Args[pi], pt, files); err != nil {
if err := checkArgument(imports, method, pi, tn.call.Args[pi], pt, files); err != nil {
return nil, err
}
if s, ok := findFormStruct(pt, files); ok {
formStruct = s
}
}
writeHeader := true
for i, a := range def.call.Args {
for i, a := range tn.call.Args {
arg := a.(*ast.Ident)
switch arg.Name {
case TemplateNameScopeIdentifierHTTPResponse:
Expand Down Expand Up @@ -203,7 +203,7 @@ func (def TemplateName) funcLit(imports *source.Imports, t *template.Template, r
Sel: ast.NewIdent(name.Name),
}

fieldTemplate := formInputTemplate(field, t)
fieldTemplate := formInputTemplate(field, tn.template)

errCheck := func(exp ast.Expr) ast.Stmt {
return &ast.ExprStmt{
Expand Down Expand Up @@ -326,7 +326,7 @@ func (def TemplateName) funcLit(imports *source.Imports, t *template.Template, r
okIdent = "ok"
)
if method.Results == nil || len(method.Results.List) == 0 {
return lit, fmt.Errorf("method for endpoint %q has no results it should have one or two", def)
return lit, fmt.Errorf("method for endpoint %q has no results it should have one or two", tn)
} else if len(method.Results.List) > 1 {
_, lastResultType, ok := source.FieldIndex(method.Results.List, method.Results.NumFields()-1)
if !ok {
Expand Down Expand Up @@ -370,7 +370,7 @@ func (def TemplateName) funcLit(imports *source.Imports, t *template.Template, r
} else {
lit.Body.List = append(lit.Body.List, &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}})
}
lit.Body.List = append(lit.Body.List, def.executeCall(source.HTTPStatusCode(imports, def.statusCode), ast.NewIdent(dataVarIdent), writeHeader))
lit.Body.List = append(lit.Body.List, tn.executeCall(source.HTTPStatusCode(imports, tn.statusCode), ast.NewIdent(dataVarIdent), writeHeader))
return lit, nil
}

Expand Down Expand Up @@ -398,19 +398,19 @@ func formInputTemplate(field *ast.Field, t *template.Template) *template.Templat
return t
}

func (def TemplateName) methodField(imports *source.Imports) *ast.Field {
func (tn TemplateName) methodField(imports *source.Imports) *ast.Field {
return &ast.Field{
Names: []*ast.Ident{ast.NewIdent(def.fun.Name)},
Type: def.funcType(imports),
Names: []*ast.Ident{ast.NewIdent(tn.fun.Name)},
Type: tn.funcType(imports),
}
}

func (def TemplateName) funcType(imports *source.Imports) *ast.FuncType {
func (tn TemplateName) funcType(imports *source.Imports) *ast.FuncType {
method := &ast.FuncType{
Params: &ast.FieldList{},
Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}},
}
for _, a := range def.call.Args {
for _, a := range tn.call.Args {
arg := a.(*ast.Ident)
switch arg.Name {
case TemplateNameScopeIdentifierHTTPRequest:
Expand Down Expand Up @@ -686,29 +686,29 @@ func appendAssignment(assignTok token.Token, result ast.Expr) func(exp ast.Expr)
}
}

func (def TemplateName) executeCall(status, data ast.Expr, writeHeader bool) *ast.ExprStmt {
func (tn TemplateName) executeCall(status, data ast.Expr, writeHeader bool) *ast.ExprStmt {
return &ast.ExprStmt{X: &ast.CallExpr{
Fun: ast.NewIdent(executeIdentName),
Args: []ast.Expr{
ast.NewIdent(TemplateNameScopeIdentifierHTTPResponse),
ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest),
ast.NewIdent(strconv.FormatBool(writeHeader)),
&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(def.name)},
&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(tn.name)},
status,
data,
},
}}
}

func (def TemplateName) httpRequestReceiverTemplateHandlerFunc(imports *source.Imports, statusCode int) *ast.FuncLit {
func (tn TemplateName) httpRequestReceiverTemplateHandlerFunc(imports *source.Imports, statusCode int) *ast.FuncLit {
return &ast.FuncLit{
Type: httpHandlerFuncType(imports),
Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(source.HTTPStatusCode(imports, statusCode), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}},
Body: &ast.BlockStmt{List: []ast.Stmt{tn.executeCall(source.HTTPStatusCode(imports, statusCode), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}},
}
}

func (def TemplateName) matchReceiver(funcDecl *ast.FuncDecl, receiverTypeIdent string) bool {
if funcDecl == nil || funcDecl.Name == nil || funcDecl.Name.Name != def.fun.Name ||
func (tn TemplateName) matchReceiver(funcDecl *ast.FuncDecl, receiverTypeIdent string) bool {
if funcDecl == nil || funcDecl.Name == nil || funcDecl.Name.Name != tn.fun.Name ||
funcDecl.Recv == nil || len(funcDecl.Recv.List) < 1 {
return false
}
Expand Down
2 changes: 1 addition & 1 deletion routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1667,7 +1667,7 @@ func execute(response http.ResponseWriter, request *http.Request, writeHeader bo
logs := log.New(io.Discard, "", 0)
set := token.NewFileSet()
goFiles := methodFuncTypeLoader(t, set, tt.ReceiverPackage)
out, err := muxt.Generate(templateNames, ts, tt.PackageName, tt.TemplatesVar, tt.RoutesFunc, tt.Receiver, tt.Interface, muxt.DefaultOutputFileName, set, goFiles, goFiles, logs)
out, err := muxt.Generate(templateNames, tt.PackageName, tt.TemplatesVar, tt.RoutesFunc, tt.Receiver, tt.Interface, muxt.DefaultOutputFileName, set, goFiles, goFiles, logs)
if tt.ExpectedError == "" {
assert.NoError(t, err)
assert.Equal(t, tt.ExpectedFile, out)
Expand Down

0 comments on commit 2a43e90

Please sign in to comment.