From 7d85f2027ed86b1c9e4f29512d0ce0236d30a36d Mon Sep 17 00:00:00 2001
From: Chrstopher Hunter <8398225+crhntr@users.noreply.github.com>
Date: Thu, 28 Nov 2024 14:19:40 -0800
Subject: [PATCH] feat: use a package other than the routes package for
receiver
---
cmd/generate-readme/README.md.template | 4 +-
cmd/muxt/generate.go | 1 +
...and_routes_are_in_different_packages.txtar | 83 +++++++++++++++++++
internal/source/imports.go | 32 ++++++-
routes.go | 62 ++++++++------
routes_test.go | 1 -
6 files changed, 151 insertions(+), 32 deletions(-)
create mode 100644 cmd/muxt/testdata/generate/receiver_and_routes_are_in_different_packages.txtar
diff --git a/cmd/generate-readme/README.md.template b/cmd/generate-readme/README.md.template
index 4c491f7..4dfcbf4 100644
--- a/cmd/generate-readme/README.md.template
+++ b/cmd/generate-readme/README.md.template
@@ -173,11 +173,11 @@ type RoutesReceiver interface {
}
func routes(mux *http.ServeMux, receiver RoutesReceiver) {
- mux.HandleFunc("GET /articles/:id", func(response http.ResponseWriter, request *http.Request) {
+ mux.HandleFunc("GET /articles/{id}", func(response http.ResponseWriter, request *http.Request) {
ctx := request.Context()
id := request.PathValue("id")
data := receiver.ReadArticle(ctx, id)
- execute(response, request, true, "GET /articles/:id ReadArticle(ctx, id)", http.StatusOK, data)
+ execute(response, request, true, "GET /articles/{id} ReadArticle(ctx, id)", http.StatusOK, data)
})
}
diff --git a/cmd/muxt/generate.go b/cmd/muxt/generate.go
index 099db36..e580b20 100644
--- a/cmd/muxt/generate.go
+++ b/cmd/muxt/generate.go
@@ -23,6 +23,7 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string)
TemplatesVar: g.TemplatesVariable,
RoutesFunc: g.RoutesFunction,
ReceiverType: g.ReceiverIdent,
+ ReceiverPackage: g.ReceiverStaticTypePackage,
ReceiverInterface: g.ReceiverInterfaceIdent,
Output: g.OutputFilename,
})
diff --git a/cmd/muxt/testdata/generate/receiver_and_routes_are_in_different_packages.txtar b/cmd/muxt/testdata/generate/receiver_and_routes_are_in_different_packages.txtar
new file mode 100644
index 0000000..0504cea
--- /dev/null
+++ b/cmd/muxt/testdata/generate/receiver_and_routes_are_in_different_packages.txtar
@@ -0,0 +1,83 @@
+
+cd internal/hypertext
+
+muxt generate --receiver-type=Handler --receiver-type-package=crhntr.com/muxt-test/internal/endpoints --routes-func=Routes
+
+cd ../../
+
+exec go test
+
+-- go.mod --
+module crhntr.com/muxt-test
+
+go 1.23
+
+-- main_test.go --
+package main_test
+
+import (
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "crhntr.com/muxt-test/internal/endpoints"
+ "crhntr.com/muxt-test/internal/hypertext"
+)
+
+func Test(t *testing.T) {
+ mux := http.NewServeMux()
+ var h endpoints.Handler
+
+ hypertext.Routes(mux, h)
+
+ t.Run("GET", func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ mux.ServeHTTP(rec, req)
+
+ res := rec.Result()
+
+ if res.StatusCode != http.StatusOK {
+ t.Fail()
+ }
+
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !strings.Contains(string(body), `
result
`) {
+ t.Error("expected output text to contain result", string(body))
+ t.Log("got", string(body))
+ }
+ })
+}
+-- internal/endpoints/server.go --
+package endpoints
+
+type Handler struct{}
+
+func (Handler) F() string {
+ return "result"
+}
+-- internal/hypertext/generate.go --
+package hypertext
+
+import (
+ "embed"
+ "html/template"
+)
+
+var (
+ //go:embed *.gohtml
+ templateFiles embed.FS
+
+ templates = template.Must(template.ParseFS(templateFiles, "*"))
+)
+-- internal/hypertext/form.gohtml --
+{{- define "GET /{$} F()" -}}
+{{.}}
+{{- end -}}
diff --git a/internal/source/imports.go b/internal/source/imports.go
index 8f5deb3..c3f05a2 100644
--- a/internal/source/imports.go
+++ b/internal/source/imports.go
@@ -7,8 +7,10 @@ import (
"go/parser"
"go/token"
"go/types"
+ "golang.org/x/tools/go/packages"
"log"
"path"
+ "path/filepath"
"slices"
"strconv"
"strings"
@@ -19,6 +21,7 @@ type Imports struct {
fileSet *token.FileSet
types map[string]*types.Package
files map[string]*ast.File
+ packages []*packages.Package
outputPackage string
}
@@ -31,8 +34,33 @@ func NewImports(decl *ast.GenDecl) *Imports {
return &Imports{GenDecl: decl, types: make(map[string]*types.Package), files: make(map[string]*ast.File)}
}
-func (imports *Imports) AddPackages(p *types.Package) {
- recursivelyRegisterPackages(imports.types, p)
+func (imports *Imports) Package(path string) (*packages.Package, bool) {
+ for _, pkg := range imports.packages {
+ if pkg.PkgPath == path {
+ return pkg, true
+ }
+ }
+ return nil, false
+}
+
+func (imports *Imports) AddPackages(packages ...*packages.Package) {
+ imports.packages = slices.Grow(imports.packages, len(packages))
+ for _, pkg := range packages {
+ if pkg == nil {
+ continue
+ }
+ recursivelyRegisterPackages(imports.types, pkg.Types)
+ imports.packages = append(imports.packages, pkg)
+ }
+}
+
+func (imports *Imports) PackageAtFilepath(p string) (*packages.Package, bool) {
+ for _, pkg := range imports.packages {
+ if len(pkg.GoFiles) > 0 && filepath.Dir(pkg.GoFiles[0]) == p {
+ return pkg, true
+ }
+ }
+ return nil, false
}
func (imports *Imports) FileSet() *token.FileSet {
diff --git a/routes.go b/routes.go
index 692ea3d..1a2112b 100644
--- a/routes.go
+++ b/routes.go
@@ -60,6 +60,7 @@ type RoutesFileConfiguration struct {
TemplatesVar,
RoutesFunc,
ReceiverType,
+ ReceiverPackage,
ReceiverInterface,
Output string
}
@@ -80,33 +81,51 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur
}
imports := source.NewImports(&ast.GenDecl{Tok: token.IMPORT})
- var pkg *types.Package
+ patterns := []string{
+ wd, "net/http",
+ }
+
+ if config.ReceiverPackage != "" {
+ patterns = append(patterns, config.ReceiverPackage)
+ }
+
pl, err := packages.Load(&packages.Config{
Fset: imports.FileSet(),
Mode: packages.NeedModule | packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedSyntax | packages.NeedEmbedPatterns | packages.NeedEmbedFiles,
Dir: wd,
- }, ".", "net/http")
+ }, patterns...)
if err != nil {
return "", err
}
+ imports.AddPackages(pl...)
+ routesPkg, ok := imports.PackageAtFilepath(wd)
+ if !ok {
+ return "", fmt.Errorf("could not find package in working directory %q", wd)
+ }
+ imports.SetOutputPackage(routesPkg.Types)
+ config.PackagePath = routesPkg.PkgPath
+ config.Package = routesPkg.Name
var receiver *types.Named
- for _, p := range pl {
- imports.AddPackages(p.Types)
- }
- var currentPkg *packages.Package
- if len(pl) == 2 && pl[0].PkgPath == "net/http" {
- imports.SetOutputPackage(pl[1].Types)
- currentPkg = pl[1]
- } else if len(pl) == 2 && pl[1].PkgPath == "net/http" {
- imports.SetOutputPackage(pl[0].Types)
- currentPkg = pl[0]
+ if config.ReceiverType != "" {
+ receiverPkgPath := cmp.Or(config.ReceiverPackage, config.PackagePath)
+ receiverPkg, ok := imports.Package(receiverPkgPath)
+ if !ok {
+ return "", fmt.Errorf("could not determine receiver package %s", receiverPkgPath)
+ }
+ obj := receiverPkg.Types.Scope().Lookup(config.ReceiverType)
+ if config.ReceiverType != "" && obj == nil {
+ return "", fmt.Errorf("could not find receiver type %s in %s", config.ReceiverType, receiverPkg.PkgPath)
+ }
+ named, ok := obj.Type().(*types.Named)
+ if !ok {
+ return "", fmt.Errorf("expected receiver %s to be a named type", config.ReceiverType)
+ }
+ receiver = named
} else {
- log.Fatal("expected the current directory to have a non-test package", pl)
+ receiver = types.NewNamed(types.NewTypeName(0, routesPkg.Types, "Receiver", nil), types.NewStruct(nil, nil), nil)
}
- config.PackagePath = currentPkg.PkgPath
- config.Package = currentPkg.Name
- ts, err := source.Templates(wd, config.TemplatesVar, currentPkg)
+ ts, err := source.Templates(wd, config.TemplatesVar, routesPkg)
if err != nil {
return "", err
}
@@ -122,20 +141,9 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur
config.executeFunc = filepath.Base(p.Fset.Position(executeObj.Pos()).Filename) == config.Output
}
}
- obj := p.Types.Scope().Lookup(config.ReceiverType)
- if obj != nil {
- named, ok := obj.Type().(*types.Named)
- if !ok {
- return "", fmt.Errorf("expected receiver to be a named type")
- }
- receiver = named
- }
break
}
}
- if receiver == nil {
- receiver = types.NewNamed(types.NewTypeName(0, pkg, "Receiver", nil), types.NewStruct(nil, nil), nil)
- }
receiverInterface := &ast.InterfaceType{
Methods: new(ast.FieldList),
diff --git a/routes_test.go b/routes_test.go
index 8de72ca..be50965 100644
--- a/routes_test.go
+++ b/routes_test.go
@@ -181,7 +181,6 @@ func execute(response http.ResponseWriter, request *http.Request, writeHeader bo
{
Name: "when the default interface name is overwritten",
Templates: `{{define "GET / F()"}}Hello{{end}}`,
- Receiver: "T",
Interface: "Server",
ExpectedFile: `package main