Skip to content

Commit

Permalink
refactor: only parse packages once
Browse files Browse the repository at this point in the history
  • Loading branch information
crhntr committed Nov 28, 2024
1 parent 42b8a22 commit f2f6523
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 57 deletions.
57 changes: 20 additions & 37 deletions cmd/muxt/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"bytes"
"cmp"
"flag"
"fmt"
"go/token"
Expand All @@ -14,7 +13,6 @@ import (
"golang.org/x/tools/go/packages"

"github.com/crhntr/muxt"
"github.com/crhntr/muxt/internal/source"
)

const (
Expand All @@ -30,10 +28,9 @@ const (
)

type Generate struct {
Package *packages.Package
goPackage string
goFile string
goLine string
Package *packages.Package
goFile string
goLine string

templatesVariable string
outputFilename string
Expand All @@ -45,9 +42,8 @@ type Generate struct {

func newGenerate(args []string, getEnv func(string) string, stderr io.Writer) (Generate, error) {
g := Generate{
goPackage: getEnv("GOPACKAGE"),
goFile: getEnv("GOFILE"),
goLine: getEnv("GOLINE"),
goFile: getEnv("GOFILE"),
goLine: getEnv("GOLINE"),
}
flagSet := flag.NewFlagSet("generate", flag.ContinueOnError)
flagSet.SetOutput(stderr)
Expand Down Expand Up @@ -82,41 +78,28 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string)
if err != nil {
return err
}

g.goPackage = cmp.Or(g.goPackage, filepath.Base(workingDirectory))

packageList, err := source.Load(workingDirectory, ".")
if err != nil {
return err
}
if len(packageList) > 0 {
g.Package = packageList[0]
g.goPackage = packageList[0].Name
}
ts, err := source.Templates(workingDirectory, g.templatesVariable, g.Package)
if err != nil {
return err
}
templates, err := muxt.Templates(ts)
if err != nil {
return err
}
s, err := muxt.TemplateRoutesFile(workingDirectory, templates, log.New(stdout, "", 0), muxt.RoutesFileConfiguration{
Package: g.goPackage,
PackagePath: g.Package.PkgPath,
s, err := muxt.TemplateRoutesFile(workingDirectory, log.New(stdout, "", 0), muxt.RoutesFileConfiguration{
Package: getEnv("GOPACKAGE"),
TemplatesVar: g.templatesVariable,
RoutesFunc: g.routesFunction,
ReceiverType: g.receiverIdent,
ReceiverInterface: g.receiverInterfaceIdent,
Output: g.outputFilename,
})
var sb bytes.Buffer
sb.WriteString(CodeGenerationComment)
if v, ok := cliVersion(); ok {
sb.WriteString("\n// muxt version: ")
sb.WriteString(v)
sb.WriteString("\n\n")
if err != nil {
return err
}
var sb bytes.Buffer
writeCodeGenerationComment(&sb)
sb.WriteString(s)
return os.WriteFile(filepath.Join(workingDirectory, g.outputFilename), sb.Bytes(), 0o644)
}

func writeCodeGenerationComment(w io.StringWriter) {
_, _ = w.WriteString(CodeGenerationComment)
if v, ok := cliVersion(); ok {
_, _ = w.WriteString("\n// muxt version: ")
_, _ = w.WriteString(v)
_, _ = w.WriteString("\n\n")
}
}
1 change: 0 additions & 1 deletion internal/source/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ func (imports *Imports) FieldTag(pos token.Pos) (*ast.Field, error) {
}
}
}

}
return nil, fmt.Errorf("failed to find field")
}
Expand Down
12 changes: 6 additions & 6 deletions internal/source/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func Templates(workingDirectory, templatesVariable string, pkg *packages.Package
if i < 0 || i >= len(tv.Values) {
continue
}
embeddedPaths, err := relFilepaths(workingDirectory, pkg.EmbedFiles...)
embeddedPaths, err := relativeFilePaths(workingDirectory, pkg.EmbedFiles...)
if err != nil {
return nil, fmt.Errorf("failed to calculate relative path for embedded files: %w", err)
}
Expand Down Expand Up @@ -179,7 +179,7 @@ func evaluateCallParseFilesArgs(workingDirectory string, fileSet *token.FileSet,
if len(call.Args) < 1 {
return nil, contextError(workingDirectory, fileSet, call.Lparen, fmt.Errorf("missing required arguments"))
}
matches, err := embedFSFilepaths(workingDirectory, fileSet, files, call.Args[0], embeddedPaths)
matches, err := embedFSFilePaths(workingDirectory, fileSet, files, call.Args[0], embeddedPaths)
if err != nil {
return nil, err
}
Expand All @@ -201,10 +201,10 @@ func evaluateCallParseFilesArgs(workingDirectory string, fileSet *token.FileSet,
break
}
}
return joinFilepaths(workingDirectory, filtered...), nil
return joinFilePaths(workingDirectory, filtered...), nil
}

func embedFSFilepaths(dir string, fileSet *token.FileSet, files []*ast.File, exp ast.Expr, embeddedFiles []string) ([]string, error) {
func embedFSFilePaths(dir string, fileSet *token.FileSet, files []*ast.File, exp ast.Expr, embeddedFiles []string) ([]string, error) {
varIdent, ok := exp.(*ast.Ident)
if !ok {
return nil, contextError(dir, fileSet, exp.Pos(), fmt.Errorf("first argument to ParseFS must be an identifier"))
Expand Down Expand Up @@ -323,15 +323,15 @@ func contextError(workingDirectory string, set *token.FileSet, pos token.Pos, er
return fmt.Errorf("%s: %w", p, err)
}

func joinFilepaths(wd string, rel ...string) []string {
func joinFilePaths(wd string, rel ...string) []string {
result := slices.Clone(rel)
for i := range result {
result[i] = filepath.Join(wd, result[i])
}
return result
}

func relFilepaths(wd string, abs ...string) ([]string, error) {
func relativeFilePaths(wd string, abs ...string) ([]string, error) {
result := slices.Clone(abs)
for i, p := range result {
r, err := filepath.Rel(wd, p)
Expand Down
21 changes: 19 additions & 2 deletions routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ func (config RoutesFileConfiguration) applyDefaults() RoutesFileConfiguration {
return config
}

func TemplateRoutesFile(wd string, templates []Template, logger *log.Logger, config RoutesFileConfiguration) (string, error) {
func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfiguration) (string, error) {
config = config.applyDefaults()
if !token.IsIdentifier(config.Package) {
return "", fmt.Errorf("package name %q is not an identifier", config.Package)
}
imports := source.NewImports(&ast.GenDecl{Tok: token.IMPORT})

var pkg *types.Package
Expand All @@ -90,13 +93,27 @@ func TemplateRoutesFile(wd string, templates []Template, logger *log.Logger, con
for _, p := range pl {
imports.AddPackages(p.Types)
}
var currentPkg *packages.Package
if len(pl) == 2 && pl[0].PkgPath == "net/http" {
imports.SetOutputPackage(pl[1].Types)
currentPkg = pl[1]
} else if len(pl) == 2 && pl[1].PkgPath == "net/http" {
imports.SetOutputPackage(pl[0].Types)
currentPkg = pl[0]
} else {
log.Fatal("expected the current directory to have a non-test package", pl)
}
config.PackagePath = currentPkg.PkgPath
config.Package = currentPkg.Name

ts, err := source.Templates(wd, config.TemplatesVar, currentPkg)
if err != nil {
return "", err
}
templates, err := Templates(ts)
if err != nil {
return "", err
}

for _, p := range pl {
if p.Types.Path() == config.PackagePath {
Expand Down Expand Up @@ -222,7 +239,7 @@ func appendParseArgumentStatements(statements []ast.Stmt, t Template, imports *s
if !ok {
return nil, fmt.Errorf("expected method")
}
//const parsedVariableName = "parsed"
// const parsedVariableName = "parsed"
if exp := signature.Params().Len(); exp != len(call.Args) { // TODO: (signature.Variadic() && exp > len(call.Args))
sigStr := fun.Name + strings.TrimPrefix(signature.String(), "func")
return nil, fmt.Errorf("handler func %s expects %d arguments but call %s has %d", sigStr, signature.Params().Len(), source.Format(call), len(call.Args))
Expand Down
26 changes: 18 additions & 8 deletions routes_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package muxt_test

import (
"html/template"
"cmp"
"fmt"
"io"
"log"
"os"
Expand Down Expand Up @@ -1923,19 +1924,28 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) {
},
} {
t.Run(tt.Name, func(t *testing.T) {
ts := template.Must(template.New(tt.Name).Parse(tt.Templates))
templates, err := muxt.Templates(ts)
require.NoError(t, err)
logger := log.New(io.Discard, "", 0)

archive := txtar.Parse([]byte(tt.ReceiverPackage))
archiveDir, err := txtar.FS(archive)
require.NoError(t, err)

dir := t.TempDir()
require.NoError(t, os.CopyFS(dir, archiveDir))
require.NoError(t, os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module example.com\n\ngo 1.20\n"), 0644))
out, err := muxt.TemplateRoutesFile(dir, templates, logger, muxt.RoutesFileConfiguration{
require.NoError(t, os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module example.com\n\ngo 1.20\n"), 0o644))
require.NoError(t, os.WriteFile(filepath.Join(dir, "template.gohtml"), []byte(tt.Templates), 0o644))
require.NoError(t, os.WriteFile(filepath.Join(dir, "templates.go"), []byte(fmt.Sprintf(`package %s
import (
"embed"
"html/template"
)
//go:embed template.gohtml
var templatesDir embed.FS
var templates = template.Must(template.ParseFS(templatesDir, "template.gohtml"))
`, cmp.Or(tt.PackageName, "main"))), 0o644))
logger := log.New(io.Discard, "", 0)
out, err := muxt.TemplateRoutesFile(dir, logger, muxt.RoutesFileConfiguration{
ReceiverInterface: tt.Interface,
Package: tt.PackageName,
TemplatesVar: tt.TemplatesVar,
Expand Down
7 changes: 4 additions & 3 deletions template.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (

func Templates(ts *template.Template) ([]Template, error) {
var templateNames []Template
routes := make(map[string]struct{})
patterns := make(map[string]struct{})
for _, t := range ts.Templates() {
mt, err, ok := NewTemplateName(t.Name())
if !ok {
Expand All @@ -27,11 +27,12 @@ func Templates(ts *template.Template) ([]Template, error) {
if err != nil {
return templateNames, err
}
if _, exists := routes[mt.method+mt.path]; exists {
pattern := strings.Join([]string{mt.method, mt.host, mt.path}, " ")
if _, exists := patterns[pattern]; exists {
return templateNames, fmt.Errorf("duplicate route pattern: %s", mt.pattern)
}
mt.template = t
routes[mt.method+mt.path] = struct{}{}
patterns[pattern] = struct{}{}
templateNames = append(templateNames, mt)
}
slices.SortFunc(templateNames, Template.byPathThenMethod)
Expand Down

0 comments on commit f2f6523

Please sign in to comment.