Skip to content

Commit

Permalink
feat: typecheck function calls
Browse files Browse the repository at this point in the history
  • Loading branch information
crhntr committed Dec 21, 2024
1 parent ef71eb3 commit dfeec51
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 29 deletions.
27 changes: 18 additions & 9 deletions internal/check/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ type scope struct {
variables map[string]types.Type
}

func (s *scope) child() *scope {
return &scope{
global: s.global,
variables: maps.Clone(s.variables),
}
}

func (s *scope) checkNode(tree *parse.Tree, dot types.Type, node parse.Node) (types.Type, error) {
switch n := node.(type) {
case *parse.DotNode:
Expand Down Expand Up @@ -159,9 +166,14 @@ func (s *scope) checkIfNode(tree *parse.Tree, dot types.Type, n *parse.IfNode) e
if err != nil {
return err
}
if _, err := s.checkNode(tree, dot, n.List); err != nil {
if _, err := s.child().checkNode(tree, dot, n.List); err != nil {
return err
}
if n.ElseList != nil {
if _, err := s.child().checkNode(tree, dot, n.ElseList); err != nil {
return err
}
}
return nil
}

Expand Down Expand Up @@ -238,7 +250,7 @@ func (s *scope) checkCommandNode(tree *parse.Tree, dot types.Type, n *parse.Comm
pt = sig.Params().At(i).Type()
}
if !types.AssignableTo(at, pt) {
return nil, fmt.Errorf("%s argument %d has type %s expected %s", n.Args[0], i-1, at, pt)
return nil, fmt.Errorf("%s argument %d has type %s expected %s", n.Args[0], i, at, pt)
}
}
return sig.Results().At(0).Type(), nil
Expand Down Expand Up @@ -286,11 +298,8 @@ func (s *scope) checkIdentifiers(tree *parse.Tree, dot types.Type, n parse.Node,
}

func (s *scope) checkRangeNode(tree *parse.Tree, dot types.Type, n *parse.RangeNode) error {
loopScope := &scope{
global: s.global,
variables: maps.Clone(s.variables),
}
pipeType, err := loopScope.checkNode(tree, dot, n.Pipe)
child := s.child()
pipeType, err := child.checkNode(tree, dot, n.Pipe)
if err != nil {
return err
}
Expand All @@ -305,11 +314,11 @@ func (s *scope) checkRangeNode(tree *parse.Tree, dot types.Type, n *parse.RangeN
default:
return fmt.Errorf("failed to range over %s", pipeType)
}
if _, err := loopScope.checkNode(tree, x, n.List); err != nil {
if _, err := child.checkNode(tree, x, n.List); err != nil {
return err
}
if n.ElseList != nil {
if _, err := loopScope.checkNode(tree, x, n.ElseList); err != nil {
if _, err := child.checkNode(tree, x, n.ElseList); err != nil {
return err
}
}
Expand Down
43 changes: 41 additions & 2 deletions internal/check/tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/crhntr/muxt"
"github.com/crhntr/muxt/internal/check"
"github.com/crhntr/muxt/internal/source"
)

func TestTree(t *testing.T) {
Expand Down Expand Up @@ -278,14 +279,48 @@ func TestTree(t *testing.T) {
Template: `{{$v := 1}}{{.F $v}}`,
Data: MethodWithIntParam{},
},
{
Name: "when there is an error in the else block",
Template: `{{$x := "wrong type"}}{{if false}}{{else}}{{.F $x}}{{end}}`,
Data: MethodWithIntParam{},
Error: func(t *testing.T, checkErr, _ error, tp types.Type) {
require.Error(t, checkErr)
require.ErrorContains(t, checkErr, ".F argument 0 has type untyped string expected int")
},
},
{
Name: "variable redefined in if block",
Template: `{{$x := 1}}{{if true}}{{$x := "str"}}{{end}}{{.F $x}}`,
Data: MethodWithIntParam{},
},
{
Name: "range variable does not clobber outer scope",
Template: `{{$x := 1}}{{range .Numbers}}{{$x := "str"}}{{end}}{{square $x}}`,
Data: MethodWithKeyValForSlices{},
},
{
Name: "range variable does not override outer scope",
Template: `{{$x := "str"}}{{range $x, $y := .Numbers}}{{$.F $x $y}}{{end}}{{printf $x}}`,
Data: MethodWithKeyValForSlices{},
},
{
Name: "source provided function",
Template: `{{square 5}}`,
Data: T{},
},
} {
t.Run(tt.Name, func(t *testing.T) {
templates, parseErr := template.New("template").Parse(tt.Template)
templates, parseErr := template.New("template").Funcs(template.FuncMap{
"square": square,
}).Parse(tt.Template)
require.NoError(t, parseErr)

dataType := checkTestPackage.Types.Scope().Lookup(reflect.TypeOf(tt.Data).Name()).Type()

if checkErr := check.Tree(templates.Tree, dataType, checkTestPackage.Types, checkTestPackage.Fset, newForrest(templates), nil); tt.Error != nil {
functions := source.DefaultFunctions(checkTestPackage.Types)
functions["square"] = checkTestPackage.Types.Scope().Lookup("square").(*types.Func).Signature()

if checkErr := check.Tree(templates.Tree, dataType, checkTestPackage.Types, checkTestPackage.Fset, newForrest(templates), functions); tt.Error != nil {
execErr := templates.Execute(io.Discard, tt.Data)
tt.Error(t, checkErr, execErr, dataType)
} else {
Expand Down Expand Up @@ -431,6 +466,10 @@ type MethodWithKeyValForMap struct {

func (MethodWithKeyValForMap) F(int16, float32) (_ T) { return }

func square(n int) int {
return n * n
}

func TestExampleTemplate(t *testing.T) {
packageList, loadErr := packages.Load(&packages.Config{
Mode: packages.NeedName | packages.NeedFiles | packages.NeedDeps | packages.NeedTypes,
Expand Down
49 changes: 31 additions & 18 deletions internal/source/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ import (
"golang.org/x/tools/go/packages"
)

type TemplateFuncMap = map[string]*types.Signature

func Templates(workingDirectory, templatesVariable string, pkg *packages.Package) (*template.Template, TemplateFuncMap, error) {
funcTypeMap := registerDefaultFunctions(pkg.Types)
func Templates(workingDirectory, templatesVariable string, pkg *packages.Package) (*template.Template, Functions, error) {
funcTypeMap := DefaultFunctions(pkg.Types)
for _, tv := range IterateValueSpecs(pkg.Syntax) {
i := slices.IndexFunc(tv.Names, func(e *ast.Ident) bool {
return e.Name == templatesVariable
Expand Down Expand Up @@ -53,19 +51,7 @@ func findPackage(pkg *types.Package, path string) (*types.Package, bool) {
return nil, false
}

func registerDefaultFunctions(pkg *types.Package) TemplateFuncMap {
funcTypeMap := make(TemplateFuncMap)
fmtPkg, ok := findPackage(pkg, "fmt")
if !ok || fmtPkg == nil {
return funcTypeMap
}
funcTypeMap["printf"] = fmtPkg.Scope().Lookup("Sprintf").Type().(*types.Signature)
funcTypeMap["print"] = fmtPkg.Scope().Lookup("Sprint").Type().(*types.Signature)
funcTypeMap["println"] = fmtPkg.Scope().Lookup("Sprintln").Type().(*types.Signature)
return funcTypeMap
}

func evaluateTemplateSelector(ts *template.Template, pkg *types.Package, expression ast.Expr, workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string, funcTypeMaps TemplateFuncMap, fm template.FuncMap) (*template.Template, error) {
func evaluateTemplateSelector(ts *template.Template, pkg *types.Package, expression ast.Expr, workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string, funcTypeMaps Functions, fm template.FuncMap) (*template.Template, error) {
call, ok := expression.(*ast.CallExpr)
if !ok {
return nil, contextError(workingDirectory, fileSet, expression.Pos(), fmt.Errorf("expected call expression"))
Expand Down Expand Up @@ -161,7 +147,7 @@ func evaluateTemplateSelector(ts *template.Template, pkg *types.Package, express
}
}

func evaluateFuncMap(workingDirectory, templatePackageIdent string, pkg *types.Package, fileSet *token.FileSet, call *ast.CallExpr, fm template.FuncMap, funcTypesMap TemplateFuncMap) error {
func evaluateFuncMap(workingDirectory, templatePackageIdent string, pkg *types.Package, fileSet *token.FileSet, call *ast.CallExpr, fm template.FuncMap, funcTypesMap Functions) error {
const funcMapTypeIdent = "FuncMap"
if len(call.Args) != 1 {
return contextError(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly 1 template.FuncMap composite literal argument"))
Expand Down Expand Up @@ -385,3 +371,30 @@ func relativeFilePaths(wd string, abs ...string) ([]string, error) {
}
return result, nil
}

type Functions map[string]*types.Signature

func NewFunctions(m map[string]*types.Signature) Functions {
return Functions(m)
}

func DefaultFunctions(pkg *types.Package) Functions {
funcTypeMap := make(Functions)
fmtPkg, ok := findPackage(pkg, "fmt")
if !ok || fmtPkg == nil {
return funcTypeMap
}
funcTypeMap["printf"] = fmtPkg.Scope().Lookup("Sprintf").Type().(*types.Signature)
funcTypeMap["print"] = fmtPkg.Scope().Lookup("Sprint").Type().(*types.Signature)
funcTypeMap["println"] = fmtPkg.Scope().Lookup("Sprintln").Type().(*types.Signature)
return funcTypeMap
}

func (functions Functions) FindFunction(name string) (*types.Signature, bool) {
m := (map[string]*types.Signature)(functions)
fn, ok := m[name]
if !ok {
return nil, false
}
return fn, true
}

0 comments on commit dfeec51

Please sign in to comment.