diff --git a/cmd/muxt/generate.go b/cmd/muxt/generate.go index 1435dae..2547a7f 100644 --- a/cmd/muxt/generate.go +++ b/cmd/muxt/generate.go @@ -104,11 +104,15 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string) if err != nil { return err } - out := log.New(stdout, "", 0) - s, err := muxt.Routes(templates, g.goPackage, g.templatesVariable, g.routesFunction, g.receiverIdent, g.receiverInterfaceIdent, g.outputFilename, packageList, out) - if err != nil { - return err - } + s, err := muxt.TemplateRoutesFile(workingDirectory, templates, log.New(stdout, "", 0), muxt.RoutesFileConfiguration{ + Package: g.goPackage, + PackagePath: g.Package.PkgPath, + TemplatesVar: g.templatesVariable, + RoutesFunc: g.routesFunction, + ReceiverType: g.receiverIdent, + ReceiverInterface: g.receiverInterfaceIdent, + Output: g.outputFilename, + }) var sb bytes.Buffer sb.WriteString(CodeGenerationComment) if v, ok := cliVersion(); ok { diff --git a/internal/source/html.go b/internal/source/html.go index c45c4fd..2845388 100644 --- a/internal/source/html.go +++ b/internal/source/html.go @@ -5,6 +5,7 @@ import ( "fmt" "go/ast" "go/token" + "go/types" "regexp" "slices" "strings" @@ -17,7 +18,7 @@ type ValidationGenerator interface { GenerateValidation(imports *Imports, variable ast.Expr, handleError func(string) ast.Stmt) ast.Stmt } -func ParseInputValidations(name string, input spec.Element, tp ast.Expr) ([]ValidationGenerator, error) { +func ParseInputValidations(name string, input spec.Element, tp types.Type) ([]ValidationGenerator, error) { if tag := strings.ToLower(input.TagName()); tag != atom.Input.String() { return nil, fmt.Errorf("expected element to have tag got <%s>", tag) } diff --git a/internal/source/imports.go b/internal/source/imports.go index 36231f6..705ee2b 100644 --- a/internal/source/imports.go +++ b/internal/source/imports.go @@ -1,8 +1,11 @@ package source import ( + "fmt" "go/ast" + "go/parser" "go/token" + "go/types" "log" "path" "slices" @@ -12,6 +15,10 @@ import ( type Imports struct { *ast.GenDecl + fileSet *token.FileSet + types map[string]*types.Package + files map[string]*ast.File + outputPackage string } func NewImports(decl *ast.GenDecl) *Imports { @@ -20,7 +27,82 @@ func NewImports(decl *ast.GenDecl) *Imports { log.Panicf("expected decl to have token.IMPORT Tok got %s", got) } } - return &Imports{GenDecl: decl} + return &Imports{GenDecl: decl, types: make(map[string]*types.Package), files: make(map[string]*ast.File)} +} + +func (imports *Imports) AddPackages(p *types.Package) { + recursivelyRegisterPackages(imports.types, p) +} + +func (imports *Imports) FileSet() *token.FileSet { + if imports.fileSet == nil { + imports.fileSet = token.NewFileSet() + } + return imports.fileSet +} + +func (imports *Imports) SetOutputPackage(pkgPath string) { + imports.outputPackage = pkgPath +} + +func (imports *Imports) OutputPackage() string { + return imports.outputPackage +} + +func (imports *Imports) SyntaxFile(pos token.Pos) (*ast.File, *token.FileSet, error) { + position := imports.FileSet().Position(pos) + fSet := token.NewFileSet() + file, err := parser.ParseFile(fSet, position.Filename, nil, parser.AllErrors|parser.ParseComments|parser.SkipObjectResolution) + return file, fSet, err +} + +func (imports *Imports) FieldTag(pos token.Pos) (*ast.Field, error) { + file, fileSet, err := imports.SyntaxFile(pos) + if err != nil { + return nil, err + } + position := imports.fileSet.Position(pos) + for _, d := range file.Decls { + switch decl := d.(type) { + case *ast.GenDecl: + for _, s := range decl.Specs { + switch spec := s.(type) { + case *ast.TypeSpec: + tp, ok := spec.Type.(*ast.StructType) + if !ok { + continue + } + + for _, field := range tp.Fields.List { + for _, name := range field.Names { + p := fileSet.Position(name.Pos()) + if p != position { + continue + } + return field, nil + } + } + } + } + } + + } + return nil, fmt.Errorf("failed to find field") +} + +func (imports *Imports) Types(pkgPath string) (*types.Package, bool) { + p, ok := imports.types[pkgPath] + return p, ok +} + +func recursivelyRegisterPackages(set map[string]*types.Package, pkg *types.Package) { + if pkg == nil { + return + } + set[pkg.Path()] = pkg + for _, p := range pkg.Imports() { + recursivelyRegisterPackages(set, p) + } } func (imports *Imports) Add(pkgIdent, pkgPath string) string { @@ -31,27 +113,29 @@ func (imports *Imports) Add(pkgIdent, pkgPath string) string { if pkgIdent == "" { pkgIdent = path.Base(pkgPath) } - for _, s := range imports.GenDecl.Specs { - spec := s.(*ast.ImportSpec) - pp, _ := strconv.Unquote(spec.Path.Value) - if pp == pkgPath { - if spec.Name != nil && spec.Name.Name != "" && spec.Name.Name != pkgIdent { - return spec.Name.Name + if pkgPath != imports.outputPackage { + for _, s := range imports.GenDecl.Specs { + spec := s.(*ast.ImportSpec) + pp, _ := strconv.Unquote(spec.Path.Value) + if pp == pkgPath { + if spec.Name != nil && spec.Name.Name != "" && spec.Name.Name != pkgIdent { + return spec.Name.Name + } + return path.Base(pp) } - return path.Base(pp) } + var pi *ast.Ident + if path.Base(pkgPath) != pkgIdent { + pi = Ident(pkgIdent) + } + imports.GenDecl.Specs = append(imports.GenDecl.Specs, &ast.ImportSpec{ + Path: String(pkgPath), + Name: pi, + }) + slices.SortFunc(imports.GenDecl.Specs, func(a, b ast.Spec) int { + return strings.Compare(a.(*ast.ImportSpec).Path.Value, b.(*ast.ImportSpec).Path.Value) + }) } - var pi *ast.Ident - if path.Base(pkgPath) != pkgIdent { - pi = Ident(pkgIdent) - } - imports.GenDecl.Specs = append(imports.GenDecl.Specs, &ast.ImportSpec{ - Path: String(pkgPath), - Name: pi, - }) - slices.SortFunc(imports.GenDecl.Specs, func(a, b ast.Spec) int { - return strings.Compare(a.(*ast.ImportSpec).Path.Value, b.(*ast.ImportSpec).Path.Value) - }) return pkgIdent } diff --git a/internal/source/parse.go b/internal/source/parse.go index 6aadd41..a0dd6ca 100644 --- a/internal/source/parse.go +++ b/internal/source/parse.go @@ -4,6 +4,7 @@ import ( "fmt" "go/ast" "go/token" + "go/types" "net/http" "regexp" "slices" @@ -71,7 +72,7 @@ func GenerateParseValueFromStringStatements(imports *Imports, tmp string, str, t return nil, fmt.Errorf("unsupported type: %s", Format(typeExp)) } -func GenerateValidations(imports *Imports, variable, variableType ast.Expr, inputQuery, inputName, responseIdent string, fragment spec.DocumentFragment) ([]ast.Stmt, error, bool) { +func GenerateValidations(imports *Imports, variable ast.Expr, variableType types.Type, inputQuery, inputName, responseIdent string, fragment spec.DocumentFragment) ([]ast.Stmt, error, bool) { input := fragment.QuerySelector(inputQuery) if input == nil { return nil, nil, false diff --git a/internal/source/parse_test.go b/internal/source/parse_test.go index 6f9b237..99a2853 100644 --- a/internal/source/parse_test.go +++ b/internal/source/parse_test.go @@ -3,7 +3,7 @@ package source_test import ( "fmt" "go/ast" - "go/parser" + "go/types" "html/template" "strings" "testing" @@ -21,20 +21,20 @@ import ( func Test_inputValidations(t *testing.T) { for _, tt := range []struct { Name string - Type string + Type types.Type Template string Result string Error string }{ { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "no attributes", Template: ``, Result: `{ }`, }, { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "min", Template: ``, Result: `{ @@ -45,7 +45,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "negative min", Template: ``, Result: `{ @@ -56,7 +56,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "zero min", Template: ``, Result: `{ @@ -67,7 +67,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int8", + Type: types.Universe.Lookup("int8").Type(), Name: "zero min", Template: ``, Result: `{ @@ -78,7 +78,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int16", + Type: types.Universe.Lookup("int16").Type(), Name: "zero min", Template: ``, Result: `{ @@ -89,7 +89,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int32", + Type: types.Universe.Lookup("int32").Type(), Name: "zero min", Template: ``, Result: `{ @@ -100,7 +100,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int64", + Type: types.Universe.Lookup("int64").Type(), Name: "zero min", Template: ``, Result: `{ @@ -111,7 +111,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "uint", + Type: types.Universe.Lookup("uint").Type(), Name: "zero min", Template: ``, Result: `{ @@ -122,7 +122,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "uint8", + Type: types.Universe.Lookup("uint8").Type(), Name: "zero min", Template: ``, Result: `{ @@ -133,7 +133,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "uint16", + Type: types.Universe.Lookup("uint16").Type(), Name: "zero min", Template: ``, Result: `{ @@ -144,7 +144,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "uint32", + Type: types.Universe.Lookup("uint32").Type(), Name: "zero min", Template: ``, Result: `{ @@ -155,7 +155,7 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "uint64", + Type: types.Universe.Lookup("uint64").Type(), Name: "zero min", Template: ``, Result: `{ @@ -166,91 +166,79 @@ func Test_inputValidations(t *testing.T) { }`, }, { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseInt: parsing "18446744073709551616": value out of range`, }, { - Type: "int8", + Type: types.Universe.Lookup("int8").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseInt: parsing "256": value out of range`, }, { - Type: "int16", + Type: types.Universe.Lookup("int16").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseInt: parsing "32768": value out of range`, }, { - Type: "int32", + Type: types.Universe.Lookup("int32").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseInt: parsing "2147483648": value out of range`, }, { - Type: "int64", + Type: types.Universe.Lookup("int64").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseInt: parsing "9223372036854775808": value out of range`, }, { - Type: "uint", + Type: types.Universe.Lookup("uint").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseUint: parsing "-10": invalid syntax`, }, { - Type: "uint8", + Type: types.Universe.Lookup("uint8").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseUint: parsing "256": value out of range`, }, { - Type: "uint16", + Type: types.Universe.Lookup("uint16").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseUint: parsing "65536": value out of range`, }, { - Type: "uint32", + Type: types.Universe.Lookup("uint32").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseUint: parsing "4294967296": value out of range`, }, { - Type: "uint64", + Type: types.Universe.Lookup("uint64").Type(), Name: "out of range", Template: ``, Error: `strconv.ParseUint: parsing "18446744073709551616": value out of range`, }, { - Type: "*T", - Name: "unsupported type", - Template: ``, - Error: `type *T is not supported`, - }, - { - Type: "T", - Name: "type unknown", - Template: ``, - Error: `type T unknown`, - }, - { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "not a number", Template: ``, Error: `strconv.ParseInt: parsing "NaN": invalid syntax`, }, { - Type: "int", + Type: types.Universe.Lookup("int").Type(), Name: "wrong tag", Template: `
`, Error: `expected element to have tag got