diff --git a/cmd/muxt/generate.go b/cmd/muxt/generate.go index 1435dae..2547a7f 100644 --- a/cmd/muxt/generate.go +++ b/cmd/muxt/generate.go @@ -104,11 +104,15 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string) if err != nil { return err } - out := log.New(stdout, "", 0) - s, err := muxt.Routes(templates, g.goPackage, g.templatesVariable, g.routesFunction, g.receiverIdent, g.receiverInterfaceIdent, g.outputFilename, packageList, out) - if err != nil { - return err - } + s, err := muxt.TemplateRoutesFile(workingDirectory, templates, log.New(stdout, "", 0), muxt.RoutesFileConfiguration{ + Package: g.goPackage, + PackagePath: g.Package.PkgPath, + TemplatesVar: g.templatesVariable, + RoutesFunc: g.routesFunction, + ReceiverType: g.receiverIdent, + ReceiverInterface: g.receiverInterfaceIdent, + Output: g.outputFilename, + }) var sb bytes.Buffer sb.WriteString(CodeGenerationComment) if v, ok := cliVersion(); ok { diff --git a/internal/source/html.go b/internal/source/html.go index c45c4fd..2845388 100644 --- a/internal/source/html.go +++ b/internal/source/html.go @@ -5,6 +5,7 @@ import ( "fmt" "go/ast" "go/token" + "go/types" "regexp" "slices" "strings" @@ -17,7 +18,7 @@ type ValidationGenerator interface { GenerateValidation(imports *Imports, variable ast.Expr, handleError func(string) ast.Stmt) ast.Stmt } -func ParseInputValidations(name string, input spec.Element, tp ast.Expr) ([]ValidationGenerator, error) { +func ParseInputValidations(name string, input spec.Element, tp types.Type) ([]ValidationGenerator, error) { if tag := strings.ToLower(input.TagName()); tag != atom.Input.String() { return nil, fmt.Errorf("expected element to have tag got <%s>", tag) } diff --git a/internal/source/imports.go b/internal/source/imports.go index 36231f6..705ee2b 100644 --- a/internal/source/imports.go +++ b/internal/source/imports.go @@ -1,8 +1,11 @@ package source import ( + "fmt" "go/ast" + "go/parser" "go/token" + "go/types" "log" "path" "slices" @@ -12,6 +15,10 @@ import ( type Imports struct { *ast.GenDecl + fileSet *token.FileSet + types map[string]*types.Package + files map[string]*ast.File + outputPackage string } func NewImports(decl *ast.GenDecl) *Imports { @@ -20,7 +27,82 @@ func NewImports(decl *ast.GenDecl) *Imports { log.Panicf("expected decl to have token.IMPORT Tok got %s", got) } } - return &Imports{GenDecl: decl} + return &Imports{GenDecl: decl, types: make(map[string]*types.Package), files: make(map[string]*ast.File)} +} + +func (imports *Imports) AddPackages(p *types.Package) { + recursivelyRegisterPackages(imports.types, p) +} + +func (imports *Imports) FileSet() *token.FileSet { + if imports.fileSet == nil { + imports.fileSet = token.NewFileSet() + } + return imports.fileSet +} + +func (imports *Imports) SetOutputPackage(pkgPath string) { + imports.outputPackage = pkgPath +} + +func (imports *Imports) OutputPackage() string { + return imports.outputPackage +} + +func (imports *Imports) SyntaxFile(pos token.Pos) (*ast.File, *token.FileSet, error) { + position := imports.FileSet().Position(pos) + fSet := token.NewFileSet() + file, err := parser.ParseFile(fSet, position.Filename, nil, parser.AllErrors|parser.ParseComments|parser.SkipObjectResolution) + return file, fSet, err +} + +func (imports *Imports) FieldTag(pos token.Pos) (*ast.Field, error) { + file, fileSet, err := imports.SyntaxFile(pos) + if err != nil { + return nil, err + } + position := imports.fileSet.Position(pos) + for _, d := range file.Decls { + switch decl := d.(type) { + case *ast.GenDecl: + for _, s := range decl.Specs { + switch spec := s.(type) { + case *ast.TypeSpec: + tp, ok := spec.Type.(*ast.StructType) + if !ok { + continue + } + + for _, field := range tp.Fields.List { + for _, name := range field.Names { + p := fileSet.Position(name.Pos()) + if p != position { + continue + } + return field, nil + } + } + } + } + } + + } + return nil, fmt.Errorf("failed to find field") +} + +func (imports *Imports) Types(pkgPath string) (*types.Package, bool) { + p, ok := imports.types[pkgPath] + return p, ok +} + +func recursivelyRegisterPackages(set map[string]*types.Package, pkg *types.Package) { + if pkg == nil { + return + } + set[pkg.Path()] = pkg + for _, p := range pkg.Imports() { + recursivelyRegisterPackages(set, p) + } } func (imports *Imports) Add(pkgIdent, pkgPath string) string { @@ -31,27 +113,29 @@ func (imports *Imports) Add(pkgIdent, pkgPath string) string { if pkgIdent == "" { pkgIdent = path.Base(pkgPath) } - for _, s := range imports.GenDecl.Specs { - spec := s.(*ast.ImportSpec) - pp, _ := strconv.Unquote(spec.Path.Value) - if pp == pkgPath { - if spec.Name != nil && spec.Name.Name != "" && spec.Name.Name != pkgIdent { - return spec.Name.Name + if pkgPath != imports.outputPackage { + for _, s := range imports.GenDecl.Specs { + spec := s.(*ast.ImportSpec) + pp, _ := strconv.Unquote(spec.Path.Value) + if pp == pkgPath { + if spec.Name != nil && spec.Name.Name != "" && spec.Name.Name != pkgIdent { + return spec.Name.Name + } + return path.Base(pp) } - return path.Base(pp) } + var pi *ast.Ident + if path.Base(pkgPath) != pkgIdent { + pi = Ident(pkgIdent) + } + imports.GenDecl.Specs = append(imports.GenDecl.Specs, &ast.ImportSpec{ + Path: String(pkgPath), + Name: pi, + }) + slices.SortFunc(imports.GenDecl.Specs, func(a, b ast.Spec) int { + return strings.Compare(a.(*ast.ImportSpec).Path.Value, b.(*ast.ImportSpec).Path.Value) + }) } - var pi *ast.Ident - if path.Base(pkgPath) != pkgIdent { - pi = Ident(pkgIdent) - } - imports.GenDecl.Specs = append(imports.GenDecl.Specs, &ast.ImportSpec{ - Path: String(pkgPath), - Name: pi, - }) - slices.SortFunc(imports.GenDecl.Specs, func(a, b ast.Spec) int { - return strings.Compare(a.(*ast.ImportSpec).Path.Value, b.(*ast.ImportSpec).Path.Value) - }) return pkgIdent } diff --git a/internal/source/parse.go b/internal/source/parse.go index 6aadd41..a0dd6ca 100644 --- a/internal/source/parse.go +++ b/internal/source/parse.go @@ -4,6 +4,7 @@ import ( "fmt" "go/ast" "go/token" + "go/types" "net/http" "regexp" "slices" @@ -71,7 +72,7 @@ func GenerateParseValueFromStringStatements(imports *Imports, tmp string, str, t return nil, fmt.Errorf("unsupported type: %s", Format(typeExp)) } -func GenerateValidations(imports *Imports, variable, variableType ast.Expr, inputQuery, inputName, responseIdent string, fragment spec.DocumentFragment) ([]ast.Stmt, error, bool) { +func GenerateValidations(imports *Imports, variable ast.Expr, variableType types.Type, inputQuery, inputName, responseIdent string, fragment spec.DocumentFragment) ([]ast.Stmt, error, bool) { input := fragment.QuerySelector(inputQuery) if input == nil { return nil, nil, false diff --git a/internal/source/parse_test.go b/internal/source/parse_test.go index 6f9b237..99a2853 100644 --- a/internal/source/parse_test.go +++ b/internal/source/parse_test.go @@ -3,7 +3,7 @@ package source_test import ( "fmt" "go/ast" - "go/parser" + "go/types" "html/template" "strings" "testing" @@ -21,20 +21,20 @@ import ( func Test_inputValidations(t *testing.T) { for _, tt := range []struct { Name string - Type string + Type types.Type Template string Result string Error string }{ { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "no attributes", Template: ``, Result: `{ }`, }, { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "min", Template: ``, Result: `{ @@ -45,7 +45,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "negative min", Template: ``, Result: `{ @@ -56,7 +56,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "zero min", Template: ``, Result: `{ @@ -67,7 +67,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int8", + Type: types.Universe.Lookup("int8").Type(), Name: "zero min", Template: ``, Result: `{ @@ -78,7 +78,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int16", + Type: types.Universe.Lookup("int16").Type(), Name: "zero min", Template: ``, Result: `{ @@ -89,7 +89,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int32", + Type: types.Universe.Lookup("int32").Type(), Name: "zero min", Template: ``, Result: `{ @@ -100,7 +100,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int64", + Type: types.Universe.Lookup("int64").Type(), Name: "zero min", Template: ``, Result: `{ @@ -111,7 +111,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "uint", + Type: types.Universe.Lookup("uint").Type(), Name: "zero min", Template: ``, Result: `{ @@ -122,7 +122,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "uint8", + Type: types.Universe.Lookup("uint8").Type(), Name: "zero min", Template: ``, Result: `{ @@ -133,7 +133,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "uint16", + Type: types.Universe.Lookup("uint16").Type(), Name: "zero min", Template: ``, Result: `{ @@ -144,7 +144,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "uint32", + Type: types.Universe.Lookup("uint32").Type(), Name: "zero min", Template: ``, Result: `{ @@ -155,7 +155,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "uint64", + Type: types.Universe.Lookup("uint64").Type(), Name: "zero min", Template: ``, Result: `{ @@ -166,91 +166,79 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseInt: parsing "18446744073709551616": value out of range`, }, { - Type: "int8", + Type: types.Universe.Lookup("int8").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseInt: parsing "256": value out of range`, }, { - Type: "int16", + Type: types.Universe.Lookup("int16").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseInt: parsing "32768": value out of range`, }, { - Type: "int32", + Type: types.Universe.Lookup("int32").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseInt: parsing "2147483648": value out of range`, }, { - Type: "int64", + Type: types.Universe.Lookup("int64").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseInt: parsing "9223372036854775808": value out of range`, }, { - Type: "uint", + Type: types.Universe.Lookup("uint").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseUint: parsing "-10": invalid syntax`, }, { - Type: "uint8", + Type: types.Universe.Lookup("uint8").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseUint: parsing "256": value out of range`, }, { - Type: "uint16", + Type: types.Universe.Lookup("uint16").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseUint: parsing "65536": value out of range`, }, { - Type: "uint32", + Type: types.Universe.Lookup("uint32").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseUint: parsing "4294967296": value out of range`, }, { - Type: "uint64", + Type: types.Universe.Lookup("uint64").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseUint: parsing "18446744073709551616": value out of range`, }, { - Type: "*T", - Name: "unsupported type", - Template: ``, - Error: `type *T is not supported`, - }, - { - Type: "T", - Name: "type unknown", - Template: ``, - Error: `type T unknown`, - }, - { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "not a number", Template: ``, Error: `strconv.ParseInt: parsing "NaN": invalid syntax`, }, { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "wrong tag", Template: `
`, Error: `expected element to have tag got
`, }, { - Type: "uint32", + Type: types.Universe.Lookup("uint32").Type(), Name: "zero max", Template: ``, Result: `{ @@ -263,8 +251,6 @@ func Test_inputValidations(t *testing.T) { } { t.Run(fmt.Sprintf("cromulent attribute type %s %s", tt.Type, tt.Name), func(t *testing.T) { v := ast.NewIdent("v") - tp, err := parser.ParseExpr(tt.Type) - require.NoError(t, err) ts := template.Must(template.New("").Parse(tt.Template)) nodes, err := html.ParseFragment(strings.NewReader(ts.Tree.Root.String()), &html.Node{ Type: html.ElementNode, @@ -273,7 +259,7 @@ func Test_inputValidations(t *testing.T) { }) fragment := dom.NewDocumentFragment(nodes) imports := source.NewImports(nil) - statements, err, ok := source.GenerateValidations(imports, v, tp, `[name="field"]`, "field", "response", fragment) + statements, err, ok := source.GenerateValidations(imports, v, tt.Type, `[name="field"]`, "field", "response", fragment) require.True(t, ok) if tt.Error != "" { require.Error(t, err) diff --git a/internal/source/reflect.go b/internal/source/reflect.go index 167ae9e..3b3fad5 100644 --- a/internal/source/reflect.go +++ b/internal/source/reflect.go @@ -2,17 +2,13 @@ package source import ( "fmt" - "go/ast" + "go/types" "reflect" "strconv" ) -func ParseStringWithType(val string, tp ast.Expr) (reflect.Value, error) { - tpIdent, ok := tp.(*ast.Ident) - if !ok { - return reflect.Value{}, fmt.Errorf("type %s is not supported", Format(tp)) - } - switch tpIdent.Name { +func ParseStringWithType(val string, tp types.Type) (reflect.Value, error) { + switch tp.Underlying().String() { case reflect.Int.String(): n, err := strconv.ParseInt(val, 10, 64) if err != nil { @@ -74,6 +70,6 @@ func ParseStringWithType(val string, tp ast.Expr) (reflect.Value, error) { } return reflect.ValueOf(n), nil default: - return reflect.Value{}, fmt.Errorf("type %s unknown", Format(tp)) + return reflect.Value{}, fmt.Errorf("type %s unknown", tp.String()) } } diff --git a/routes.go b/routes.go index 2fb959c..b77967f 100644 --- a/routes.go +++ b/routes.go @@ -14,8 +14,10 @@ import ( "slices" "strconv" "strings" + "time" "github.com/crhntr/dom" + "github.com/stretchr/testify/assert" "golang.org/x/net/html" "golang.org/x/net/html/atom" "golang.org/x/tools/go/packages" @@ -33,12 +35,9 @@ const ( requestPathValue = "PathValue" httpRequestContextMethod = "Context" httpResponseWriterIdent = "ResponseWriter" - httpServeMuxIdent = "ServeMux" httpRequestIdent = "Request" httpHandleFuncIdent = "HandleFunc" - contextContextTypeIdent = "Context" - defaultPackageName = "main" DefaultTemplatesVariableName = "templates" DefaultRoutesFunctionName = "routes" @@ -48,413 +47,307 @@ const ( InputAttributeNameStructTag = "name" InputAttributeTemplateStructTag = "template" + muxParamName = "mux" + receiverParamName = "receiver" + errIdent = "err" ) -func Routes(templates []Template, packageName, templatesVariableName, routesFunctionName, receiverTypeIdent, receiverInterfaceIdent, output string, packageList []*packages.Package, log *log.Logger) (string, error) { - packageName = cmp.Or(packageName, defaultPackageName) - templatesVariableName = cmp.Or(templatesVariableName, DefaultTemplatesVariableName) - routesFunctionName = cmp.Or(routesFunctionName, DefaultRoutesFunctionName) - receiverInterfaceIdent = cmp.Or(receiverInterfaceIdent, DefaultReceiverInterfaceName) +type RoutesFileConfiguration struct { + executeFunc bool + Package, + PackagePath, + TemplatesVar, + RoutesFunc, + ReceiverType, + ReceiverInterface, + Output string +} +func (config RoutesFileConfiguration) applyDefaults() RoutesFileConfiguration { + config.Package = cmp.Or(config.Package, defaultPackageName) + config.TemplatesVar = cmp.Or(config.TemplatesVar, DefaultTemplatesVariableName) + config.RoutesFunc = cmp.Or(config.RoutesFunc, DefaultRoutesFunctionName) + config.ReceiverInterface = cmp.Or(config.ReceiverInterface, DefaultReceiverInterfaceName) + config.executeFunc = true + return config +} + +func TemplateRoutesFile(wd string, templates []Template, logger *log.Logger, config RoutesFileConfiguration) (string, error) { + config = config.applyDefaults() imports := source.NewImports(&ast.GenDecl{Tok: token.IMPORT}) - var ( - receiverType *types.Named - pkg *packages.Package - receiverPackage []*ast.File - ) - for _, p := range packageList { - if p.Types.Name() == packageName { - pkg = p - receiverPackage = pkg.Syntax - break - } + + var pkg *types.Package + pl, err := packages.Load(&packages.Config{ + Fset: imports.FileSet(), + Mode: packages.NeedModule | packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedSyntax | packages.NeedEmbedPatterns | packages.NeedEmbedFiles, + Dir: wd, + }, ".", "net/http") + if err != nil { + return "", err + } + var receiver *types.Named + for _, p := range pl { + imports.AddPackages(p.Types) + } + if len(pl) > 0 && pl[0].PkgPath != "net/http" { + imports.SetOutputPackage(pl[0].PkgPath) + } else { + imports.SetOutputPackage(filepath.Base(wd)) } - if pkg != nil { - receiverTypeObj := pkg.Types.Scope().Lookup(receiverTypeIdent) - if receiverTypeObj != nil { - named, ok := receiverTypeObj.Type().(*types.Named) - if ok { - receiverType = named + for _, p := range pl { + if p.Types.Path() == config.PackagePath { + if executeObj := p.Types.Scope().Lookup("execute"); executeObj != nil { + if _, ok := executeObj.(*types.Func); ok { + config.executeFunc = filepath.Base(p.Fset.Position(executeObj.Pos()).Filename) == config.Output + } + } + obj := p.Types.Scope().Lookup(config.ReceiverType) + if obj != nil { + named, ok := obj.Type().(*types.Named) + if !ok { + return "", fmt.Errorf("expected receiver to be a named type") + } + receiver = named } + break } } - receiverInterface := receiverInterfaceType(imports, packageList, receiverType, templates) - routesFunc, err := routesFuncDeclaration(imports, routesFunctionName, receiverInterfaceIdent, receiverInterface, receiverPackage, templates, log) - if err != nil { - return "", err + if receiver == nil { + receiver = types.NewNamed(types.NewTypeName(0, pkg, "Receiver", nil), types.NewStruct(nil, nil), nil) } - typesDecl := &ast.GenDecl{ - Tok: token.TYPE, - Specs: []ast.Spec{ - &ast.TypeSpec{Name: ast.NewIdent(receiverInterfaceIdent), Type: receiverInterface}, + receiverInterface := &ast.InterfaceType{ + Methods: new(ast.FieldList), + } + + routesFunc := &ast.FuncDecl{ + Name: ast.NewIdent(config.RoutesFunc), + Type: &ast.FuncType{ + Params: &ast.FieldList{ + List: []*ast.Field{ + httpServMuxField(imports), + { + Names: []*ast.Ident{ast.NewIdent(receiverParamName)}, + Type: ast.NewIdent(config.ReceiverInterface), + }, + }, + }, }, + Body: new(ast.BlockStmt), + } + + for _, t := range templates { + logger.Printf("routes has route for %s", t.endpoint) + if t.fun == nil { + hf := t.httpRequestReceiverTemplateHandlerFunc(imports, t.statusCode) + routesFunc.Body.List = append(routesFunc.Body.List, t.callHandleFunc(hf)) + continue + } + writeHeader := !hasHTTPResponseWriterArgument(t.call) + handlerFunc := &ast.FuncLit{ + Type: httpHandlerFuncType(imports), + Body: &ast.BlockStmt{}, + } + if err := ensureMethodSignature(imports, t, receiver, receiverInterface, t.call); err != nil { + return "", err + } + methodObj, _, _ := types.LookupFieldOrMethod(receiver, true, receiver.Obj().Pkg(), t.fun.Name) + if methodObj == nil { + return "", fmt.Errorf("failed to generate method %s", t.fun.Name) + } + sig := methodObj.Type().(*types.Signature) + if sig.Results().Len() == 0 { + return "", fmt.Errorf("method for endpoint %q has no results it should have one or two", t.name) + } + if handlerFunc.Body.List, err = appendParseArgumentStatements(handlerFunc.Body.List, t, imports, nil, receiver, t.call); err != nil { + return "", err + } + const dataVarIdent = "data" + receiverCallStatements, err := callReceiverMethod(imports, dataVarIdent, sig, &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(receiverIdent), + Sel: t.fun, + }, + Args: slices.Clone(t.call.Args), + }) + if err != nil { + return "", err + } + handlerFunc.Body.List = append(handlerFunc.Body.List, receiverCallStatements...) + handlerFunc.Body.List = append(handlerFunc.Body.List, t.executeCall(source.HTTPStatusCode(imports, t.statusCode), ast.NewIdent(dataVarIdent), writeHeader)) + routesFunc.Body.List = append(routesFunc.Body.List, t.callHandleFunc(handlerFunc)) } imports.SortImports() file := &ast.File{ - Name: ast.NewIdent(packageName), + Name: ast.NewIdent(config.Package), Decls: []ast.Decl{ + // import imports.GenDecl, - typesDecl, + + // type + &ast.GenDecl{ + Tok: token.TYPE, + Specs: []ast.Spec{ + &ast.TypeSpec{Name: ast.NewIdent(config.ReceiverInterface), Type: receiverInterface}, + }, + }, + + // func routes routesFunc, }, } - if pkg != nil { - if obj := pkg.Types.Scope().Lookup(executeIdentName); obj == nil { - file.Decls = append(file.Decls, executeFuncDecl(imports, templatesVariableName)) - } else { - pos := pkg.Fset.Position(obj.Pos()) - if filepath.Base(pos.Filename) == output { - file.Decls = append(file.Decls, executeFuncDecl(imports, templatesVariableName)) - } - } - } else { - file.Decls = append(file.Decls, executeFuncDecl(imports, templatesVariableName)) + + if config.executeFunc { + file.Decls = append(file.Decls, executeFuncDecl(imports, config.TemplatesVar)) } return source.Format(file), nil } -func routesFuncDeclaration(imports *source.Imports, routesFunctionName, receiverInterfaceIdent string, receiverInterfaceType *ast.InterfaceType, receiverPackage []*ast.File, templateNames []Template, log *log.Logger) (*ast.FuncDecl, error) { - routes := &ast.FuncDecl{ - Name: ast.NewIdent(routesFunctionName), - Type: routesFuncType(imports, ast.NewIdent(receiverInterfaceIdent)), - Body: &ast.BlockStmt{}, +func appendParseArgumentStatements(statements []ast.Stmt, t Template, imports *source.Imports, parsed map[string]struct{}, receiver *types.Named, call *ast.CallExpr) ([]ast.Stmt, error) { + fun, ok := call.Fun.(*ast.Ident) + if !ok { + return nil, fmt.Errorf("expected function to be identifier") } - - for _, tn := range templateNames { - log.Printf("%s has route for %s", routesFunctionName, tn.endpoint) - if tn.fun == nil { - hf := tn.httpRequestReceiverTemplateHandlerFunc(imports, tn.statusCode) - routes.Body.List = append(routes.Body.List, tn.callHandleFunc(hf)) - continue - } - - hf, err := tn.funcLit(imports, receiverInterfaceType, receiverPackage) - if err != nil { - return nil, err - } - routes.Body.List = append(routes.Body.List, tn.callHandleFunc(hf)) + obj, _, _ := types.LookupFieldOrMethod(receiver, true, receiver.Obj().Pkg(), fun.Name) + if obj == nil { + return nil, fmt.Errorf("method %s not defined on %s", fun.Name, receiver.Obj().Type()) } - - return routes, nil -} - -func receiverInterfaceType(imports *source.Imports, packageList []*packages.Package, receiverType *types.Named, templateNames []Template) *ast.InterfaceType { - interfaceMethods := new(ast.FieldList) - for _, tn := range templateNames { - if tn.fun == nil { - continue - } - for _, arg := range tn.call.Args { - switch exp := arg.(type) { - case *ast.CallExpr: - callIdent, ok := exp.Fun.(*ast.Ident) - if !ok { - continue - } - method, ok := interfaceMethodForCall(imports, interfaceMethods, callIdent, exp, receiverType, packageList) - if !ok { - continue - } - interfaceMethods.List = append(interfaceMethods.List, method) - } - } - if source.HasFieldWithName(interfaceMethods, tn.fun.Name) { - continue - } - method, ok := interfaceMethodForCall(imports, interfaceMethods, tn.fun, tn.call, receiverType, packageList) - if !ok { - continue - } - interfaceMethods.List = append(interfaceMethods.List, method) + signature, ok := obj.Type().(*types.Signature) + if !ok { + return nil, fmt.Errorf("expected method") } - return &ast.InterfaceType{Methods: interfaceMethods} -} - -func interfaceMethodForCall(imports *source.Imports, interfaceMethods *ast.FieldList, methodName *ast.Ident, call *ast.CallExpr, receiverType *types.Named, packageList []*packages.Package) (*ast.Field, bool) { - if source.HasFieldWithName(interfaceMethods, methodName.Name) { - return nil, false + //const parsedVariableName = "parsed" + if exp := signature.Params().Len(); exp != len(call.Args) { // TODO: (signature.Variadic() && exp > len(call.Args)) + sigStr := fun.Name + strings.TrimPrefix(signature.String(), "func") + return nil, fmt.Errorf("handler func %s expects %d arguments but call %s has %d", sigStr, signature.Params().Len(), source.Format(call), len(call.Args)) } - if receiverType != nil { - obj, _, _ := types.LookupFieldOrMethod(receiverType, true, receiverType.Obj().Pkg(), methodName.Name) - if obj != nil { - fn, ok := obj.(*types.Func) - if !ok { - return nil, false - } - for _, pkg := range packageList { - if pkg.PkgPath != obj.Pkg().Path() { - continue - } - for _, file := range pkg.Syntax { - for _, decl := range file.Decls { - fd, ok := decl.(*ast.FuncDecl) - if !ok || fd.Name.Pos() != fn.Pos() { - continue - } - return &ast.Field{ - Names: []*ast.Ident{ast.NewIdent(fd.Name.Name)}, - Type: fd.Type, - }, true - } - } - } - } + if parsed == nil { + parsed = make(map[string]struct{}) } - return &ast.Field{ - Names: []*ast.Ident{ast.NewIdent(methodName.Name)}, - Type: generateFuncTypeFromArguments(imports, call), - }, true -} + resultCount := 0 + for i, a := range call.Args { + param := signature.Params().At(i) -func (t Template) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStmt { - return &ast.ExprStmt{X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(muxVarIdent), - Sel: ast.NewIdent(httpHandleFuncIdent), - }, - Args: []ast.Expr{source.String(t.endpoint), handlerFuncLit}, - }} -} + switch arg := a.(type) { + default: + // TODO: add error case + case *ast.CallExpr: + parseArgStatements, err := appendParseArgumentStatements(statements, t, imports, parsed, receiver, arg) + if err != nil { + return nil, err + } + resultVarIdent := "result" + strconv.Itoa(resultCount) + call.Args[i] = ast.NewIdent(resultVarIdent) + resultCount++ -func (t Template) funcLit(imports *source.Imports, receiverInterfaceType *ast.InterfaceType, files []*ast.File) (*ast.FuncLit, error) { - methodField, ok := source.FindFieldWithName(receiverInterfaceType.Methods, t.fun.Name) - if !ok { - log.Fatalf("receiver does not have a method declaration for %s", t.fun.Name) - } - method := methodField.Type.(*ast.FuncType) - lit := &ast.FuncLit{ - Type: httpHandlerFuncType(imports), - Body: &ast.BlockStmt{}, - } - call := &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent(receiverIdent), Sel: ast.NewIdent(t.fun.Name)}} - if method.Params.NumFields() != len(t.call.Args) { - return nil, errWrongNumberOfArguments(t, method) - } - argTypes := slices.Collect(fieldListTypes(method.Params)) - resultArgCount := 0 - var handledIdents []string - for i, arg := range t.call.Args { - argType := argTypes[i] - switch exp := arg.(type) { - case *ast.Ident: - call.Args = append(call.Args, ast.NewIdent(exp.Name)) - if slices.Contains(handledIdents, exp.Name) { - continue + obj, _, _ := types.LookupFieldOrMethod(receiver.Obj().Type(), true, receiver.Obj().Pkg(), arg.Fun.(*ast.Ident).Name) + methodSignature, err := astTypeExpression(imports, obj.Type()) + if err != nil { + return nil, err } - statements, err := t.identifierArgument(imports, i, exp, argType, method, files) + + callMethodStatements, err := t.callReceiverMethod(imports, resultVarIdent, methodSignature.(*ast.FuncType), arg) if err != nil { return nil, err } - lit.Body.List = append(lit.Body.List, statements...) - handledIdents = append(handledIdents, exp.Name) - case *ast.CallExpr: - callMethodIdent, ok := exp.Fun.(*ast.Ident) - if !ok { - return nil, fmt.Errorf("argument call expression function must be an identifier: got %s", source.Format(exp.Fun)) + arg.Fun = &ast.SelectorExpr{ + X: ast.NewIdent(receiverIdent), + Sel: ast.NewIdent(arg.Fun.(*ast.Ident).Name), } - callMethodField, ok := source.FindFieldWithName(receiverInterfaceType.Methods, callMethodIdent.Name) + + statements = append(parseArgStatements, callMethodStatements...) + case *ast.Ident: + argType, ok := defaultTemplateNameScope(imports, t, arg.Name) if !ok { - log.Fatalf("receiver does not have a method declaration for %s", callMethodIdent.Name) + return nil, fmt.Errorf("failed to determine type for %s", arg.Name) } - callMethod := callMethodField.Type.(*ast.FuncType) - callArgTypes := slices.Collect(fieldListTypes(callMethod.Params)) - callCall := &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent(receiverIdent), Sel: ast.NewIdent(callMethodIdent.Name)}} - - if aCount, pCount := len(exp.Args), callMethod.Params.NumFields(); aCount != pCount { - return nil, fmt.Errorf("expected %d arguments for method %s with %d parameters", aCount, callMethodIdent.Name, pCount) + src := &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, } - - for j, callArg := range exp.Args { - callArgType := callArgTypes[j] - switch callArgExp := callArg.(type) { - case *ast.Ident: - callCall.Args = append(callCall.Args, ast.NewIdent(callArgExp.Name)) - if slices.Contains(handledIdents, callArgExp.Name) { - continue - } - parseStatements, err := t.identifierArgument(imports, j, callArgExp, callArgType, callMethod, files) - if err != nil { - return nil, err + if types.AssignableTo(argType, param.Type()) { + if _, ok := parsed[arg.Name]; !ok { + parsed[arg.Name] = struct{}{} + switch arg.Name { + case TemplateNameScopeIdentifierForm: + declareFormVar, err := formVariableAssignment(imports, arg, param.Type()) + if err != nil { + return nil, err + } + statements = append(statements, callParseForm(), declareFormVar) + case TemplateNameScopeIdentifierContext: + statements = append(statements, contextAssignment(TemplateNameScopeIdentifierContext)) + default: + if slices.Contains(t.parsePathValueNames(), arg.Name) { + statements = append(statements, singleAssignment(token.DEFINE, ast.NewIdent(arg.Name))(src)) + } } - lit.Body.List = append(lit.Body.List, parseStatements...) - handledIdents = append(handledIdents, callArgExp.Name) } + continue } - resultVar := "result" + strconv.Itoa(resultArgCount) - - receiverCallStatements, err := t.callReceiverMethod(imports, resultVar, callMethod, callCall) - if err != nil { - return nil, err + if _, ok := parsed[arg.Name]; ok { + continue } - lit.Body.List = append(lit.Body.List, receiverCallStatements...) - call.Args = append(call.Args, ast.NewIdent(resultVar)) - resultArgCount++ - } - } - const dataVarIdent = "data" - receiverCallStatements, err := t.callReceiverMethod(imports, dataVarIdent, method, call) - if err != nil { - return nil, err - } - lit.Body.List = append(lit.Body.List, receiverCallStatements...) - - lit.Body.List = append(lit.Body.List, t.executeCall(source.HTTPStatusCode(imports, t.statusCode), ast.NewIdent(dataVarIdent), t.callWriteHeader(receiverInterfaceType))) - return lit, nil -} - -func (t Template) callReceiverMethod(imports *source.Imports, dataVarIdent string, method *ast.FuncType, call *ast.CallExpr) ([]ast.Stmt, error) { - const ( - okIdent = "ok" - ) - if method.Results == nil || len(method.Results.List) == 0 { - return nil, fmt.Errorf("method for endpoint %q has no results it should have one or two", t) - } else if len(method.Results.List) > 1 { - _, lastResultType, ok := source.FieldIndex(method.Results.List, method.Results.NumFields()-1) - if !ok { - return nil, fmt.Errorf("failed to get the last method result") - } - switch rt := lastResultType.(type) { - case *ast.Ident: - switch rt.Name { - case "error": - return []ast.Stmt{ - &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent), ast.NewIdent(errIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}, - &ast.IfStmt{ - Cond: &ast.BinaryExpr{X: ast.NewIdent(errIdent), Op: token.NEQ, Y: source.Nil()}, - Body: &ast.BlockStmt{ - List: []ast.Stmt{ - &ast.ExprStmt{X: imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), source.CallError(errIdent), http.StatusInternalServerError)}, - &ast.ReturnStmt{}, - }, - }, - }, - }, nil - case "bool": - return []ast.Stmt{ - &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent), ast.NewIdent(okIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}, - &ast.IfStmt{ - Cond: &ast.UnaryExpr{Op: token.NOT, X: ast.NewIdent(okIdent)}, - Body: &ast.BlockStmt{ - List: []ast.Stmt{ - &ast.ReturnStmt{}, - }, - }, - }, - }, nil + switch { + case slices.Contains(t.parsePathValueNames(), arg.Name): + parsed[arg.Name] = struct{}{} + s, err := generateParseValueFromStringStatements(imports, arg.Name+"Parsed", src, param.Type(), errCheck(imports), nil, singleAssignment(token.DEFINE, ast.NewIdent(arg.Name))) + if err != nil { + return nil, err + } + statements = append(statements, s...) + case arg.Name == TemplateNameScopeIdentifierForm: + s, err := appendFormParseStatements(statements, t, imports, arg, param) + if err != nil { + return nil, err + } + statements = s default: - return nil, fmt.Errorf("expected last result to be either an error or a bool") + pt, _ := astTypeExpression(imports, param.Type()) + at, _ := astTypeExpression(imports, argType) + return nil, fmt.Errorf("method expects type %s but %s is %s", source.Format(pt), arg.Name, source.Format(at)) } - default: - return nil, fmt.Errorf("expected last result to be either an error or a bool") } - } else { - return []ast.Stmt{&ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}}, nil } + return statements, nil } -func (t Template) identifierArgument(imports *source.Imports, i int, arg *ast.Ident, argType ast.Expr, method *ast.FuncType, files []*ast.File) ([]ast.Stmt, error) { - switch arg.Name { - case TemplateNameScopeIdentifierHTTPResponse: - if !matchSelectorIdents(argType, imports.AddNetHTTP(), httpResponseWriterIdent, false) { - return nil, fmt.Errorf("method expects type %s but %s is %s.%s", source.Format(argType), arg.Name, imports.AddNetHTTP(), httpResponseWriterIdent) - } - imports.AddNetHTTP() - return nil, nil - case TemplateNameScopeIdentifierHTTPRequest: - if !matchSelectorIdents(argType, imports.AddNetHTTP(), httpRequestIdent, true) { - return nil, fmt.Errorf("method expects type %s but %s is *%s.%s", source.Format(argType), arg.Name, imports.AddNetHTTP(), httpRequestIdent) - } - imports.AddNetHTTP() - return nil, nil - case TemplateNameScopeIdentifierContext: - if !matchSelectorIdents(argType, imports.AddContext(), contextContextTypeIdent, false) { - return nil, fmt.Errorf("method expects type %s but %s is %s.%s", source.Format(argType), arg.Name, imports.AddContext(), contextContextTypeIdent) - } - imports.AddContext() - return []ast.Stmt{contextAssignment(TemplateNameScopeIdentifierContext)}, nil - case TemplateNameScopeIdentifierForm: - _, tp, _ := source.FieldIndex(method.Params.List, i) - statements := []ast.Stmt{callParseForm(), formDeclaration(imports, arg.Name, tp)} - var formStruct *ast.StructType - if s, ok := findFormStruct(argType, files); ok { - formStruct = s - } - if formStruct == nil && !matchSelectorIdents(argType, "url", "Values", false) { - return nil, fmt.Errorf("method expects form to have type url.Values or T (where T is some struct type)") - } - if formStruct != nil { - result, err := t.parseFormFields(imports, formStruct, arg) - if err != nil { - return nil, err - } - statements = append(statements, result...) - } else { - imports.Add("", "net/url") - } - return statements, nil - default: - for paramIndex, paramType := range source.IterateFieldTypes(method.Params.List) { - if i != paramIndex { - continue - } - if err := compareTypes(paramType, argType); err != nil { - return nil, fmt.Errorf("method argument and param mismatch: %w", err) - } - break - } - src := &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, - } - statements, err := httpPathValueAssignment(imports, method, i, arg, src, token.DEFINE, errCheck(imports)) +func appendFormParseStatements(statements []ast.Stmt, t Template, imports *source.Imports, arg *ast.Ident, param types.Object) ([]ast.Stmt, error) { + const parsedVariableName = "value" + statements = append(statements, callParseForm()) + switch tp := param.Type().(type) { + case *types.Named: + declareFormVar, err := formVariableDeclaration(imports, arg, tp) if err != nil { return nil, err } - return statements, nil - } -} + statements = append(statements, declareFormVar) -func errCheck(imports *source.Imports) func(msg ast.Expr) ast.Stmt { - return func(msg ast.Expr) ast.Stmt { - return &ast.ExprStmt{ - X: imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), msg, http.StatusBadRequest), + form, ok := tp.Underlying().(*types.Struct) + if !ok { + return nil, fmt.Errorf("expected form parameter type to be a struct") } - } -} -func (t Template) parseFormFields(imports *source.Imports, formStruct *ast.StructType, arg *ast.Ident) ([]ast.Stmt, error) { - var statements []ast.Stmt - for _, field := range formStruct.Fields.List { - for _, name := range field.Names { - fieldExpr := &ast.SelectorExpr{ - X: ast.NewIdent(arg.Name), - Sel: ast.NewIdent(name.Name), + parseErrCheck := func(exp ast.Expr) ast.Stmt { + return &ast.ExprStmt{ + X: imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), source.CallError(errIdent), http.StatusBadRequest), } + } - fieldTemplate := formInputTemplate(field, t.template) - - errCheck := func(exp ast.Expr) ast.Stmt { - return &ast.ExprStmt{ - X: imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), source.CallError(errIdent), http.StatusBadRequest), - } + for i := 0; i < form.NumFields(); i++ { + field := form.Field(i) + inputName, fieldTemplate, err := fieldTags(imports, t.template, field) + if err != nil { + return nil, err } - - const parsedVariableName = "value" - if fieldType, ok := field.Type.(*ast.ArrayType); ok { - inputName := formInputName(field, name) - const valVar = "val" - assignment := appendAssignment(token.ASSIGN, &ast.SelectorExpr{ - X: ast.NewIdent(arg.Name), - Sel: ast.NewIdent(name.Name), - }) - var templateNodes []*html.Node + var templateNodes []*html.Node + if fieldTemplate != nil { if fieldTemplate != nil { templateNodes, _ = html.ParseFragment(strings.NewReader(fieldTemplate.Tree.Root.String()), &html.Node{ Type: html.ElementNode, @@ -462,250 +355,483 @@ func (t Template) parseFormFields(imports *source.Imports, formStruct *ast.Struc Data: atom.Body.String(), }) } - validations, err, ok := source.GenerateValidations(imports, ast.NewIdent(parsedVariableName), fieldType.Elt, fmt.Sprintf("[name=%q]", inputName), inputName, httpResponseField(imports).Names[0].Name, dom.NewDocumentFragment(templateNodes)) + } + + var ( + parseResult func(expr ast.Expr) ast.Stmt + str ast.Expr + elemType types.Type + ) + switch ft := field.Type().(type) { + case *types.Slice: + parseResult = func(expr ast.Expr) ast.Stmt { + return &ast.AssignStmt{ + Lhs: []ast.Expr{&ast.SelectorExpr{X: ast.NewIdent(TemplateNameScopeIdentifierForm), Sel: ast.NewIdent(field.Name())}}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent("append"), + Args: []ast.Expr{&ast.SelectorExpr{X: ast.NewIdent(TemplateNameScopeIdentifierForm), Sel: ast.NewIdent(field.Name())}, expr}, + }}, + } + } + str = ast.NewIdent("val") + elemType = ft.Elem() + validations, err, ok := source.GenerateValidations(imports, ast.NewIdent(parsedVariableName), elemType, fmt.Sprintf("[name=%q]", inputName), inputName, httpResponseField(imports).Names[0].Name, dom.NewDocumentFragment(templateNodes)) if ok && err != nil { return nil, err } - loopBlockStatements, err := source.GenerateParseValueFromStringStatements(imports, parsedVariableName, ast.NewIdent(valVar), fieldType.Elt, errCheck, validations, assignment) + parseStatements, err := generateParseValueFromStringStatements(imports, parsedVariableName, str, elemType, parseErrCheck, validations, parseResult) if err != nil { - return nil, fmt.Errorf("failed to generate parse statements for form field %s: %w", name.Name, err) + return nil, fmt.Errorf("failed to generate parse statements for form field %s: %w", field.Name(), err) } - - forLoop := &ast.RangeStmt{ + statements = append(statements, &ast.RangeStmt{ Key: ast.NewIdent("_"), - Value: ast.NewIdent(valVar), + Value: ast.NewIdent("val"), Tok: token.DEFINE, - X: &ast.IndexExpr{ - X: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent("Form"), - }, - Index: source.String(inputName), - }, - Body: &ast.BlockStmt{ - List: loopBlockStatements, - }, - } - - statements = append(statements, forLoop) - } else { - assignment := singleAssignment(token.ASSIGN, fieldExpr) - inputName := formInputName(field, name) - str := &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent("FormValue"), - }, - Args: []ast.Expr{ - &ast.BasicLit{ - Kind: token.STRING, - Value: strconv.Quote(inputName), - }, - }, - } - var templateNodes []*html.Node - if fieldTemplate != nil { - templateNodes, _ = html.ParseFragment(strings.NewReader(fieldTemplate.Tree.Root.String()), &html.Node{ - Type: html.ElementNode, - DataAtom: atom.Body, - Data: atom.Body.String(), - }) + X: &ast.IndexExpr{X: &ast.SelectorExpr{X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), Sel: ast.NewIdent("Form")}, Index: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(inputName)}}, + Body: &ast.BlockStmt{List: parseStatements}, + }) + default: + parseResult = func(expr ast.Expr) ast.Stmt { + return &ast.AssignStmt{ + Lhs: []ast.Expr{&ast.SelectorExpr{X: ast.NewIdent(TemplateNameScopeIdentifierForm), Sel: ast.NewIdent(field.Name())}}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{expr}, + } } - validations, err, ok := source.GenerateValidations(imports, ast.NewIdent(parsedVariableName), field.Type, fmt.Sprintf("[name=%q]", inputName), inputName, httpResponseField(imports).Names[0].Name, dom.NewDocumentFragment(templateNodes)) + str = &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), Sel: ast.NewIdent("FormValue")}, Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(inputName)}}} + elemType = field.Type() + validations, err, ok := source.GenerateValidations(imports, ast.NewIdent(parsedVariableName), elemType, fmt.Sprintf("[name=%q]", inputName), inputName, httpResponseField(imports).Names[0].Name, dom.NewDocumentFragment(templateNodes)) if ok && err != nil { return nil, err } - parseStatements, err := source.GenerateParseValueFromStringStatements(imports, parsedVariableName, str, field.Type, errCheck, validations, assignment) + parseStatements, err := generateParseValueFromStringStatements(imports, parsedVariableName, str, elemType, parseErrCheck, validations, parseResult) if err != nil { - return nil, fmt.Errorf("failed to generate parse statements for form field %s: %w", name.Name, err) + return nil, fmt.Errorf("failed to generate parse statements for form field %s: %w", field.Name(), err) } - if len(parseStatements) > 1 { - parseStatements = []ast.Stmt{&ast.BlockStmt{ + statements = append(statements, &ast.BlockStmt{ List: parseStatements, - }} + }) + } else { + statements = append(statements, parseStatements...) } - - statements = append(statements, parseStatements...) } } + + return statements, nil } - return statements, nil + return nil, fmt.Errorf("expected form parameter type to be a struct") } -func callParseForm() *ast.ExprStmt { - return &ast.ExprStmt{X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent("ParseForm"), +func fieldTags(imports *source.Imports, t *template.Template, field *types.Var) (string, *template.Template, error) { + inputName := field.Name() + syntaxField, err := imports.FieldTag(field.Pos()) + if err == nil && syntaxField != nil && syntaxField.Tag != nil { + unquoted, err := strconv.Unquote(syntaxField.Tag.Value) + if err != nil { + return "", nil, err + } + tags := reflect.StructTag(unquoted) + if name, found := tags.Lookup(InputAttributeNameStructTag); found { + inputName = name + } + if name, found := tags.Lookup(InputAttributeTemplateStructTag); found { + t = t.Lookup(name) + } + } + return inputName, t, nil +} + +func formVariableDeclaration(imports *source.Imports, arg *ast.Ident, tp types.Type) (*ast.DeclStmt, error) { + typeExp, err := astTypeExpression(imports, tp) + if err != nil { + return nil, err + } + return &ast.DeclStmt{ + Decl: &ast.GenDecl{ + Tok: token.VAR, + Specs: []ast.Spec{ + &ast.ValueSpec{ + Names: []*ast.Ident{ast.NewIdent(arg.Name)}, + Type: typeExp, + }, + }, }, - Args: []ast.Expr{}, - }} + }, nil } -func formInputName(field *ast.Field, name *ast.Ident) string { - if field.Tag != nil { - v, _ := strconv.Unquote(field.Tag.Value) - tags := reflect.StructTag(v) - n, hasInputTag := tags.Lookup(InputAttributeNameStructTag) - if hasInputTag { - return n - } +func formVariableAssignment(imports *source.Imports, arg *ast.Ident, tp types.Type) (*ast.DeclStmt, error) { + typeExp, err := astTypeExpression(imports, tp) + if err != nil { + return nil, err + } + return &ast.DeclStmt{ + Decl: &ast.GenDecl{ + Tok: token.VAR, + Specs: []ast.Spec{ + &ast.ValueSpec{ + Names: []*ast.Ident{ast.NewIdent(arg.Name)}, + Type: typeExp, + Values: []ast.Expr{ + &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent("Form"), + }, + }, + }, + }, + }, + }, nil +} + +func httpServMuxField(imports *source.Imports) *ast.Field { + return &ast.Field{ + Names: []*ast.Ident{ast.NewIdent(muxParamName)}, + Type: &ast.StarExpr{X: &ast.SelectorExpr{X: ast.NewIdent(imports.AddNetHTTP()), Sel: ast.NewIdent("ServeMux")}}, } - return name.Name } -func formInputTemplate(field *ast.Field, t *template.Template) *template.Template { - if field.Tag != nil { - v, _ := strconv.Unquote(field.Tag.Value) - tags := reflect.StructTag(v) - n, hasInputTag := tags.Lookup(InputAttributeTemplateStructTag) - if hasInputTag { - return t.Lookup(n) +func generateParseValueFromStringStatements(imports *source.Imports, tmp string, str ast.Expr, valueType types.Type, errCheck func(expr ast.Expr) ast.Stmt, validations []ast.Stmt, assignment func(ast.Expr) ast.Stmt) ([]ast.Stmt, error) { + switch tp := valueType.(type) { + case *types.Basic: + convert := func(exp ast.Expr) ast.Stmt { + return assignment(&ast.CallExpr{ + Fun: ast.NewIdent(tp.Name()), + Args: []ast.Expr{exp}, + }) + } + switch tp.Name() { + default: + return nil, fmt.Errorf("method param type %s not supported", valueType.String()) + case "bool": + return parseBlock(tmp, imports.StrconvParseBoolCall(str), validations, errCheck, assignment), nil + case "int": + return parseBlock(tmp, imports.StrconvAtoiCall(str), validations, errCheck, assignment), nil + case "int8": + return parseBlock(tmp, imports.StrconvParseIntCall(str, 10, 8), validations, errCheck, convert), nil + case "int16": + return parseBlock(tmp, imports.StrconvParseIntCall(str, 10, 16), validations, errCheck, convert), nil + case "int32": + return parseBlock(tmp, imports.StrconvParseIntCall(str, 10, 32), validations, errCheck, convert), nil + case "int64": + return parseBlock(tmp, imports.StrconvParseIntCall(str, 10, 64), validations, errCheck, assignment), nil + case "uint": + return parseBlock(tmp, imports.StrconvParseUintCall(str, 10, 0), validations, errCheck, convert), nil + case "uint8": + return parseBlock(tmp, imports.StrconvParseUintCall(str, 10, 8), validations, errCheck, convert), nil + case "uint16": + return parseBlock(tmp, imports.StrconvParseUintCall(str, 10, 16), validations, errCheck, convert), nil + case "uint32": + return parseBlock(tmp, imports.StrconvParseUintCall(str, 10, 32), validations, errCheck, convert), nil + case "uint64": + return parseBlock(tmp, imports.StrconvParseUintCall(str, 10, 64), validations, errCheck, assignment), nil + case "string": + if len(validations) == 0 { + assign := assignment(str) + statements := slices.Concat(validations, []ast.Stmt{assign}) + return statements, nil + } + statements := slices.Concat([]ast.Stmt{&ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{str}, + }}, validations, []ast.Stmt{assignment(ast.NewIdent(tmp))}) + return statements, nil + } + case *types.Named: + if tp.Obj().Pkg().Path() == "time" && tp.Obj().Name() == "Time" { + return parseBlock(tmp, imports.TimeParseCall(time.DateOnly, str), validations, errCheck, assignment), nil } } - return t + tp, _ := astTypeExpression(imports, valueType) + return nil, fmt.Errorf("unsupported type: %s", source.Format(tp)) } -func generateFuncTypeFromArguments(imports *source.Imports, call *ast.CallExpr) *ast.FuncType { - method := &ast.FuncType{ - Params: &ast.FieldList{}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, +func parseBlock(tmpIdent string, parseCall ast.Expr, validations []ast.Stmt, handleErr, handleResult func(out ast.Expr) ast.Stmt) []ast.Stmt { + const errIdent = "err" + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmpIdent), ast.NewIdent(errIdent)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{parseCall}, } - for _, a := range call.Args { - arg, ok := a.(*ast.Ident) - if !ok { - continue + errCheck := source.ErrorCheckReturn(errIdent, handleErr(&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(errIdent), + Sel: ast.NewIdent("Error"), + }, + Args: []ast.Expr{}, + })) + block := &ast.BlockStmt{List: []ast.Stmt{parse, errCheck}} + block.List = append(block.List, validations...) + block.List = append(block.List, handleResult(ast.NewIdent(tmpIdent))) + return block.List +} + +func callReceiverMethod(imports *source.Imports, dataVarIdent string, method *types.Signature, call *ast.CallExpr) ([]ast.Stmt, error) { + const ( + okIdent = "ok" + ) + if method.Results().Len() == 9 { + mathodIdent := call.Fun.(*ast.Ident) + assert.NotNil(assertion, mathodIdent) + return nil, fmt.Errorf("method %s has no results it should have one or two", mathodIdent.Name) + } else if method.Results().Len() > 1 { + lastResult := method.Results().At(method.Results().Len() - 1) + + errorType := types.Universe.Lookup("error").Type().Underlying().(*types.Interface) + assert.NotNil(assertion, errorType) + + if types.Implements(lastResult.Type(), errorType) { + return []ast.Stmt{ + &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent), ast.NewIdent(errIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}, + &ast.IfStmt{ + Cond: &ast.BinaryExpr{X: ast.NewIdent(errIdent), Op: token.NEQ, Y: source.Nil()}, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.ExprStmt{X: imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), source.CallError(errIdent), http.StatusInternalServerError)}, + &ast.ReturnStmt{}, + }, + }, + }, + }, nil } - switch arg.Name { - case TemplateNameScopeIdentifierHTTPRequest: - method.Params.List = append(method.Params.List, httpRequestField(imports)) - case TemplateNameScopeIdentifierHTTPResponse: - method.Params.List = append(method.Params.List, httpResponseField(imports)) - case TemplateNameScopeIdentifierContext: - method.Params.List = append(method.Params.List, contextContextField(imports)) - case TemplateNameScopeIdentifierForm: - method.Params.List = append(method.Params.List, urlValuesField(imports, arg.Name)) - default: - method.Params.List = append(method.Params.List, pathValueField(arg.Name)) + + if basic, ok := lastResult.Type().(*types.Basic); ok && basic.Kind() == types.Bool { + return []ast.Stmt{ + &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent), ast.NewIdent(okIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}, + &ast.IfStmt{ + Cond: &ast.UnaryExpr{Op: token.NOT, X: ast.NewIdent(okIdent)}, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.ReturnStmt{}, + }, + }, + }, + }, nil } + + return nil, fmt.Errorf("expected last result to be either an error or a bool") + } else { + return []ast.Stmt{&ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}}, nil } - return method } -func fieldListTypes(fieldList *ast.FieldList) func(func(ast.Expr) bool) { - return func(yield func(ast.Expr) bool) { - for _, param := range fieldList.List { - if len(param.Names) == 0 { - if !yield(param.Type) { - return - } - continue +func astTypeExpression(imports *source.Imports, tp types.Type) (ast.Expr, error) { + switch t := tp.(type) { + case *types.Signature: + node := &ast.FuncType{ + Params: new(ast.FieldList), + Results: new(ast.FieldList), + } + + for i := 0; i < t.Params().Len(); i++ { + param := t.Params().At(i) + pt, err := astTypeExpression(imports, param.Type()) + if err != nil { + return nil, err } - for range param.Names { - if !yield(param.Type) { - return - } + var names []*ast.Ident + if param.Name() != "" { + names = []*ast.Ident{ast.NewIdent(param.Name())} } + node.Params.List = append(node.Params.List, &ast.Field{ + Names: names, + Type: pt, + }) + } + for i := 0; i < t.Results().Len(); i++ { + result := t.Results().At(i) + rt, err := astTypeExpression(imports, result.Type()) + if err != nil { + return nil, err + } + node.Results.List = append(node.Results.List, &ast.Field{ + Type: rt, + }) + } + + return node, nil + case *types.Named: + pkg := t.Obj().Pkg() + if pkg != nil && pkg.Name() != "main" && pkg.Path() != imports.OutputPackage() { + return &ast.SelectorExpr{ + X: ast.NewIdent(imports.Add(pkg.Name(), pkg.Path())), + Sel: ast.NewIdent(t.Obj().Name()), + }, nil + } + return ast.NewIdent(t.Obj().Name()), nil + case *types.Slice: + elt, err := astTypeExpression(imports, t.Elem()) + if err != nil { + return nil, err } + return &ast.ArrayType{ + Elt: elt, + }, nil + case *types.Pointer: + x, err := astTypeExpression(imports, t.Elem()) + if err != nil { + return nil, err + } + return &ast.StarExpr{X: x}, nil + case *types.Alias: + pkg := t.Obj().Pkg() + if pkg != nil && pkg.Name() != "main" { + return &ast.SelectorExpr{ + X: ast.NewIdent(imports.Add(pkg.Name(), pkg.Path())), + Sel: ast.NewIdent(t.Obj().Name()), + }, nil + } + return ast.NewIdent(t.Obj().Name()), nil + case *types.Basic: + return ast.NewIdent(t.Name()), nil + default: + assert.Failf(assertion, "", "could not generate type expression for %[1]T %[1]s", tp) + return nil, nil } } -func errWrongNumberOfArguments(def Template, method *ast.FuncType) error { - return fmt.Errorf("handler %s expects %d arguments but call %s has %d", source.Format(&ast.FuncDecl{Name: ast.NewIdent(def.fun.Name), Type: method}), method.Params.NumFields(), def.handler, len(def.call.Args)) +var assertion AssertionFailureReporter + +type AssertionFailureReporter struct{} + +func (AssertionFailureReporter) Errorf(format string, args ...interface{}) { + log.Fatalf(format, args...) } -func compareTypes(expA, expB ast.Expr) error { - if a, b, ok := matchExpressionType[*ast.Ident](expA, expB); ok && a.Name == b.Name { - return nil - } - if a, b, ok := matchExpressionType[*ast.SelectorExpr](expA, expB); ok && a.Sel == b.Sel { - if _, _, ok = matchExpressionType[*ast.Ident](a.X, b.X); ok { - return nil +func defaultTemplateNameScope(imports *source.Imports, template Template, argumentIdentifier string) (types.Type, bool) { + switch argumentIdentifier { + case TemplateNameScopeIdentifierHTTPRequest: + pkg, ok := imports.Types("net/http") + if !ok { + return nil, false + } + t := types.NewPointer(pkg.Scope().Lookup("Request").Type()) + return t, true + case TemplateNameScopeIdentifierHTTPResponse: + pkg, ok := imports.Types("net/http") + if !ok { + return nil, false + } + t := pkg.Scope().Lookup("ResponseWriter").Type() + return t, true + case TemplateNameScopeIdentifierContext: + pkg, ok := imports.Types("context") + if !ok { + return nil, false + } + t := pkg.Scope().Lookup("Context").Type() + return t, true + case TemplateNameScopeIdentifierForm: + pkg, ok := imports.Types("net/url") + if !ok { + return nil, false + } + t := pkg.Scope().Lookup("Values").Type() + return t, true + default: + if slices.Contains(template.parsePathValueNames(), argumentIdentifier) { + return types.Universe.Lookup("string").Type(), true } + return nil, false } - return fmt.Errorf("type %s is not assignable to %s", source.Format(expA), source.Format(expB)) -} - -func matchExpressionType[T ast.Expr](a, b ast.Expr) (T, T, bool) { - ax, aOk := a.(T) - bx, bOk := b.(T) - return ax, bx, aOk && bOk } -func findFormStruct(argType ast.Expr, files []*ast.File) (*ast.StructType, bool) { - if argTypeIdent, ok := argType.(*ast.Ident); ok { - for _, file := range files { - for _, d := range file.Decls { - decl, ok := d.(*ast.GenDecl) - if !ok || decl.Tok != token.TYPE { - continue - } - for _, s := range decl.Specs { - spec := s.(*ast.TypeSpec) - structType, isStruct := spec.Type.(*ast.StructType) - if isStruct && spec.Name.Name == argTypeIdent.Name { - return structType, true +func ensureMethodSignature(imports *source.Imports, t Template, receiver *types.Named, receiverInterface *ast.InterfaceType, call *ast.CallExpr) error { + switch fun := call.Fun.(type) { + case *ast.Ident: + mo, _, _ := types.LookupFieldOrMethod(receiver, true, receiver.Obj().Pkg(), fun.Name) + if mo == nil { + ms, err := createMethodSignature(imports, t, receiver, receiverInterface, call) + if err != nil { + return err + } + fn := types.NewFunc(0, receiver.Obj().Pkg(), fun.Name, ms) + receiver.AddMethod(fn) + mo = fn + } else { + for _, a := range call.Args { + switch arg := a.(type) { + case *ast.CallExpr: + if err := ensureMethodSignature(imports, t, receiver, receiverInterface, arg); err != nil { + return err } } } } - } - return nil, false -} - -func matchSelectorIdents(expr ast.Expr, pkg, name string, star bool) bool { - if star { - st, ok := expr.(*ast.StarExpr) - if !ok { - return false + exp, err := astTypeExpression(imports, mo.Type()) + if err != nil { + return err } - expr = st.X - } - sel, ok := expr.(*ast.SelectorExpr) - if !ok { - return false + receiverInterface.Methods.List = append(receiverInterface.Methods.List, &ast.Field{ + Names: []*ast.Ident{ast.NewIdent(fun.Name)}, + Type: exp, + }) + return nil + default: + return fmt.Errorf("expected a method identifier") } - id, ok := sel.X.(*ast.Ident) - return ok && sel.Sel.Name == name && id.Name == pkg } -func pathValueField(name string) *ast.Field { - return &ast.Field{ - Type: ast.NewIdent("string"), - Names: []*ast.Ident{ast.NewIdent(name)}, +func createMethodSignature(imports *source.Imports, t Template, receiver *types.Named, receiverInterface *ast.InterfaceType, call *ast.CallExpr) (*types.Signature, error) { + var params []*types.Var + for _, a := range call.Args { + switch arg := a.(type) { + case *ast.Ident: + tp, ok := defaultTemplateNameScope(imports, t, arg.Name) + if !ok { + return nil, fmt.Errorf("could not determine a type for %s", arg.Name) + } + params = append(params, types.NewVar(0, receiver.Obj().Pkg(), arg.Name, tp)) + case *ast.CallExpr: + if err := ensureMethodSignature(imports, t, receiver, receiverInterface, arg); err != nil { + return nil, err + } + } } + results := types.NewTuple(types.NewVar(0, nil, "", types.Universe.Lookup("any").Type())) + return types.NewSignatureType(types.NewVar(0, nil, "", receiver.Obj().Type()), nil, nil, types.NewTuple(params...), results, false), nil } -func contextContextField(imports *source.Imports) *ast.Field { - return &ast.Field{ - Names: []*ast.Ident{ast.NewIdent(TemplateNameScopeIdentifierContext)}, - Type: contextContextType(imports.AddContext()), +func hasHTTPResponseWriterArgument(call *ast.CallExpr) bool { + for _, a := range call.Args { + switch arg := a.(type) { + case *ast.Ident: + if arg.Name == TemplateNameScopeIdentifierHTTPResponse { + return true + } + case *ast.CallExpr: + if hasHTTPResponseWriterArgument(arg) { + return true + } + } } + return false } -func httpResponseField(imports *source.Imports) *ast.Field { - return &ast.Field{ - Names: []*ast.Ident{ast.NewIdent(TemplateNameScopeIdentifierHTTPResponse)}, - Type: &ast.SelectorExpr{X: ast.NewIdent(imports.AddNetHTTP()), Sel: ast.NewIdent(httpResponseWriterIdent)}, +func errCheck(imports *source.Imports) func(msg ast.Expr) ast.Stmt { + return func(msg ast.Expr) ast.Stmt { + return &ast.ExprStmt{ + X: imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), msg, http.StatusBadRequest), + } } } -func routesFuncType(imports *source.Imports, receiverType ast.Expr) *ast.FuncType { - return &ast.FuncType{Params: &ast.FieldList{ - List: []*ast.Field{ - {Names: []*ast.Ident{ast.NewIdent(muxVarIdent)}, Type: &ast.StarExpr{ - X: &ast.SelectorExpr{X: ast.NewIdent(imports.AddNetHTTP()), Sel: ast.NewIdent(httpServeMuxIdent)}, - }}, - {Names: []*ast.Ident{ast.NewIdent(receiverIdent)}, Type: receiverType}, +func callParseForm() *ast.ExprStmt { + return &ast.ExprStmt{X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent("ParseForm"), }, + Args: []ast.Expr{}, }} } -func urlValuesField(imports *source.Imports, ident string) *ast.Field { +func httpResponseField(imports *source.Imports) *ast.Field { return &ast.Field{ - Names: []*ast.Ident{ast.NewIdent(ident)}, - Type: &ast.SelectorExpr{X: ast.NewIdent(imports.Add("", "net/url")), Sel: ast.NewIdent("Values")}, + Names: []*ast.Ident{ast.NewIdent(TemplateNameScopeIdentifierHTTPResponse)}, + Type: &ast.SelectorExpr{X: ast.NewIdent(imports.AddNetHTTP()), Sel: ast.NewIdent(httpResponseWriterIdent)}, } } @@ -720,10 +846,6 @@ func httpHandlerFuncType(imports *source.Imports) *ast.FuncType { return &ast.FuncType{Params: &ast.FieldList{List: []*ast.Field{httpResponseField(imports), httpRequestField(imports)}}} } -func contextContextType(contextPackageIdent string) *ast.SelectorExpr { - return &ast.SelectorExpr{X: ast.NewIdent(contextPackageIdent), Sel: ast.NewIdent(contextContextTypeIdent)} -} - func contextAssignment(ident string) *ast.AssignStmt { return &ast.AssignStmt{ Tok: token.DEFINE, @@ -737,48 +859,6 @@ func contextAssignment(ident string) *ast.AssignStmt { } } -func formDeclaration(imports *source.Imports, ident string, typeExp ast.Expr) *ast.DeclStmt { - if matchSelectorIdents(typeExp, imports.Ident("net/url"), "Values", false) { - imports.Add("", "net/url") - return &ast.DeclStmt{ - Decl: &ast.GenDecl{ - Tok: token.VAR, - Specs: []ast.Spec{ - &ast.ValueSpec{ - Names: []*ast.Ident{ast.NewIdent(ident)}, - Type: typeExp, - Values: []ast.Expr{ - &ast.SelectorExpr{X: ast.NewIdent(httpRequestField(imports).Names[0].Name), Sel: ast.NewIdent("Form")}, - }, - }, - }, - }, - } - } - return &ast.DeclStmt{ - Decl: &ast.GenDecl{ - Tok: token.VAR, - Specs: []ast.Spec{ - &ast.ValueSpec{ - Names: []*ast.Ident{ast.NewIdent(ident)}, - Type: typeExp, - }, - }, - }, - } -} - -func httpPathValueAssignment(imports *source.Imports, method *ast.FuncType, i int, arg *ast.Ident, str ast.Expr, assignTok token.Token, errCheck func(stmt ast.Expr) ast.Stmt) ([]ast.Stmt, error) { - for typeIndex, typeExp := range source.IterateFieldTypes(method.Params.List) { - if typeIndex != i { - continue - } - assignment := singleAssignment(assignTok, ast.NewIdent(arg.Name)) - return source.GenerateParseValueFromStringStatements(imports, arg.Name+"Parsed", str, typeExp, errCheck, nil, assignment) - } - return nil, fmt.Errorf("type for argumement %d not found", i) -} - func singleAssignment(assignTok token.Token, result ast.Expr) func(exp ast.Expr) ast.Stmt { return func(exp ast.Expr) ast.Stmt { return &ast.AssignStmt{ @@ -789,53 +869,6 @@ func singleAssignment(assignTok token.Token, result ast.Expr) func(exp ast.Expr) } } -func appendAssignment(assignTok token.Token, result ast.Expr) func(exp ast.Expr) ast.Stmt { - return func(exp ast.Expr) ast.Stmt { - return &ast.AssignStmt{ - Lhs: []ast.Expr{result}, - Tok: assignTok, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent("append"), - Args: []ast.Expr{result, exp}, - }}, - } - } -} - -func (t Template) executeCall(status, data ast.Expr, writeHeader bool) *ast.ExprStmt { - return &ast.ExprStmt{X: &ast.CallExpr{ - Fun: ast.NewIdent(executeIdentName), - Args: []ast.Expr{ - ast.NewIdent(TemplateNameScopeIdentifierHTTPResponse), - ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - ast.NewIdent(strconv.FormatBool(writeHeader)), - &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(t.name)}, - status, - data, - }, - }} -} - -func (t Template) httpRequestReceiverTemplateHandlerFunc(imports *source.Imports, statusCode int) *ast.FuncLit { - return &ast.FuncLit{ - Type: httpHandlerFuncType(imports), - Body: &ast.BlockStmt{List: []ast.Stmt{t.executeCall(source.HTTPStatusCode(imports, statusCode), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}}, - } -} - -func (t Template) matchReceiver(funcDecl *ast.FuncDecl, receiverTypeIdent string) bool { - if funcDecl == nil || funcDecl.Name == nil || funcDecl.Name.Name != t.fun.Name || - funcDecl.Recv == nil || len(funcDecl.Recv.List) < 1 { - return false - } - exp := funcDecl.Recv.List[0].Type - if star, ok := exp.(*ast.StarExpr); ok { - exp = star.X - } - ident, ok := exp.(*ast.Ident) - return ok && ident.Name == receiverTypeIdent -} - func executeFuncDecl(imports *source.Imports, templatesVariableIdent string) *ast.FuncDecl { const writeHeaderIdent = "writeHeader" return &ast.FuncDecl{ diff --git a/routes_test.go b/routes_test.go index 41a03ec..b3cd7d8 100644 --- a/routes_test.go +++ b/routes_test.go @@ -10,11 +10,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/tools/go/packages" "golang.org/x/tools/txtar" "github.com/crhntr/muxt" - "github.com/crhntr/muxt/internal/source" ) func TestGenerate(t *testing.T) { @@ -444,6 +442,8 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { -- receiver.go -- package main +import "context" + type T struct{} func (T) F(ctx context.Context) int { return 30 } @@ -872,13 +872,16 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { `, }, { - Name: "F is defined and form has two string fields", + Name: "form argument has typed parameters", Templates: `{{define "GET / F(form)"}}Hello, {{.}}!{{end}}`, ReceiverPackage: ` -- in.go -- package main -import "net/http" +import ( + "net/http" +"time" +) type ( T struct{} @@ -890,7 +893,6 @@ type ( fieldInt8 int8 fieldUint uint fieldUint64 uint64 - fieldUint16 uint16 fieldUint32 uint32 fieldUint16 uint16 fieldUint8 uint8 @@ -976,14 +978,6 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { } form.fieldUint64 = value } - { - value, err := strconv.ParseUint(request.FormValue("fieldUint16"), 10, 16) - if err != nil { - http.Error(response, err.Error(), http.StatusBadRequest) - return - } - form.fieldUint16 = uint16(value) - } { value, err := strconv.ParseUint(request.FormValue("fieldUint32"), 10, 32) if err != nil { @@ -1131,7 +1125,6 @@ type ( fieldInt8 []int8 fieldUint []uint fieldUint64 []uint64 - fieldUint16 []uint16 fieldUint32 []uint32 fieldUint16 []uint16 fieldUint8 []uint8 @@ -1216,14 +1209,6 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { } form.fieldUint64 = append(form.fieldUint64, value) } - for _, val := range request.Form["fieldUint16"] { - value, err := strconv.ParseUint(val, 10, 16) - if err != nil { - http.Error(response, err.Error(), http.StatusBadRequest) - return - } - form.fieldUint16 = append(form.fieldUint16, uint16(value)) - } for _, val := range request.Form["fieldUint32"] { value, err := strconv.ParseUint(val, 10, 32) if err != nil { @@ -1414,6 +1399,8 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { ReceiverPackage: `-- in.go -- package main +import "context" + type T struct{} func (T) F(ctx context.Context) any {return nil} @@ -1714,7 +1701,7 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { } id := idParsed data := receiver.F(ctx, result0, id) - execute(response, request, true, "GET /{id} F(ctx, Session(response, request), id)", http.StatusOK, data) + execute(response, request, false, "GET /{id} F(ctx, Session(response, request), id)", http.StatusOK, data) }) } `, @@ -1733,12 +1720,12 @@ import ( type ( T struct{} - User struct{} + Session struct{} ) -func (T) F(context.Context, S, int) any {return nil} +func (T) F(context.Context, Session, int) any {return nil} -func (T) Author(int) (User, error) {return Session{}, nil} +func (T) Author(int) (Session, error) {return Session{}, nil} func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} `, @@ -1751,8 +1738,8 @@ import ( ) type RoutesReceiver interface { - Author(int) (User, error) - F(context.Context, S, int) any + Author(int) (Session, error) + F(context.Context, Session, int) any } func routes(mux *http.ServeMux, receiver RoutesReceiver) { @@ -1794,7 +1781,7 @@ type ( func (T) F(context.Context, Configuration) any {return nil} -func (T) LoadConfiguration() (_ Configuration) { return } +func (T) LoadConfiguration() Configuration { return } func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} `, @@ -1806,7 +1793,7 @@ import ( ) type RoutesReceiver interface { - LoadConfiguration() (_ Configuration) + LoadConfiguration() Configuration F(context.Context, Configuration) any } @@ -1821,7 +1808,7 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { `, }, { - Name: "call expression argument", + Name: "call expression argument with response argument", Templates: `{{define "GET / F(ctx, Headers(response))"}}{{end}}`, Receiver: "T", ReceiverPackage: `-- in.go -- @@ -1937,13 +1924,28 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { } { t.Run(tt.Name, func(t *testing.T) { ts := template.Must(template.New(tt.Name).Parse(tt.Templates)) - templateNames, err := muxt.Templates(ts) + templates, err := muxt.Templates(ts) require.NoError(t, err) - logs := log.New(io.Discard, "", 0) - pkg := loadPackage(t, tt.ReceiverPackage) - out, err := muxt.Routes(templateNames, tt.PackageName, tt.TemplatesVar, tt.RoutesFunc, tt.Receiver, tt.Interface, muxt.DefaultOutputFileName, pkg, logs) + logger := log.New(io.Discard, "", 0) + + archive := txtar.Parse([]byte(tt.ReceiverPackage)) + archiveDir, err := txtar.FS(archive) + require.NoError(t, err) + + dir := t.TempDir() + require.NoError(t, os.CopyFS(dir, archiveDir)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module example.com\n\ngo 1.20\n"), 0644)) + out, err := muxt.TemplateRoutesFile(dir, templates, logger, muxt.RoutesFileConfiguration{ + ReceiverInterface: tt.Interface, + Package: tt.PackageName, + TemplatesVar: tt.TemplatesVar, + RoutesFunc: tt.RoutesFunc, + PackagePath: "example.com", + ReceiverType: tt.Receiver, + Output: "template_routes.go", + }) if tt.ExpectedError == "" { - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, tt.ExpectedFile, out) } else { assert.ErrorContains(t, err, tt.ExpectedError) @@ -1952,21 +1954,21 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { } } -func loadPackage(t *testing.T, in string) []*packages.Package { - t.Helper() - archive := txtar.Parse([]byte(in)) - archiveDir, err := txtar.FS(archive) - require.NoError(t, err) - - dir := t.TempDir() - require.NoError(t, os.CopyFS(dir, archiveDir)) - require.NoError(t, os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module example.com\n"), 0644)) - - packageList, err := source.Load(dir, "./...") - require.NoError(t, err) - - return packageList -} +//func loadPackage(t *testing.T, in string) []*packages.Package { +// t.Helper() +// archive := txtar.Parse([]byte(in)) +// archiveDir, err := txtar.FS(archive) +// require.NoError(t, err) +// +// dir := t.TempDir() +// require.NoError(t, os.CopyFS(dir, archiveDir)) +// require.NoError(t, os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module example.com\n"), 0644)) +// +// packageList, err := source.Load(dir, "./...") +// require.NoError(t, err) +// +// return packageList +//} const executeGo = `-- execute.go -- package main diff --git a/template.go b/template.go index 370f3ed..33d8a7d 100644 --- a/template.go +++ b/template.go @@ -262,3 +262,97 @@ func patternScope() []string { TemplateNameScopeIdentifierForm, } } + +func (t Template) executeCall(status, data ast.Expr, writeHeader bool) *ast.ExprStmt { + return &ast.ExprStmt{X: &ast.CallExpr{ + Fun: ast.NewIdent(executeIdentName), + Args: []ast.Expr{ + ast.NewIdent(TemplateNameScopeIdentifierHTTPResponse), + ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + ast.NewIdent(strconv.FormatBool(writeHeader)), + &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(t.name)}, + status, + data, + }, + }} +} + +func (t Template) httpRequestReceiverTemplateHandlerFunc(imports *source.Imports, statusCode int) *ast.FuncLit { + return &ast.FuncLit{ + Type: httpHandlerFuncType(imports), + Body: &ast.BlockStmt{List: []ast.Stmt{t.executeCall(source.HTTPStatusCode(imports, statusCode), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}}, + } +} + +func (t Template) matchReceiver(funcDecl *ast.FuncDecl, receiverTypeIdent string) bool { + if funcDecl == nil || funcDecl.Name == nil || funcDecl.Name.Name != t.fun.Name || + funcDecl.Recv == nil || len(funcDecl.Recv.List) < 1 { + return false + } + exp := funcDecl.Recv.List[0].Type + if star, ok := exp.(*ast.StarExpr); ok { + exp = star.X + } + ident, ok := exp.(*ast.Ident) + return ok && ident.Name == receiverTypeIdent +} + +func (t Template) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStmt { + return &ast.ExprStmt{X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(muxVarIdent), + Sel: ast.NewIdent(httpHandleFuncIdent), + }, + Args: []ast.Expr{source.String(t.endpoint), handlerFuncLit}, + }} +} + +func (t Template) callReceiverMethod(imports *source.Imports, dataVarIdent string, method *ast.FuncType, call *ast.CallExpr) ([]ast.Stmt, error) { + const ( + okIdent = "ok" + ) + if method.Results == nil || len(method.Results.List) == 0 { + return nil, fmt.Errorf("method for endpoint %q has no results it should have one or two", t) + } else if len(method.Results.List) > 1 { + _, lastResultType, ok := source.FieldIndex(method.Results.List, method.Results.NumFields()-1) + if !ok { + return nil, fmt.Errorf("failed to get the last method result") + } + switch rt := lastResultType.(type) { + case *ast.Ident: + switch rt.Name { + case "error": + return []ast.Stmt{ + &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent), ast.NewIdent(errIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}, + &ast.IfStmt{ + Cond: &ast.BinaryExpr{X: ast.NewIdent(errIdent), Op: token.NEQ, Y: source.Nil()}, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.ExprStmt{X: imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), source.CallError(errIdent), http.StatusInternalServerError)}, + &ast.ReturnStmt{}, + }, + }, + }, + }, nil + case "bool": + return []ast.Stmt{ + &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent), ast.NewIdent(okIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}, + &ast.IfStmt{ + Cond: &ast.UnaryExpr{Op: token.NOT, X: ast.NewIdent(okIdent)}, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.ReturnStmt{}, + }, + }, + }, + }, nil + default: + return nil, fmt.Errorf("expected last result to be either an error or a bool") + } + default: + return nil, fmt.Errorf("expected last result to be either an error or a bool") + } + } else { + return []ast.Stmt{&ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}}, nil + } +}