Skip to content

Commit

Permalink
integrate check into generate
Browse files Browse the repository at this point in the history
  • Loading branch information
crhntr committed Dec 2, 2024
1 parent acb5202 commit 36c18cd
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 89 deletions.
8 changes: 6 additions & 2 deletions cmd/muxt/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ import (
"github.com/crhntr/muxt/internal/configuration"
)

const CodeGenerationComment = "// Code generated by muxt. DO NOT EDIT."
const (
CodeGenerationComment = "// Code generated by muxt. DO NOT EDIT."
experimentCheckTypesEnvVar = "MUXT_EXPERIMENT_CHECK_TYPES"
)

func generateCommand(args []string, workingDirectory string, stdout, stderr io.Writer) error {
func generateCommand(workingDirectory string, args []string, getEnv func(string) string, stdout, stderr io.Writer) error {
config, err := configuration.NewRoutesFileConfiguration(args, stderr)
if err != nil {
return err
}
config.ExperimentalCheckTypes = getEnv(experimentCheckTypesEnvVar) == "true"
s, err := muxt.TemplateRoutesFile(workingDirectory, log.New(stdout, "", 0), config)
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions cmd/muxt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ func main() {
os.Exit(handleError(command(wd, flag.Args(), os.Getenv, os.Stdout, os.Stderr)))
}

func command(wd string, args []string, _ func(string) string, stdout, stderr io.Writer) error {
func command(wd string, args []string, getEnv func(string) string, stdout, stderr io.Writer) error {
if len(args) > 0 {
switch cmd, cmdArgs := args[0], args[1:]; cmd {
case "generate", "gen", "g":
return generateCommand(cmdArgs, wd, stdout, stderr)
return generateCommand(wd, cmdArgs, getEnv, stdout, stderr)
case "version", "v":
return versionCommand(stdout)
}
Expand Down
3 changes: 3 additions & 0 deletions cmd/muxt/script.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ func scriptCommand() script.Cmd {
return func(state *script.State) (string, string, error) {
var stdout, stderr bytes.Buffer
err := command(state.Getwd(), args, func(s string) string {
if s == experimentCheckTypesEnvVar {
return "true"
}
e, _ := state.LookupEnv(s)
return e
}, &stdout, &stderr)
Expand Down
25 changes: 21 additions & 4 deletions internal/check/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"go/token"
"go/types"
"maps"
"strings"
"text/template/parse"
)

Expand Down Expand Up @@ -223,7 +224,16 @@ func (s *scope) checkCommandNode(tree *parse.Tree, dot types.Type, n *parse.Comm
}
for i := 0; i < len(argTypes); i++ {
at := argTypes[i]
pt := sig.Params().At(i).Type()
var pt types.Type
isVar := sig.Variadic()
argVar := i >= sig.Params().Len()-1
if isVar && argVar {
ps := sig.Params()
v := ps.At(ps.Len() - 1).Type().(*types.Slice)
pt = v.Elem()
} else {
pt = sig.Params().At(i).Type()
}
if !types.AssignableTo(at, pt) {
return nil, fmt.Errorf("%s argument %d has type %s expected %s", n.Args[0], i-1, at, pt)
}
Expand Down Expand Up @@ -304,9 +314,16 @@ func (s *scope) checkRangeNode(tree *parse.Tree, dot types.Type, n *parse.RangeN
}

func (s *scope) checkIdentifierNode(n *parse.IdentifierNode) (types.Type, error) {
tp, ok := s.variables[n.Ident]
if strings.HasPrefix(n.Ident, "$") {
tp, ok := s.variables[n.Ident]
if !ok {
return nil, fmt.Errorf("failed to find identifier %s", n.Ident)
}
return tp, nil
}
fn, ok := s.FindFunction(n.Ident)
if !ok {
return nil, fmt.Errorf("failed to find identifier %q", n.Ident)
return nil, fmt.Errorf("failed to find function %s", n.Ident)
}
return tp, nil
return fn, nil
}
85 changes: 64 additions & 21 deletions internal/source/template.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package source

import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/token"
"go/types"
"html/template"
"path/filepath"
"slices"
Expand All @@ -13,7 +16,10 @@ import (
"golang.org/x/tools/go/packages"
)

func Templates(workingDirectory, templatesVariable string, pkg *packages.Package) (*template.Template, error) {
type TemplateFuncMap = map[string]*types.Signature

func Templates(workingDirectory, templatesVariable string, pkg *packages.Package) (*template.Template, TemplateFuncMap, error) {
funcTypeMap := registerDefaultFunctions(pkg.Types)
for _, tv := range IterateValueSpecs(pkg.Syntax) {
i := slices.IndexFunc(tv.Names, func(e *ast.Ident) bool {
return e.Name == templatesVariable
Expand All @@ -23,19 +29,43 @@ func Templates(workingDirectory, templatesVariable string, pkg *packages.Package
}
embeddedPaths, err := relativeFilePaths(workingDirectory, pkg.EmbedFiles...)
if err != nil {
return nil, fmt.Errorf("failed to calculate relative path for embedded files: %w", err)
return nil, nil, fmt.Errorf("failed to calculate relative path for embedded files: %w", err)
}
const templatePackageIdent = "template"
ts, err := evaluateTemplateSelector(nil, tv.Values[i], workingDirectory, templatesVariable, templatePackageIdent, "", "", pkg.Fset, pkg.Syntax, embeddedPaths)
ts, err := evaluateTemplateSelector(nil, pkg.Types, tv.Values[i], workingDirectory, templatesVariable, templatePackageIdent, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap))
if err != nil {
return nil, fmt.Errorf("run template %s failed at %w", templatesVariable, err)
return nil, nil, fmt.Errorf("run template %s failed at %w", templatesVariable, err)
}
return ts, nil
return ts, funcTypeMap, nil
}
return nil, nil, fmt.Errorf("variable %s not found", templatesVariable)
}

func findPackage(pkg *types.Package, path string) (*types.Package, bool) {
if pkg.Path() == path {
return pkg, true
}
for _, im := range pkg.Imports() {
if p, ok := findPackage(im, path); ok {
return p, true
}
}
return nil, false
}

func registerDefaultFunctions(pkg *types.Package) TemplateFuncMap {
funcTypeMap := make(TemplateFuncMap)
fmtPkg, ok := findPackage(pkg, "fmt")
if !ok || fmtPkg == nil {
return funcTypeMap
}
return nil, fmt.Errorf("variable %s not found", templatesVariable)
funcTypeMap["printf"] = fmtPkg.Scope().Lookup("Sprintf").Type().(*types.Signature)
funcTypeMap["print"] = fmtPkg.Scope().Lookup("Sprint").Type().(*types.Signature)
funcTypeMap["println"] = fmtPkg.Scope().Lookup("Sprintln").Type().(*types.Signature)
return funcTypeMap
}

func evaluateTemplateSelector(ts *template.Template, expression ast.Expr, workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string) (*template.Template, error) {
func evaluateTemplateSelector(ts *template.Template, pkg *types.Package, expression ast.Expr, workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string, funcTypeMaps TemplateFuncMap, fm template.FuncMap) (*template.Template, error) {
call, ok := expression.(*ast.CallExpr)
if !ok {
return nil, contextError(workingDirectory, fileSet, expression.Pos(), fmt.Errorf("expected call expression"))
Expand All @@ -56,7 +86,7 @@ func evaluateTemplateSelector(ts *template.Template, expression ast.Expr, workin
if len(call.Args) != 1 {
return nil, contextError(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one argument %s got %d", Format(sel.X), len(call.Args)))
}
return evaluateTemplateSelector(ts, call.Args[0], workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim, fileSet, files, embeddedPaths)
return evaluateTemplateSelector(ts, pkg, call.Args[0], workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim, fileSet, files, embeddedPaths, funcTypeMaps, fm)
case "New":
if len(call.Args) != 1 {
return nil, contextError(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument"))
Expand All @@ -76,7 +106,7 @@ func evaluateTemplateSelector(ts *template.Template, expression ast.Expr, workin
return nil, contextError(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported function %s", sel.Sel.Name))
}
case *ast.CallExpr:
up, err := evaluateTemplateSelector(ts, sel.X, workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim, fileSet, files, embeddedPaths)
up, err := evaluateTemplateSelector(ts, pkg, sel.X, workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim, fileSet, files, embeddedPaths, funcTypeMaps, fm)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -121,44 +151,44 @@ func evaluateTemplateSelector(ts *template.Template, expression ast.Expr, workin
}
return up.Option(list...), nil
case "Funcs":
funcMap, err := evaluateFuncMap(workingDirectory, templatePackageIdent, fileSet, call)
if err != nil {
if err := evaluateFuncMap(workingDirectory, templatePackageIdent, pkg, fileSet, call, fm, funcTypeMaps); err != nil {
return nil, err
}
return up.Funcs(funcMap), nil
return up.Funcs(fm), nil
default:
return nil, contextError(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported method %s", sel.Sel.Name))
}
}
}

func evaluateFuncMap(workingDirectory, templatePackageIdent string, fileSet *token.FileSet, call *ast.CallExpr) (template.FuncMap, error) {
func evaluateFuncMap(workingDirectory, templatePackageIdent string, pkg *types.Package, fileSet *token.FileSet, call *ast.CallExpr, fm template.FuncMap, funcTypesMap TemplateFuncMap) error {
const funcMapTypeIdent = "FuncMap"
fm := make(template.FuncMap)
if len(call.Args) != 1 {
return nil, contextError(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly 1 template.FuncMap composite literal argument"))
return contextError(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly 1 template.FuncMap composite literal argument"))
}
arg := call.Args[0]
lit, ok := arg.(*ast.CompositeLit)
if !ok {
return nil, contextError(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a composite literal with type %s.%s got %s", templatePackageIdent, funcMapTypeIdent, Format(arg)))
return contextError(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a composite literal with type %s.%s got %s", templatePackageIdent, funcMapTypeIdent, Format(arg)))
}
typeSel, ok := lit.Type.(*ast.SelectorExpr)
if !ok || typeSel.Sel.Name != funcMapTypeIdent {
return nil, contextError(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a composite literal with type %s.%s got %s", templatePackageIdent, funcMapTypeIdent, Format(arg)))
return contextError(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a composite literal with type %s.%s got %s", templatePackageIdent, funcMapTypeIdent, Format(arg)))
}
if tp, ok := typeSel.X.(*ast.Ident); !ok || tp.Name != templatePackageIdent {
return nil, contextError(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a composite literal with type %s.%s got %s", templatePackageIdent, funcMapTypeIdent, Format(arg)))
return contextError(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a composite literal with type %s.%s got %s", templatePackageIdent, funcMapTypeIdent, Format(arg)))
}
var buf bytes.Buffer
for i, exp := range lit.Elts {
el, ok := exp.(*ast.KeyValueExpr)
if !ok {
return nil, contextError(workingDirectory, fileSet, exp.Pos(), fmt.Errorf("expected element at index %d to be a key value pair got %s", i, Format(exp)))
return contextError(workingDirectory, fileSet, exp.Pos(), fmt.Errorf("expected element at index %d to be a key value pair got %s", i, Format(exp)))
}
funcName, err := evaluateStringLiteralExpression(workingDirectory, fileSet, el.Key)
if err != nil {
return nil, err
return err
}

// template.Parse does not evaluate the function signature parameters;
// it ensures the function name is in scope and there is one or two results.
// we could use something like func() string { return "" } for this signature
Expand All @@ -171,8 +201,21 @@ func evaluateFuncMap(workingDirectory, templatePackageIdent string, fileSet *tok
// or
// fm[funcName] = func() (int, int) {return 0, 0} // will fail because the second result is not an error
fm[funcName] = fmt.Sprintln

if pkg == nil {
continue
}
buf.Reset()
if err := format.Node(&buf, fileSet, el.Value); err != nil {
return err
}
tv, err := types.Eval(fileSet, pkg, lit.Pos(), buf.String())
if err != nil {
return err
}
funcTypesMap[funcName] = tv.Type.(*types.Signature)
}
return fm, nil
return nil
}

func evaluateCallParseFilesArgs(workingDirectory string, fileSet *token.FileSet, call *ast.CallExpr, files []*ast.File, embeddedPaths []string) ([]string, error) {
Expand Down
Loading

0 comments on commit 36c18cd

Please sign in to comment.