diff --git a/internal/check/tree.go b/internal/check/tree.go index dc4e855..ee72797 100644 --- a/internal/check/tree.go +++ b/internal/check/tree.go @@ -87,6 +87,8 @@ func (s *scope) checkNode(tree *parse.Tree, dot types.Type, node parse.Node) (ty return s.checkIdentifierNode(n) case *parse.TextNode: return nil, nil + case *parse.WithNode: + return nil, s.checkWithNode(tree, dot, n) default: return nil, fmt.Errorf("missing node type check %T", n) } @@ -166,11 +168,32 @@ func (s *scope) checkIfNode(tree *parse.Tree, dot types.Type, n *parse.IfNode) e if err != nil { return err } - if _, err := s.child().checkNode(tree, dot, n.List); err != nil { + ifScope := s.child() + if _, err := ifScope.checkNode(tree, dot, n.List); err != nil { return err } if n.ElseList != nil { - if _, err := s.child().checkNode(tree, dot, n.ElseList); err != nil { + elseScope := s.child() + if _, err := elseScope.checkNode(tree, dot, n.ElseList); err != nil { + return err + } + } + return nil +} + +func (s *scope) checkWithNode(tree *parse.Tree, dot types.Type, n *parse.WithNode) error { + child := s.child() + x, err := child.checkNode(tree, dot, n.Pipe) + if err != nil { + return err + } + withScope := child.child() + if _, err := withScope.checkNode(tree, x, n.List); err != nil { + return err + } + if n.ElseList != nil { + elseScope := child.child() + if _, err := elseScope.checkNode(tree, dot, n.ElseList); err != nil { return err } } @@ -201,6 +224,7 @@ func (s *scope) checkTemplateNode(tree *parse.Tree, dot types.Type, n *parse.Tem return err } x = tp + x = downgradeUntyped(x) } else { x = types.Typ[types.UntypedNil] } @@ -218,6 +242,30 @@ func (s *scope) checkTemplateNode(tree *parse.Tree, dot types.Type, n *parse.Tem return err } +func downgradeUntyped(x types.Type) types.Type { + if x == nil { + return x + } + basic, ok := x.Underlying().(*types.Basic) + if !ok { + return x + } + switch k := basic.Kind(); k { + case types.UntypedInt: + return types.Typ[types.Int].Underlying() + case types.UntypedRune: + return types.Typ[types.Rune].Underlying() + case types.UntypedFloat: + return types.Typ[types.Float64].Underlying() + case types.UntypedComplex: + return types.Typ[types.Complex128].Underlying() + case types.UntypedString: + return types.Typ[types.String].Underlying() + default: + return x + } +} + func (s *scope) checkFieldNode(tree *parse.Tree, dot types.Type, n *parse.FieldNode) (types.Type, error) { return s.checkIdentifiers(tree, dot, n, n.Ident) } @@ -249,7 +297,8 @@ func (s *scope) checkCommandNode(tree *parse.Tree, dot types.Type, n *parse.Comm } else { pt = sig.Params().At(i).Type() } - if !types.AssignableTo(at, pt) { + assignable := types.AssignableTo(at, pt) + if !assignable { return nil, fmt.Errorf("%s argument %d has type %s expected %s", n.Args[0], i, at, pt) } } diff --git a/internal/check/tree_test.go b/internal/check/tree_test.go index 9e85930..a56f4e8 100644 --- a/internal/check/tree_test.go +++ b/internal/check/tree_test.go @@ -5,6 +5,7 @@ import ( "go/types" "html/template" "io" + "math" "reflect" "slices" "testing" @@ -308,19 +309,112 @@ func TestTree(t *testing.T) { Template: `{{square 5}}`, Data: T{}, }, + { + Name: "with expression", + Template: `{{$x := 1}}{{with $x := .Numbers}}{{$x}}{{end}}`, + Data: MethodWithKeyValForSlices{}, + }, + { + Name: "with expression declares variable with same name as parent scope", + Template: `{{$x := 1.2}}{{with $x := ceil $x}}{{$x}}{{end}}`, + Data: MethodWithKeyValForSlices{}, + }, + { + Name: "with expression has action with wrong dot type used in call", + Template: `{{with $x := "wrong"}}{{expectInt .}}{{else}}{{end}}`, + Data: T{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.ErrorContains(t, execErr, "wrong type for value; expected int; got string") + require.ErrorContains(t, checkErr, "expectInt argument 0 has type untyped string expected int") + }, + }, + { + Name: "with else expression has action with correct dot type used in call", + Template: `{{with $x := 12}}{{with $x := 1.2}}{{else}}{{expectInt $x}}{{end}}{{end}}`, + Data: T{}, + }, + { + Name: "with else expression has action with wrong dot type used in call", + Template: `{{with $outer := 12}}{{with $x := true}}{{else}}{{expectString .}}{{end}}{{end}}`, + Data: T{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.NoError(t, execErr) + require.ErrorContains(t, checkErr, "expectString argument 0 has type untyped int expected string") + }, + }, + { + Name: "complex number parses", + Template: `{{$x := 2i}}{{printf "%T" $x}}`, + Data: T{}, + }, + { + Name: "template node without parameter", + Template: `{{define "t"}}{{end}}{{template "t"}}`, + Data: T{}, + }, + { + Name: "template wrong input type", + Template: `{{define "t"}}{{expectInt .}}{{end}}{{if false}}{{template "t" 1.2}}{{end}}`, + Data: T{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.NoError(t, execErr) + require.ErrorContains(t, checkErr, "expectInt argument 0 has type float64 expected int") + }, + }, + { + Name: "it downgrades untyped integers", + Template: `{{define "t"}}{{expectInt8 .}}{{end}}{{if false}}{{template "t" 12}}{{end}}`, + Data: T{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.NoError(t, execErr) + require.ErrorContains(t, checkErr, "expectInt8 argument 0 has type int expected int8") + }, + }, + { + Name: "it downgrades untyped floats", + Template: `{{define "t"}}{{expectFloat32 .}}{{end}}{{if false}}{{template "t" 1.2}}{{end}}`, + Data: T{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.NoError(t, execErr) + require.ErrorContains(t, checkErr, "expectFloat32 argument 0 has type float64 expected float32") + }, + }, + { + Name: "it downgrades untyped complex", + Template: `{{define "t"}}{{expectComplex64 .}}{{end}}{{if false}}{{template "t" 2i}}{{end}}`, + Data: T{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.NoError(t, execErr) + require.ErrorContains(t, checkErr, "expectComplex64 argument 0 has type complex128 expected complex64") + }, + }, + // not sure if I should be downgrading bool, it should be fine to let it be since there is only one basic bool type } { t.Run(tt.Name, func(t *testing.T) { - templates, parseErr := template.New("template").Funcs(template.FuncMap{ - "square": square, - }).Parse(tt.Template) + functions := template.FuncMap{ + "square": square, + "ceil": ceil, + "expectInt": expectInt, + "expectFloat64": expectFloat64, + "expectString": expectString, + "expectInt8": expectInt8, + "expectFloat32": expectFloat32, + "expectComplex64": expectComplex64, + } + + templates, parseErr := template.New("template").Funcs(functions).Parse(tt.Template) require.NoError(t, parseErr) dataType := checkTestPackage.Types.Scope().Lookup(reflect.TypeOf(tt.Data).Name()).Type() - functions := source.DefaultFunctions(checkTestPackage.Types) - functions["square"] = checkTestPackage.Types.Scope().Lookup("square").(*types.Func).Signature() + sourceFunctions := source.DefaultFunctions(checkTestPackage.Types) + for name := range functions { + fn := checkTestPackage.Types.Scope().Lookup(name).(*types.Func).Signature() + require.NotNil(t, fn) + sourceFunctions[name] = fn + } - if checkErr := check.Tree(templates.Tree, dataType, checkTestPackage.Types, checkTestPackage.Fset, newForrest(templates), functions); tt.Error != nil { + if checkErr := check.Tree(templates.Tree, dataType, checkTestPackage.Types, checkTestPackage.Fset, newForrest(templates), sourceFunctions); tt.Error != nil { execErr := templates.Execute(io.Discard, tt.Data) tt.Error(t, checkErr, execErr, dataType) } else { @@ -470,6 +564,22 @@ func square(n int) int { return n * n } +func ceil(n float64) int { + return int(math.Ceil(n)) +} + +func expectInt(n int) int { return n } + +func expectFloat64(n float64) float64 { return n } + +func expectString(s string) string { return s } + +func expectInt8(n int8) int8 { return n } + +func expectFloat32(n float32) float32 { return n } + +func expectComplex64(n complex64) complex64 { return n } + func TestExampleTemplate(t *testing.T) { packageList, loadErr := packages.Load(&packages.Config{ Mode: packages.NeedName | packages.NeedFiles | packages.NeedDeps | packages.NeedTypes,