Skip to content

Commit

Permalink
feat: use a package other than the routes package for receiver
Browse files Browse the repository at this point in the history
  • Loading branch information
crhntr committed Nov 28, 2024
1 parent 1fe9db0 commit 7d85f20
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 32 deletions.
4 changes: 2 additions & 2 deletions cmd/generate-readme/README.md.template
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ type RoutesReceiver interface {
}

func routes(mux *http.ServeMux, receiver RoutesReceiver) {
mux.HandleFunc("GET /articles/:id", func(response http.ResponseWriter, request *http.Request) {
mux.HandleFunc("GET /articles/{id}", func(response http.ResponseWriter, request *http.Request) {
ctx := request.Context()
id := request.PathValue("id")
data := receiver.ReadArticle(ctx, id)
execute(response, request, true, "GET /articles/:id ReadArticle(ctx, id)", http.StatusOK, data)
execute(response, request, true, "GET /articles/{id} ReadArticle(ctx, id)", http.StatusOK, data)
})
}

Expand Down
1 change: 1 addition & 0 deletions cmd/muxt/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string)
TemplatesVar: g.TemplatesVariable,
RoutesFunc: g.RoutesFunction,
ReceiverType: g.ReceiverIdent,
ReceiverPackage: g.ReceiverStaticTypePackage,
ReceiverInterface: g.ReceiverInterfaceIdent,
Output: g.OutputFilename,
})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@

cd internal/hypertext

muxt generate --receiver-type=Handler --receiver-type-package=crhntr.com/muxt-test/internal/endpoints --routes-func=Routes

cd ../../

exec go test

-- go.mod --
module crhntr.com/muxt-test

go 1.23

-- main_test.go --
package main_test

import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"crhntr.com/muxt-test/internal/endpoints"
"crhntr.com/muxt-test/internal/hypertext"
)

func Test(t *testing.T) {
mux := http.NewServeMux()
var h endpoints.Handler

hypertext.Routes(mux, h)

t.Run("GET", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

mux.ServeHTTP(rec, req)

res := rec.Result()

if res.StatusCode != http.StatusOK {
t.Fail()
}

body, err := io.ReadAll(res.Body)
if err != nil {
t.Error(err)
}

if !strings.Contains(string(body), `<h1>result</h1>`) {
t.Error("expected output text to contain result", string(body))
t.Log("got", string(body))
}
})
}
-- internal/endpoints/server.go --
package endpoints

type Handler struct{}

func (Handler) F() string {
return "result"
}
-- internal/hypertext/generate.go --
package hypertext

import (
"embed"
"html/template"
)

var (
//go:embed *.gohtml
templateFiles embed.FS

templates = template.Must(template.ParseFS(templateFiles, "*"))
)
-- internal/hypertext/form.gohtml --
{{- define "GET /{$} F()" -}}
<h1>{{.}}</h1>
{{- end -}}
32 changes: 30 additions & 2 deletions internal/source/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"go/parser"
"go/token"
"go/types"
"golang.org/x/tools/go/packages"
"log"
"path"
"path/filepath"
"slices"
"strconv"
"strings"
Expand All @@ -19,6 +21,7 @@ type Imports struct {
fileSet *token.FileSet
types map[string]*types.Package
files map[string]*ast.File
packages []*packages.Package
outputPackage string
}

Expand All @@ -31,8 +34,33 @@ func NewImports(decl *ast.GenDecl) *Imports {
return &Imports{GenDecl: decl, types: make(map[string]*types.Package), files: make(map[string]*ast.File)}
}

