From 7d85f2027ed86b1c9e4f29512d0ce0236d30a36d Mon Sep 17 00:00:00 2001 From: Chrstopher Hunter <8398225+crhntr@users.noreply.github.com> Date: Thu, 28 Nov 2024 14:19:40 -0800 Subject: [PATCH] feat: use a package other than the routes package for receiver --- cmd/generate-readme/README.md.template | 4 +- cmd/muxt/generate.go | 1 + ...and_routes_are_in_different_packages.txtar | 83 +++++++++++++++++++ internal/source/imports.go | 32 ++++++- routes.go | 62 ++++++++------ routes_test.go | 1 - 6 files changed, 151 insertions(+), 32 deletions(-) create mode 100644 cmd/muxt/testdata/generate/receiver_and_routes_are_in_different_packages.txtar diff --git a/cmd/generate-readme/README.md.template b/cmd/generate-readme/README.md.template index 4c491f7..4dfcbf4 100644 --- a/cmd/generate-readme/README.md.template +++ b/cmd/generate-readme/README.md.template @@ -173,11 +173,11 @@ type RoutesReceiver interface { } func routes(mux *http.ServeMux, receiver RoutesReceiver) { - mux.HandleFunc("GET /articles/:id", func(response http.ResponseWriter, request *http.Request) { + mux.HandleFunc("GET /articles/{id}", func(response http.ResponseWriter, request *http.Request) { ctx := request.Context() id := request.PathValue("id") data := receiver.ReadArticle(ctx, id) - execute(response, request, true, "GET /articles/:id ReadArticle(ctx, id)", http.StatusOK, data) + execute(response, request, true, "GET /articles/{id} ReadArticle(ctx, id)", http.StatusOK, data) }) } diff --git a/cmd/muxt/generate.go b/cmd/muxt/generate.go index 099db36..e580b20 100644 --- a/cmd/muxt/generate.go +++ b/cmd/muxt/generate.go @@ -23,6 +23,7 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string) TemplatesVar: g.TemplatesVariable, RoutesFunc: g.RoutesFunction, ReceiverType: g.ReceiverIdent, + ReceiverPackage: g.ReceiverStaticTypePackage, ReceiverInterface: g.ReceiverInterfaceIdent, Output: g.OutputFilename, }) diff --git a/cmd/muxt/testdata/generate/receiver_and_routes_are_in_different_packages.txtar b/cmd/muxt/testdata/generate/receiver_and_routes_are_in_different_packages.txtar new file mode 100644 index 0000000..0504cea --- /dev/null +++ b/cmd/muxt/testdata/generate/receiver_and_routes_are_in_different_packages.txtar @@ -0,0 +1,83 @@ + +cd internal/hypertext + +muxt generate --receiver-type=Handler --receiver-type-package=crhntr.com/muxt-test/internal/endpoints --routes-func=Routes + +cd ../../ + +exec go test + +-- go.mod -- +module crhntr.com/muxt-test + +go 1.23 + +-- main_test.go -- +package main_test + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "crhntr.com/muxt-test/internal/endpoints" + "crhntr.com/muxt-test/internal/hypertext" +) + +func Test(t *testing.T) { + mux := http.NewServeMux() + var h endpoints.Handler + + hypertext.Routes(mux, h) + + t.Run("GET", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + res := rec.Result() + + if res.StatusCode != http.StatusOK { + t.Fail() + } + + body, err := io.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + + if !strings.Contains(string(body), `

result

`) { + t.Error("expected output text to contain result", string(body)) + t.Log("got", string(body)) + } + }) +} +-- internal/endpoints/server.go -- +package endpoints + +type Handler struct{} + +func (Handler) F() string { + return "result" +} +-- internal/hypertext/generate.go -- +package hypertext + +import ( + "embed" + "html/template" +) + +var ( + //go:embed *.gohtml + templateFiles embed.FS + + templates = template.Must(template.ParseFS(templateFiles, "*")) +) +-- internal/hypertext/form.gohtml -- +{{- define "GET /{$} F()" -}} +

{{.}}

+{{- end -}} diff --git a/internal/source/imports.go b/internal/source/imports.go index 8f5deb3..c3f05a2 100644 --- a/internal/source/imports.go +++ b/internal/source/imports.go @@ -7,8 +7,10 @@ import ( "go/parser" "go/token" "go/types" + "golang.org/x/tools/go/packages" "log" "path" + "path/filepath" "slices" "strconv" "strings" @@ -19,6 +21,7 @@ type Imports struct { fileSet *token.FileSet types map[string]*types.Package files map[string]*ast.File + packages []*packages.Package outputPackage string } @@ -31,8 +34,33 @@ func NewImports(decl *ast.GenDecl) *Imports { return &Imports{GenDecl: decl, types: make(map[string]*types.Package), files: make(map[string]*ast.File)} } -func (imports *Imports) AddPackages(p *types.Package) { - recursivelyRegisterPackages(imports.types, p) +func (imports *Imports) Package(path string) (*packages.Package, bool) { + for _, pkg := range imports.packages { + if pkg.PkgPath == path { + return pkg, true + } + } + return nil, false +} + +func (imports *Imports) AddPackages(packages ...*packages.Package) { + imports.packages = slices.Grow(imports.packages, len(packages)) + for _, pkg := range packages { + if pkg == nil { + continue + } + recursivelyRegisterPackages(imports.types, pkg.Types) + imports.packages = append(imports.packages, pkg) + } +} + +func (imports *Imports) PackageAtFilepath(p string) (*packages.Package, bool) { + for _, pkg := range imports.packages { + if len(pkg.GoFiles) > 0 && filepath.Dir(pkg.GoFiles[0]) == p { + return pkg, true + } + } + return nil, false } func (imports *Imports) FileSet() *token.FileSet { diff --git a/routes.go b/routes.go index 692ea3d..1a2112b 100644 --- a/routes.go +++ b/routes.go @@ -60,6 +60,7 @@ type RoutesFileConfiguration struct { TemplatesVar, RoutesFunc, ReceiverType, + ReceiverPackage, ReceiverInterface, Output string } @@ -80,33 +81,51 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur } imports := source.NewImports(&ast.GenDecl{Tok: token.IMPORT}) - var pkg *types.Package + patterns := []string{ + wd, "net/http", + } + + if config.ReceiverPackage != "" { + patterns = append(patterns, config.ReceiverPackage) + } + pl, err := packages.Load(&packages.Config{ Fset: imports.FileSet(), Mode: packages.NeedModule | packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedSyntax | packages.NeedEmbedPatterns | packages.NeedEmbedFiles, Dir: wd, - }, ".", "net/http") + }, patterns...) if err != nil { return "", err } + imports.AddPackages(pl...) + routesPkg, ok := imports.PackageAtFilepath(wd) + if !ok { + return "", fmt.Errorf("could not find package in working directory %q", wd) + } + imports.SetOutputPackage(routesPkg.Types) + config.PackagePath = routesPkg.PkgPath + config.Package = routesPkg.Name var receiver *types.Named - for _, p := range pl { - imports.AddPackages(p.Types) - } - var currentPkg *packages.Package - if len(pl) == 2 && pl[0].PkgPath == "net/http" { - imports.SetOutputPackage(pl[1].Types) - currentPkg = pl[1] - } else if len(pl) == 2 && pl[1].PkgPath == "net/http" { - imports.SetOutputPackage(pl[0].Types) - currentPkg = pl[0] + if config.ReceiverType != "" { + receiverPkgPath := cmp.Or(config.ReceiverPackage, config.PackagePath) + receiverPkg, ok := imports.Package(receiverPkgPath) + if !ok { + return "", fmt.Errorf("could not determine receiver package %s", receiverPkgPath) + } + obj := receiverPkg.Types.Scope().Lookup(config.ReceiverType) + if config.ReceiverType != "" && obj == nil { + return "", fmt.Errorf("could not find receiver type %s in %s", config.ReceiverType, receiverPkg.PkgPath) + } + named, ok := obj.Type().(*types.Named) + if !ok { + return "", fmt.Errorf("expected receiver %s to be a named type", config.ReceiverType) + } + receiver = named } else { - log.Fatal("expected the current directory to have a non-test package", pl) + receiver = types.NewNamed(types.NewTypeName(0, routesPkg.Types, "Receiver", nil), types.NewStruct(nil, nil), nil) } - config.PackagePath = currentPkg.PkgPath - config.Package = currentPkg.Name - ts, err := source.Templates(wd, config.TemplatesVar, currentPkg) + ts, err := source.Templates(wd, config.TemplatesVar, routesPkg) if err != nil { return "", err } @@ -122,20 +141,9 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur config.executeFunc = filepath.Base(p.Fset.Position(executeObj.Pos()).Filename) == config.Output } } - obj := p.Types.Scope().Lookup(config.ReceiverType) - if obj != nil { - named, ok := obj.Type().(*types.Named) - if !ok { - return "", fmt.Errorf("expected receiver to be a named type") - } - receiver = named - } break } } - if receiver == nil { - receiver = types.NewNamed(types.NewTypeName(0, pkg, "Receiver", nil), types.NewStruct(nil, nil), nil) - } receiverInterface := &ast.InterfaceType{ Methods: new(ast.FieldList), diff --git a/routes_test.go b/routes_test.go index 8de72ca..be50965 100644 --- a/routes_test.go +++ b/routes_test.go @@ -181,7 +181,6 @@ func execute(response http.ResponseWriter, request *http.Request, writeHeader bo { Name: "when the default interface name is overwritten", Templates: `{{define "GET / F()"}}Hello{{end}}`, - Receiver: "T", Interface: "Server", ExpectedFile: `package main