Skip to content

Commit

Permalink
feat: improve if/with else node type check
Browse files Browse the repository at this point in the history
  • Loading branch information
crhntr committed Dec 21, 2024
1 parent dfeec51 commit 3bc8f7a
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 9 deletions.
55 changes: 52 additions & 3 deletions internal/check/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ func (s *scope) checkNode(tree *parse.Tree, dot types.Type, node parse.Node) (ty
return s.checkIdentifierNode(n)
case *parse.TextNode:
return nil, nil
case *parse.WithNode:
return nil, s.checkWithNode(tree, dot, n)
default:
return nil, fmt.Errorf("missing node type check %T", n)
}
Expand Down Expand Up @@ -166,11 +168,32 @@ func (s *scope) checkIfNode(tree *parse.Tree, dot types.Type, n *parse.IfNode) e
if err != nil {
return err
}
if _, err := s.child().checkNode(tree, dot, n.List); err != nil {
ifScope := s.child()
if _, err := ifScope.checkNode(tree, dot, n.List); err != nil {
return err
}
if n.ElseList != nil {
if _, err := s.child().checkNode(tree, dot, n.ElseList); err != nil {
elseScope := s.child()
if _, err := elseScope.checkNode(tree, dot, n.ElseList); err != nil {
return err
}
}
return nil
}

func (s *scope) checkWithNode(tree *parse.Tree, dot types.Type, n *parse.WithNode) error {
child := s.child()
x, err := child.checkNode(tree, dot, n.Pipe)
if err != nil {
return err
}
withScope := child.child()
if _, err := withScope.checkNode(tree, x, n.List); err != nil {
return err
}
if n.ElseList != nil {
elseScope := child.child()
if _, err := elseScope.checkNode(tree, dot, n.ElseList); err != nil {
return err
}
}
Expand Down Expand Up @@ -201,6 +224,7 @@ func (s *scope) checkTemplateNode(tree *parse.Tree, dot types.Type, n *parse.Tem
return err
}
x = tp
x = downgradeUntyped(x)
} else {
x = types.Typ[types.UntypedNil]
}
Expand All @@ -218,6 +242,30 @@ func (s *scope) checkTemplateNode(tree *parse.Tree, dot types.Type, n *parse.Tem
return err
}

func downgradeUntyped(x types.Type) types.Type {
if x == nil {
return x
}
basic, ok := x.Underlying().(*types.Basic)
if !ok {
return x
}
switch k := basic.Kind(); k {
case types.UntypedInt:
return types.Typ[types.Int].Underlying()
case types.UntypedRune:
return types.Typ[types.Rune].Underlying()
case types.UntypedFloat:
return types.Typ[types.Float64].Underlying()
case types.UntypedComplex:
return types.Typ[types.Complex128].Underlying()
case types.UntypedString:
return types.Typ[types.String].Underlying()
default:
return x
}
}

func (s *scope) checkFieldNode(tree *parse.Tree, dot types.Type, n *parse.FieldNode) (types.Type, error) {
return s.checkIdentifiers(tree, dot, n, n.Ident)
}
Expand Down Expand Up @@ -249,7 +297,8 @@ func (s *scope) checkCommandNode(tree *parse.Tree, dot types.Type, n *parse.Comm
} else {
pt = sig.Params().At(i).Type()
}
if !types.AssignableTo(at, pt) {
assignable := types.AssignableTo(at, pt)
if !assignable {
return nil, fmt.Errorf("%s argument %d has type %s expected %s", n.Args[0], i, at, pt)
}
}
Expand Down
122 changes: 116 additions & 6 deletions internal/check/tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"go/types"
"html/template"
"io"
"math"
"reflect"
"slices"
"testing"
Expand Down Expand Up @@ -308,19 +309,112 @@ func TestTree(t *testing.T) {
Template: `{{square 5}}`,
Data: T{},
},
{
Name: "with expression",
Template: `{{$x := 1}}{{with $x := .Numbers}}{{$x}}{{end}}`,
Data: MethodWithKeyValForSlices{},
},
{
Name: "with expression declares variable with same name as parent scope",
Template: `{{$x := 1.2}}{{with $x := ceil $x}}{{$x}}{{end}}`,
Data: MethodWithKeyValForSlices{},
},
{
Name: "with expression has action with wrong dot type used in call",
Template: `{{with $x := "wrong"}}{{expectInt .}}{{else}}{{end}}`,
Data: T{},
Error: func(t *testing.T, checkErr, execErr error, tp types.Type) {
require.ErrorContains(t, execErr, "wrong type for value; expected int; got string")
require.ErrorContains(t, checkErr, "expectInt argument 0 has type untyped string expected int")
},
},
{
Name: "with else expression has action with correct dot type used in call",
Template: `{{with $x := 12}}{{with $x := 1.2}}{{else}}{{expectInt $x}}{{end}}{{end}}`,
Data: T{},
},
{
Name: "with else expression has action with wrong dot type used in call",
Template: `{{with $outer := 12}}{{with $x := true}}{{else}}{{expectString .}}{{end}}{{end}}`,
Data: T{},
Error: func(t *testing.T, checkErr, execErr error, tp types.Type) {
require.NoError(t, execErr)
require.ErrorContains(t, checkErr, "expectString argument 0 has type untyped int expected string")
},
},
{
Name: "complex number parses",
Template: `{{$x := 2i}}{{printf "%T" $x}}`,
Data: T{},
},
{
Name: "template node without parameter",
Template: `{{define "t"}}{{end}}{{template "t"}}`,
Data: T{},
},
{
Name: "template wrong input type",
Template: `{{define "t"}}{{expectInt .}}{{end}}{{if false}}{{template "t" 1.2}}{{end}}`,
Data: T{},
Error: func(t *testing.T, checkErr, execErr error, tp types.Type) {
require.NoError(t, execErr)
require.ErrorContains(t, checkErr, "expectInt argument 0 has type float64 expected int")
},
},
{
Name: "it downgrades untyped integers",
Template: `{{define "t"}}{{expectInt8 .}}{{end}}{{if false}}{{template "t" 12}}{{end}}`,
Data: T{},
Error: func(t *testing.T, checkErr, execErr error, tp types.Type) {
require.NoError(t, execErr)
require.ErrorContains(t, checkErr, "expectInt8 argument 0 has type int expected int8")
},
},
{
Name: "it downgrades untyped floats",
Template: `{{define "t"}}{{expectFloat32 .}}{{end}}{{if false}}{{template "t" 1.2}}{{end}}`,
Data: T{},
Error: func(t *testing.T, checkErr, execErr error, tp types.Type) {
require.NoError(t, execErr)
require.ErrorContains(t, checkErr, "expectFloat32 argument 0 has type float64 expected float32")
},
},
{
Name: "it downgrades untyped complex",
Template: `{{define "t"}}{{expectComplex64 .}}{{end}}{{if false}}{{template "t" 2i}}{{end}}`,
Data: T{},
Error: func(t *testing.T, checkErr, execErr error, tp types.Type) {
require.NoError(t, execErr)
require.ErrorContains(t, checkErr, "expectComplex64 argument 0 has type complex128 expected complex64")
},
},
// not sure if I should be downgrading bool, it should be fine to let it be since there is only one basic bool type
} {
t.Run(tt.Name, func(t *testing.T) {
templates, parseErr := template.New("template").Funcs(template.FuncMap{
"square": square,
}).Parse(tt.Template)
functions := template.FuncMap{
"square": square,
"ceil": ceil,
"expectInt": expectInt,
"expectFloat64": expectFloat64,
"expectString": expectString,
"expectInt8": expectInt8,
"expectFloat32": expectFloat32,
"expectComplex64": expectComplex64,
}

templates, parseErr := template.New("template").Funcs(functions).Parse(tt.Template)
require.NoError(t, parseErr)

dataType := checkTestPackage.Types.Scope().Lookup(reflect.TypeOf(tt.Data).Name()).Type()

functions := source.DefaultFunctions(checkTestPackage.Types)
functions["square"] = checkTestPackage.Types.Scope().Lookup("square").(*types.Func).Signature()
sourceFunctions := source.DefaultFunctions(checkTestPackage.Types)
for name := range functions {
fn := checkTestPackage.Types.Scope().Lookup(name).(*types.Func).Signature()
require.NotNil(t, fn)
sourceFunctions[name] = fn
}

if checkErr := check.Tree(templates.Tree, dataType, checkTestPackage.Types, checkTestPackage.Fset, newForrest(templates), functions); tt.Error != nil {
if checkErr := check.Tree(templates.Tree, dataType, checkTestPackage.Types, checkTestPackage.Fset, newForrest(templates), sourceFunctions); tt.Error != nil {
execErr := templates.Execute(io.Discard, tt.Data)
tt.Error(t, checkErr, execErr, dataType)
} else {
Expand Down Expand Up @@ -470,6 +564,22 @@ func square(n int) int {
return n * n
}

func ceil(n float64) int {
return int(math.Ceil(n))
}

func expectInt(n int) int { return n }

func expectFloat64(n float64) float64 { return n }

func expectString(s string) string { return s }

func expectInt8(n int8) int8 { return n }

func expectFloat32(n float32) float32 { return n }

func expectComplex64(n complex64) complex64 { return n }

func TestExampleTemplate(t *testing.T) {
packageList, loadErr := packages.Load(&packages.Config{
Mode: packages.NeedName | packages.NeedFiles | packages.NeedDeps | packages.NeedTypes,
Expand Down

0 comments on commit 3bc8f7a

Please sign in to comment.