func (imports *Imports) AddPackages(p *types.Package) {
recursivelyRegisterPackages(imports.types, p)
func (imports *Imports) Package(path string) (*packages.Package, bool) {
for _, pkg := range imports.packages {
if pkg.PkgPath == path {
return pkg, true
}
}
return nil, false
}

func (imports *Imports) AddPackages(packages ...*packages.Package) {
imports.packages = slices.Grow(imports.packages, len(packages))
for _, pkg := range packages {
if pkg == nil {
continue
}
recursivelyRegisterPackages(imports.types, pkg.Types)
imports.packages = append(imports.packages, pkg)
}
}

func (imports *Imports) PackageAtFilepath(p string) (*packages.Package, bool) {
for _, pkg := range imports.packages {
if len(pkg.GoFiles) > 0 && filepath.Dir(pkg.GoFiles[0]) == p {
return pkg, true
}
}
return nil, false
}

func (imports *Imports) FileSet() *token.FileSet {
Expand Down
62 changes: 35 additions & 27 deletions routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type RoutesFileConfiguration struct {
TemplatesVar,
RoutesFunc,
ReceiverType,
ReceiverPackage,
ReceiverInterface,
Output string
}
Expand All @@ -80,33 +81,51 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur
}
imports := source.NewImports(&ast.GenDecl{Tok: token.IMPORT})

var pkg *types.Package
patterns := []string{
wd, "net/http",
}

if config.ReceiverPackage != "" {
patterns = append(patterns, config.ReceiverPackage)
}

pl, err := packages.Load(&packages.Config{
Fset: imports.FileSet(),
Mode: packages.NeedModule | packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedSyntax | packages.NeedEmbedPatterns | packages.NeedEmbedFiles,
Dir: wd,
}, ".", "net/http")
}, patterns...)
if err != nil {
return "", err
}
imports.AddPackages(pl...)
routesPkg, ok := imports.PackageAtFilepath(wd)
if !ok {
return "", fmt.Errorf("could not find package in working directory %q", wd)
}
imports.SetOutputPackage(routesPkg.Types)
config.PackagePath = routesPkg.PkgPath
config.Package = routesPkg.Name
var receiver *types.Named
for _, p := range pl {
imports.AddPackages(p.Types)
}
var currentPkg *packages.Package
if len(pl) == 2 && pl[0].PkgPath == "net/http" {
imports.SetOutputPackage(pl[1].Types)
currentPkg = pl[1]
} else if len(pl) == 2 && pl[1].PkgPath == "net/http" {
imports.SetOutputPackage(pl[0].Types)
currentPkg = pl[0]
if config.ReceiverType != "" {
receiverPkgPath := cmp.Or(config.ReceiverPackage, config.PackagePath)
receiverPkg, ok := imports.Package(receiverPkgPath)
if !ok {
return "", fmt.Errorf("could not determine receiver package %s", receiverPkgPath)
}
obj := receiverPkg.Types.Scope().Lookup(config.ReceiverType)
if config.ReceiverType != "" && obj == nil {
return "", fmt.Errorf("could not find receiver type %s in %s", config.ReceiverType, receiverPkg.PkgPath)
}
named, ok := obj.Type().(*types.Named)
if !ok {
return "", fmt.Errorf("expected receiver %s to be a named type", config.ReceiverType)
}
receiver = named
} else {
log.Fatal("expected the current directory to have a non-test package", pl)
receiver = types.NewNamed(types.NewTypeName(0, routesPkg.Types, "Receiver", nil), types.NewStruct(nil, nil), nil)
}
config.PackagePath = currentPkg.PkgPath
config.Package = currentPkg.Name

ts, err := source.Templates(wd, config.TemplatesVar, currentPkg)
ts, err := source.Templates(wd, config.TemplatesVar, routesPkg)
if err != nil {
return "", err
}
Expand All @@ -122,20 +141,9 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur
config.executeFunc = filepath.Base(p.Fset.Position(executeObj.Pos()).Filename) == config.Output
}
}
obj := p.Types.Scope().Lookup(config.ReceiverType)
if obj != nil {
named, ok := obj.Type().(*types.Named)
if !ok {
return "", fmt.Errorf("expected receiver to be a named type")
}
receiver = named
}
break
}
}
if receiver == nil {
receiver = types.NewNamed(types.NewTypeName(0, pkg, "Receiver", nil), types.NewStruct(nil, nil), nil)
}

receiverInterface := &ast.InterfaceType{
Methods: new(ast.FieldList),
Expand Down
1 change: 0 additions & 1 deletion routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ func execute(response http.ResponseWriter, request *http.Request, writeHeader bo
{
Name: "when the default interface name is overwritten",
Templates: `{{define "GET / F()"}}Hello{{end}}`,
Receiver: "T",
Interface: "Server",
ExpectedFile: `package main
Expand Down

0 comments on commit 7d85f20

Please sign in to comment.