diff --git a/internal/check/tree.go b/internal/check/tree.go index 7ca2477..dc4e855 100644 --- a/internal/check/tree.go +++ b/internal/check/tree.go @@ -48,6 +48,13 @@ type scope struct { variables map[string]types.Type } +func (s *scope) child() *scope { + return &scope{ + global: s.global, + variables: maps.Clone(s.variables), + } +} + func (s *scope) checkNode(tree *parse.Tree, dot types.Type, node parse.Node) (types.Type, error) { switch n := node.(type) { case *parse.DotNode: @@ -159,9 +166,14 @@ func (s *scope) checkIfNode(tree *parse.Tree, dot types.Type, n *parse.IfNode) e if err != nil { return err } - if _, err := s.checkNode(tree, dot, n.List); err != nil { + if _, err := s.child().checkNode(tree, dot, n.List); err != nil { return err } + if n.ElseList != nil { + if _, err := s.child().checkNode(tree, dot, n.ElseList); err != nil { + return err + } + } return nil } @@ -238,7 +250,7 @@ func (s *scope) checkCommandNode(tree *parse.Tree, dot types.Type, n *parse.Comm pt = sig.Params().At(i).Type() } if !types.AssignableTo(at, pt) { - return nil, fmt.Errorf("%s argument %d has type %s expected %s", n.Args[0], i-1, at, pt) + return nil, fmt.Errorf("%s argument %d has type %s expected %s", n.Args[0], i, at, pt) } } return sig.Results().At(0).Type(), nil @@ -286,11 +298,8 @@ func (s *scope) checkIdentifiers(tree *parse.Tree, dot types.Type, n parse.Node, } func (s *scope) checkRangeNode(tree *parse.Tree, dot types.Type, n *parse.RangeNode) error { - loopScope := &scope{ - global: s.global, - variables: maps.Clone(s.variables), - } - pipeType, err := loopScope.checkNode(tree, dot, n.Pipe) + child := s.child() + pipeType, err := child.checkNode(tree, dot, n.Pipe) if err != nil { return err } @@ -305,11 +314,11 @@ func (s *scope) checkRangeNode(tree *parse.Tree, dot types.Type, n *parse.RangeN default: return fmt.Errorf("failed to range over %s", pipeType) } - if _, err := loopScope.checkNode(tree, x, n.List); err != nil { + if _, err := child.checkNode(tree, x, n.List); err != nil { return err } if n.ElseList != nil { - if _, err := loopScope.checkNode(tree, x, n.ElseList); err != nil { + if _, err := child.checkNode(tree, x, n.ElseList); err != nil { return err } } diff --git a/internal/check/tree_test.go b/internal/check/tree_test.go index d935d4e..9e85930 100644 --- a/internal/check/tree_test.go +++ b/internal/check/tree_test.go @@ -15,6 +15,7 @@ import ( "github.com/crhntr/muxt" "github.com/crhntr/muxt/internal/check" + "github.com/crhntr/muxt/internal/source" ) func TestTree(t *testing.T) { @@ -278,14 +279,48 @@ func TestTree(t *testing.T) { Template: `{{$v := 1}}{{.F $v}}`, Data: MethodWithIntParam{}, }, + { + Name: "when there is an error in the else block", + Template: `{{$x := "wrong type"}}{{if false}}{{else}}{{.F $x}}{{end}}`, + Data: MethodWithIntParam{}, + Error: func(t *testing.T, checkErr, _ error, tp types.Type) { + require.Error(t, checkErr) + require.ErrorContains(t, checkErr, ".F argument 0 has type untyped string expected int") + }, + }, + { + Name: "variable redefined in if block", + Template: `{{$x := 1}}{{if true}}{{$x := "str"}}{{end}}{{.F $x}}`, + Data: MethodWithIntParam{}, + }, + { + Name: "range variable does not clobber outer scope", + Template: `{{$x := 1}}{{range .Numbers}}{{$x := "str"}}{{end}}{{square $x}}`, + Data: MethodWithKeyValForSlices{}, + }, + { + Name: "range variable does not override outer scope", + Template: `{{$x := "str"}}{{range $x, $y := .Numbers}}{{$.F $x $y}}{{end}}{{printf $x}}`, + Data: MethodWithKeyValForSlices{}, + }, + { + Name: "source provided function", + Template: `{{square 5}}`, + Data: T{}, + }, } { t.Run(tt.Name, func(t *testing.T) { - templates, parseErr := template.New("template").Parse(tt.Template) + templates, parseErr := template.New("template").Funcs(template.FuncMap{ + "square": square, + }).Parse(tt.Template) require.NoError(t, parseErr) dataType := checkTestPackage.Types.Scope().Lookup(reflect.TypeOf(tt.Data).Name()).Type() - if checkErr := check.Tree(templates.Tree, dataType, checkTestPackage.Types, checkTestPackage.Fset, newForrest(templates), nil); tt.Error != nil { + functions := source.DefaultFunctions(checkTestPackage.Types) + functions["square"] = checkTestPackage.Types.Scope().Lookup("square").(*types.Func).Signature() + + if checkErr := check.Tree(templates.Tree, dataType, checkTestPackage.Types, checkTestPackage.Fset, newForrest(templates), functions); tt.Error != nil { execErr := templates.Execute(io.Discard, tt.Data) tt.Error(t, checkErr, execErr, dataType) } else { @@ -431,6 +466,10 @@ type MethodWithKeyValForMap struct { func (MethodWithKeyValForMap) F(int16, float32) (_ T) { return } +func square(n int) int { + return n * n +} + func TestExampleTemplate(t *testing.T) { packageList, loadErr := packages.Load(&packages.Config{ Mode: packages.NeedName | packages.NeedFiles | packages.NeedDeps | packages.NeedTypes, diff --git a/internal/source/template.go b/internal/source/template.go index 8cc2808..05f379b 100644 --- a/internal/source/template.go +++ b/internal/source/template.go @@ -16,10 +16,8 @@ import ( "golang.org/x/tools/go/packages" ) -type TemplateFuncMap = map[string]*types.Signature - -func Templates(workingDirectory, templatesVariable string, pkg *packages.Package) (*template.Template, TemplateFuncMap, error) { - funcTypeMap := registerDefaultFunctions(pkg.Types) +func Templates(workingDirectory, templatesVariable string, pkg *packages.Package) (*template.Template, Functions, error) { + funcTypeMap := DefaultFunctions(pkg.Types) for _, tv := range IterateValueSpecs(pkg.Syntax) { i := slices.IndexFunc(tv.Names, func(e *ast.Ident) bool { return e.Name == templatesVariable @@ -53,19 +51,7 @@ func findPackage(pkg *types.Package, path string) (*types.Package, bool) { return nil, false } -func registerDefaultFunctions(pkg *types.Package) TemplateFuncMap { - funcTypeMap := make(TemplateFuncMap) - fmtPkg, ok := findPackage(pkg, "fmt") - if !ok || fmtPkg == nil { - return funcTypeMap - } - funcTypeMap["printf"] = fmtPkg.Scope().Lookup("Sprintf").Type().(*types.Signature) - funcTypeMap["print"] = fmtPkg.Scope().Lookup("Sprint").Type().(*types.Signature) - funcTypeMap["println"] = fmtPkg.Scope().Lookup("Sprintln").Type().(*types.Signature) - return funcTypeMap -} - -func evaluateTemplateSelector(ts *template.Template, pkg *types.Package, expression ast.Expr, workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string, funcTypeMaps TemplateFuncMap, fm template.FuncMap) (*template.Template, error) { +func evaluateTemplateSelector(ts *template.Template, pkg *types.Package, expression ast.Expr, workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string, funcTypeMaps Functions, fm template.FuncMap) (*template.Template, error) { call, ok := expression.(*ast.CallExpr) if !ok { return nil, contextError(workingDirectory, fileSet, expression.Pos(), fmt.Errorf("expected call expression")) @@ -161,7 +147,7 @@ func evaluateTemplateSelector(ts *template.Template, pkg *types.Package, express } } -func evaluateFuncMap(workingDirectory, templatePackageIdent string, pkg *types.Package, fileSet *token.FileSet, call *ast.CallExpr, fm template.FuncMap, funcTypesMap TemplateFuncMap) error { +func evaluateFuncMap(workingDirectory, templatePackageIdent string, pkg *types.Package, fileSet *token.FileSet, call *ast.CallExpr, fm template.FuncMap, funcTypesMap Functions) error { const funcMapTypeIdent = "FuncMap" if len(call.Args) != 1 { return contextError(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly 1 template.FuncMap composite literal argument")) @@ -385,3 +371,30 @@ func relativeFilePaths(wd string, abs ...string) ([]string, error) { } return result, nil } + +type Functions map[string]*types.Signature + +func NewFunctions(m map[string]*types.Signature) Functions { + return Functions(m) +} + +func DefaultFunctions(pkg *types.Package) Functions { + funcTypeMap := make(Functions) + fmtPkg, ok := findPackage(pkg, "fmt") + if !ok || fmtPkg == nil { + return funcTypeMap + } + funcTypeMap["printf"] = fmtPkg.Scope().Lookup("Sprintf").Type().(*types.Signature) + funcTypeMap["print"] = fmtPkg.Scope().Lookup("Sprint").Type().(*types.Signature) + funcTypeMap["println"] = fmtPkg.Scope().Lookup("Sprintln").Type().(*types.Signature) + return funcTypeMap +} + +func (functions Functions) FindFunction(name string) (*types.Signature, bool) { + m := (map[string]*types.Signature)(functions) + fn, ok := m[name] + if !ok { + return nil, false + } + return fn, true +}