diff --git a/cmd/muxt/generate.go b/cmd/muxt/generate.go index ff59073..1f4b1d6 100644 --- a/cmd/muxt/generate.go +++ b/cmd/muxt/generate.go @@ -11,13 +11,17 @@ import ( "github.com/crhntr/muxt/internal/configuration" ) -const CodeGenerationComment = "// Code generated by muxt. DO NOT EDIT." +const ( + CodeGenerationComment = "// Code generated by muxt. DO NOT EDIT." + experimentCheckTypesEnvVar = "MUXT_EXPERIMENT_CHECK_TYPES" +) -func generateCommand(args []string, workingDirectory string, stdout, stderr io.Writer) error { +func generateCommand(workingDirectory string, args []string, getEnv func(string) string, stdout, stderr io.Writer) error { config, err := configuration.NewRoutesFileConfiguration(args, stderr) if err != nil { return err } + config.ExperimentalCheckTypes = getEnv(experimentCheckTypesEnvVar) == "true" s, err := muxt.TemplateRoutesFile(workingDirectory, log.New(stdout, "", 0), config) if err != nil { return err diff --git a/cmd/muxt/main.go b/cmd/muxt/main.go index f738cbd..a17b1ef 100644 --- a/cmd/muxt/main.go +++ b/cmd/muxt/main.go @@ -1,13 +1,10 @@ package main import ( - "bytes" "flag" "fmt" "io" "os" - - "rsc.io/script" ) func main() { @@ -19,11 +16,11 @@ func main() { os.Exit(handleError(command(wd, flag.Args(), os.Getenv, os.Stdout, os.Stderr))) } -func command(wd string, args []string, _ func(string) string, stdout, stderr io.Writer) error { +func command(wd string, args []string, getEnv func(string) string, stdout, stderr io.Writer) error { if len(args) > 0 { switch cmd, cmdArgs := args[0], args[1:]; cmd { case "generate", "gen", "g": - return generateCommand(cmdArgs, wd, stdout, stderr) + return generateCommand(wd, cmdArgs, getEnv, stdout, stderr) case "version", "v": return versionCommand(stdout) } @@ -38,22 +35,3 @@ func handleError(err error) int { } return 0 } - -func scriptCommand() script.Cmd { - return script.Command(script.CmdUsage{ - Summary: "muxt", - Args: "", - }, func(state *script.State, args ...string) (script.WaitFunc, error) { - return func(state *script.State) (string, string, error) { - var stdout, stderr bytes.Buffer - err := command(state.Getwd(), args, func(s string) string { - e, _ := state.LookupEnv(s) - return e - }, &stdout, &stderr) - if err != nil { - stderr.WriteString(err.Error()) - } - return stdout.String(), stderr.String(), err - }, nil - }) -} diff --git a/cmd/muxt/script.go b/cmd/muxt/script.go new file mode 100644 index 0000000..f07f046 --- /dev/null +++ b/cmd/muxt/script.go @@ -0,0 +1,29 @@ +package main + +import ( + "bytes" + + "rsc.io/script" +) + +func scriptCommand() script.Cmd { + return script.Command(script.CmdUsage{ + Summary: "muxt", + Args: "", + }, func(state *script.State, args ...string) (script.WaitFunc, error) { + return func(state *script.State) (string, string, error) { + var stdout, stderr bytes.Buffer + err := command(state.Getwd(), args, func(s string) string { + if s == experimentCheckTypesEnvVar { + return "true" + } + e, _ := state.LookupEnv(s) + return e + }, &stdout, &stderr) + if err != nil { + stderr.WriteString(err.Error()) + } + return stdout.String(), stderr.String(), err + }, nil + }) +} diff --git a/go.sum b/go.sum index 5efcf7b..55b7544 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,11 @@ -github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss= -github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU= github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM= github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/crhntr/dom v0.1.0-dev.5 h1:/7joIQhGxSKbuJyD1xDfheT0/wBdBSTXuRh+mLMqqxM= -github.com/crhntr/dom v0.1.0-dev.5/go.mod h1:4czJjiQOqiFq99bP/dfi6hjFXKnFptRyqwC8PJNQQkY= github.com/crhntr/dom v0.1.0-dev.6 h1:iUkl5c1i3QRRyYjdozGDuNnYdEQzZp1sFk9QTmFrO4c= github.com/crhntr/dom v0.1.0-dev.6/go.mod h1:V2RcN/d7pdUo5romb+mk/K4nm4QwAmwuJ259vdJGE/M= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -21,8 +18,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -37,23 +32,16 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= -golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= -golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= -golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= -golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -62,10 +50,7 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= -golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -73,7 +58,6 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= @@ -83,7 +67,6 @@ golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXct golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= @@ -104,10 +87,6 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= -golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= -golang.org/x/tools v0.27.0 h1:qEKojBykQkQ4EynWy4S8Weg69NumxKdn40Fce3uc/8o= -golang.org/x/tools v0.27.0/go.mod h1:sUi0ZgbwW9ZPAq26Ekut+weQPR5eIM6GQLQ1Yjm1H0Q= golang.org/x/tools v0.28.0 h1:WuB6qZ4RPCQo5aP3WdKZS7i595EdWqWR8vqJTlwTVK8= golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/assert/assert.go b/internal/assert/assert.go new file mode 100644 index 0000000..af0b600 --- /dev/null +++ b/internal/assert/assert.go @@ -0,0 +1,27 @@ +package assert + +import ( + "fmt" + "os" +) + +const pleaseCreateAnIssue = `If you get this error, please file an issue on https://github.com/crhntr/muxt/issues/new with a description of the inputs. I did not expect this to be possible.` + +func Exit(description, message string, in ...any) { + _, _ = fmt.Fprintf(os.Stderr, message, in...) + _, _ = fmt.Fprintf(os.Stderr, "\n"+description) + _, _ = fmt.Fprintln(os.Stderr, "\n"+pleaseCreateAnIssue) + os.Exit(1) +} + +func Len[T any](in []T, n int, description string) { + if len(in) != n { + Exit(description, "expected length %d got %d\n", n, len(in)) + } +} + +func MaxLen[T any](in []T, n int, description string) { + if len(in) > n { + Exit(description, "expected length less than %d got %d\n", n, len(in)) + } +} diff --git a/internal/check/exec_test.go b/internal/check/exec_test.go new file mode 100644 index 0000000..1d9e499 --- /dev/null +++ b/internal/check/exec_test.go @@ -0,0 +1,556 @@ +package check_test + +import ( + "bytes" + "fmt" + "go/types" + "io" + "reflect" + "testing" + "text/template" + "text/template/parse" + + "github.com/stretchr/testify/require" + "golang.org/x/tools/go/packages" + + "github.com/crhntr/muxt/internal/check" + "github.com/crhntr/muxt/internal/source" +) + +func findTextTree(tmpl *template.Template) check.FindTreeFunc { + return func(name string) (*parse.Tree, bool) { + ts := tmpl.Lookup(name) + if ts == nil { + return nil, false + } + return ts.Tree, true + } +} + +// bigInt and bigUint are hex string representing numbers either side +// of the max int boundary. +// We do it this way so the test doesn't depend on ints being 32 bits. +var ( + bigInt = fmt.Sprintf("0x%x", int(1<", tVal, true}, + {"map .one interface", "{{.MXI.one}}", "1", tVal, true}, + {"map .WRONG args", "{{.MSI.one 1}}", "", tVal, false}, + {"map .WRONG type", "{{.MII.one}}", "", tVal, false}, + + // Dots of all kinds to test basic evaluation. + {"dot int", "<{{.}}>", "<13>", 13, true}, + {"dot uint", "<{{.}}>", "<14>", uint(14), true}, + {"dot float", "<{{.}}>", "<15.1>", 15.1, true}, + {"dot bool", "<{{.}}>", "", true, true}, + {"dot complex", "<{{.}}>", "<(16.2-17i)>", 16.2 - 17i, true}, + {"dot string", "<{{.}}>", "", "hello", true}, + {"dot slice", "<{{.}}>", "<[-1 -2 -3]>", []int{-1, -2, -3}, true}, + {"dot map", "<{{.}}>", "", map[string]int{"two": 22}, true}, + {"dot struct", "<{{.}}>", "<{7 seven}>", struct { + a int + b string + }{7, "seven"}, true}, + + // Variables. + {"$ int", "{{$}}", "123", 123, true}, + {"$.I", "{{$.I}}", "17", tVal, true}, + {"$.U.V", "{{$.U.V}}", "v", tVal, true}, + {"declare in action", "{{$x := $.U.V}}{{$x}}", "v", tVal, true}, + {"simple assignment", "{{$x := 2}}{{$x = 3}}{{$x}}", "3", tVal, true}, + {"nested assignment", + "{{$x := 2}}{{if true}}{{$x = 3}}{{end}}{{$x}}", + "3", tVal, true}, + {"nested assignment changes the last declaration", + "{{$x := 1}}{{if true}}{{$x := 2}}{{if true}}{{$x = 3}}{{end}}{{end}}{{$x}}", + "1", tVal, true}, + + // Type with String method. + {"V{6666}.String()", "-{{.V0}}-", "-<6666>-", tVal, true}, + {"&V{7777}.String()", "-{{.V1}}-", "-<7777>-", tVal, true}, + {"(*V)(nil).String()", "-{{.V2}}-", "-nilV-", tVal, true}, + + // Type with Error method. + {"W{888}.Error()", "-{{.W0}}-", "-[888]-", tVal, true}, + {"&W{999}.Error()", "-{{.W1}}-", "-[999]-", tVal, true}, + {"(*W)(nil).Error()", "-{{.W2}}-", "-nilW-", tVal, true}, + + // Pointers. + {"*int", "{{.PI}}", "23", tVal, true}, + {"*string", "{{.PS}}", "a string", tVal, true}, + {"*[]int", "{{.PSI}}", "[21 22 23]", tVal, true}, + {"*[]int[1]", "{{index .PSI 1}}", "22", tVal, true}, + {"NIL", "{{.NIL}}", "", tVal, true}, + + // Empty interfaces holding values. + {"empty nil", "{{.Empty0}}", "", tVal, true}, + {"empty with int", "{{.Empty1}}", "3", tVal, true}, + {"empty with string", "{{.Empty2}}", "empty2", tVal, true}, + {"empty with slice", "{{.Empty3}}", "[7 8]", tVal, true}, + {"empty with struct", "{{.Empty4}}", "{UinEmpty}", tVal, true}, + {"empty with struct, field", "{{.Empty4.V}}", "UinEmpty", tVal, true}, + + // Edge cases with with an interface value + {"field on interface", "{{.foo}}", "", nil, true}, + {"field on parenthesized interface", "{{(.).foo}}", "", nil, true}, + + // Issue 31810: Parenthesized first element of pipeline with arguments. + // See also TestIssue31810. + {"unparenthesized non-function", "{{1 2}}", "", nil, false}, + {"parenthesized non-function", "{{(1) 2}}", "", nil, false}, + {"parenthesized non-function with no args", "{{(1)}}", "1", nil, true}, // This is fine. + + // Method calls. + {".Method0", "-{{.Method0}}-", "-M0-", tVal, true}, + {".Method1(1234)", "-{{.Method1 1234}}-", "-1234-", tVal, true}, + {".Method1(.I)", "-{{.Method1 .I}}-", "-17-", tVal, true}, + {".Method2(3, .X)", "-{{.Method2 3 .X}}-", "-Method2: 3 x-", tVal, true}, + {".Method2(.U16, `str`)", "-{{.Method2 .U16 `str`}}-", "-Method2: 16 str-", tVal, true}, + {".Method2(.U16, $x)", "{{if $x := .X}}-{{.Method2 .U16 $x}}{{end}}-", "-Method2: 16 x-", tVal, true}, + {".Method3(nil constant)", "-{{.Method3 nil}}-", "-Method3: -", tVal, true}, + {".Method3(nil value)", "-{{.Method3 .MXI.unset}}-", "-Method3: -", tVal, true}, + {"method on var", "{{if $x := .}}-{{$x.Method2 .U16 $x.X}}{{end}}-", "-Method2: 16 x-", tVal, true}, + {"method on chained var", + "{{range .MSIone}}{{if $.U.TrueFalse $.True}}{{$.U.TrueFalse $.True}}{{else}}WRONG{{end}}{{end}}", + "true", tVal, true}, + {"chained method", + "{{range .MSIone}}{{if $.GetU.TrueFalse $.True}}{{$.U.TrueFalse $.True}}{{else}}WRONG{{end}}{{end}}", + "true", tVal, true}, + {"chained method on variable", + "{{with $x := .}}{{with .SI}}{{$.GetU.TrueFalse $.True}}{{end}}{{end}}", + "true", tVal, true}, + {".NilOKFunc not nil", "{{call .NilOKFunc .PI}}", "false", tVal, true}, + {".NilOKFunc nil", "{{call .NilOKFunc nil}}", "true", tVal, true}, + {"method on nil value from slice", "-{{range .}}{{.Method1 1234}}{{end}}-", "-1234-", tSliceOfNil, true}, + {"method on typed nil interface value", "{{.NonEmptyInterfaceTypedNil.Method0}}", "M0", tVal, true}, + + // Function call builtin. + {".BinaryFunc", "{{call .BinaryFunc `1` `2`}}", "[1=2]", tVal, true}, + {".VariadicFunc0", "{{call .VariadicFunc}}", "<>", tVal, true}, + {".VariadicFunc2", "{{call .VariadicFunc `he` `llo`}}", "", tVal, true}, + {".VariadicFuncInt", "{{call .VariadicFuncInt 33 `he` `llo`}}", "33=", tVal, true}, + {"if .BinaryFunc call", "{{ if .BinaryFunc}}{{call .BinaryFunc `1` `2`}}{{end}}", "[1=2]", tVal, true}, + {"if not .BinaryFunc call", "{{ if not .BinaryFunc}}{{call .BinaryFunc `1` `2`}}{{else}}No{{end}}", "No", tVal, true}, + {"Interface Call", `{{stringer .S}}`, "foozle", map[string]any{"S": bytes.NewBufferString("foozle")}, true}, + {".ErrFunc", "{{call .ErrFunc}}", "bla", tVal, true}, + {"call nil", "{{call nil}}", "", tVal, false}, + + // Erroneous function calls (check args). + {".BinaryFuncTooFew", "{{call .BinaryFunc `1`}}", "", tVal, false}, + {".BinaryFuncTooMany", "{{call .BinaryFunc `1` `2` `3`}}", "", tVal, false}, + {".BinaryFuncBad0", "{{call .BinaryFunc 1 3}}", "", tVal, false}, + {".BinaryFuncBad1", "{{call .BinaryFunc `1` 3}}", "", tVal, false}, + {".VariadicFuncBad0", "{{call .VariadicFunc 3}}", "", tVal, false}, + {".VariadicFuncIntBad0", "{{call .VariadicFuncInt}}", "", tVal, false}, + {".VariadicFuncIntBad`", "{{call .VariadicFuncInt `x`}}", "", tVal, false}, + {".VariadicFuncNilBad", "{{call .VariadicFunc nil}}", "", tVal, false}, + + // Pipelines. + {"pipeline", "-{{.Method0 | .Method2 .U16}}-", "-Method2: 16 M0-", tVal, true}, + {"pipeline func", "-{{call .VariadicFunc `llo` | call .VariadicFunc `he` }}-", "->-", tVal, true}, + + // Nil values aren't missing arguments. + {"nil pipeline", "{{ .Empty0 | call .NilOKFunc }}", "true", tVal, true}, + {"nil call arg", "{{ call .NilOKFunc .Empty0 }}", "true", tVal, true}, + {"bad nil pipeline", "{{ .Empty0 | .VariadicFunc }}", "", tVal, false}, + + // Parenthesized expressions + {"parens in pipeline", "{{printf `%d %d %d` (1) (2 | add 3) (add 4 (add 5 6))}}", "1 5 15", tVal, true}, + + // Parenthesized expressions with field accesses + {"parens: $ in paren", "{{($).X}}", "x", tVal, true}, + {"parens: $.GetU in paren", "{{($.GetU).V}}", "v", tVal, true}, + {"parens: $ in paren in pipe", "{{($ | echo).X}}", "x", tVal, true}, + {"parens: spaces and args", `{{(makemap "up" "down" "left" "right").left}}`, "right", tVal, true}, + + // If. + {"if true", "{{if true}}TRUE{{end}}", "TRUE", tVal, true}, + {"if false", "{{if false}}TRUE{{else}}FALSE{{end}}", "FALSE", tVal, true}, + {"if nil", "{{if nil}}TRUE{{end}}", "", tVal, false}, + {"if on typed nil interface value", "{{if .NonEmptyInterfaceTypedNil}}TRUE{{ end }}", "", tVal, true}, + {"if 1", "{{if 1}}NON-ZERO{{else}}ZERO{{end}}", "NON-ZERO", tVal, true}, + {"if 0", "{{if 0}}NON-ZERO{{else}}ZERO{{end}}", "ZERO", tVal, true}, + {"if 1.5", "{{if 1.5}}NON-ZERO{{else}}ZERO{{end}}", "NON-ZERO", tVal, true}, + {"if 0.0", "{{if .FloatZero}}NON-ZERO{{else}}ZERO{{end}}", "ZERO", tVal, true}, + {"if 1.5i", "{{if 1.5i}}NON-ZERO{{else}}ZERO{{end}}", "NON-ZERO", tVal, true}, + {"if 0.0i", "{{if .ComplexZero}}NON-ZERO{{else}}ZERO{{end}}", "ZERO", tVal, true}, + {"if emptystring", "{{if ``}}NON-EMPTY{{else}}EMPTY{{end}}", "EMPTY", tVal, true}, + {"if string", "{{if `notempty`}}NON-EMPTY{{else}}EMPTY{{end}}", "NON-EMPTY", tVal, true}, + {"if emptyslice", "{{if .SIEmpty}}NON-EMPTY{{else}}EMPTY{{end}}", "EMPTY", tVal, true}, + {"if slice", "{{if .SI}}NON-EMPTY{{else}}EMPTY{{end}}", "NON-EMPTY", tVal, true}, + {"if emptymap", "{{if .MSIEmpty}}NON-EMPTY{{else}}EMPTY{{end}}", "EMPTY", tVal, true}, + {"if map", "{{if .MSI}}NON-EMPTY{{else}}EMPTY{{end}}", "NON-EMPTY", tVal, true}, + {"if map unset", "{{if .MXI.none}}NON-ZERO{{else}}ZERO{{end}}", "ZERO", tVal, true}, + {"if map not unset", "{{if not .MXI.none}}ZERO{{else}}NON-ZERO{{end}}", "ZERO", tVal, true}, + {"if $x with $y int", "{{if $x := true}}{{with $y := .I}}{{$x}},{{$y}}{{end}}{{end}}", "true,17", tVal, true}, + {"if $x with $x int", "{{if $x := true}}{{with $x := .I}}{{$x}},{{end}}{{$x}}{{end}}", "17,true", tVal, true}, + {"if else if", "{{if false}}FALSE{{else if true}}TRUE{{end}}", "TRUE", tVal, true}, + {"if else chain", "{{if eq 1 3}}1{{else if eq 2 3}}2{{else if eq 3 3}}3{{end}}", "3", tVal, true}, + + // Print etc. + {"print", `{{print "hello, print"}}`, "hello, print", tVal, true}, + {"print 123", `{{print 1 2 3}}`, "1 2 3", tVal, true}, + {"print nil", `{{print nil}}`, "", tVal, true}, + {"println", `{{println 1 2 3}}`, "1 2 3\n", tVal, true}, + {"printf int", `{{printf "%04x" 127}}`, "007f", tVal, true}, + {"printf float", `{{printf "%g" 3.5}}`, "3.5", tVal, true}, + {"printf complex", `{{printf "%g" 1+7i}}`, "(1+7i)", tVal, true}, + {"printf string", `{{printf "%s" "hello"}}`, "hello", tVal, true}, + {"printf function", `{{printf "%#q" zeroArgs}}`, "`zeroArgs`", tVal, true}, + {"printf field", `{{printf "%s" .U.V}}`, "v", tVal, true}, + {"printf method", `{{printf "%s" .Method0}}`, "M0", tVal, true}, + {"printf dot", `{{with .I}}{{printf "%d" .}}{{end}}`, "17", tVal, true}, + {"printf var", `{{with $x := .I}}{{printf "%d" $x}}{{end}}`, "17", tVal, true}, + {"printf lots", `{{printf "%d %s %g %s" 127 "hello" 7-3i .Method0}}`, "127 hello (7-3i) M0", tVal, true}, + + // HTML. + {"html", `{{html ""}}`, + "<script>alert("XSS");</script>", nil, true}, + {"html pipeline", `{{printf "" | html}}`, + "<script>alert("XSS");</script>", nil, true}, + {"html", `{{html .PS}}`, "a string", tVal, true}, + {"html typed nil", `{{html .NIL}}`, "<nil>", tVal, true}, + {"html untyped nil", `{{html .Empty0}}`, "<no value>", tVal, true}, + + // JavaScript. + {"js", `{{js .}}`, `It\'d be nice.`, `It'd be nice.`, true}, + + // URL query. + {"urlquery", `{{"http://www.example.org/"|urlquery}}`, "http%3A%2F%2Fwww.example.org%2F", nil, true}, + + // Booleans + {"not", "{{not true}} {{not false}}", "false true", nil, true}, + {"and", "{{and false 0}} {{and 1 0}} {{and 0 true}} {{and 1 1}}", "false 0 0 1", nil, true}, + {"or", "{{or 0 0}} {{or 1 0}} {{or 0 true}} {{or 1 1}}", "0 1 true 1", nil, true}, + {"or short-circuit", "{{or 0 1 (die)}}", "1", nil, true}, + {"and short-circuit", "{{and 1 0 (die)}}", "0", nil, true}, + {"or short-circuit2", "{{or 0 0 (die)}}", "", nil, false}, + {"and short-circuit2", "{{and 1 1 (die)}}", "", nil, false}, + {"and pipe-true", "{{1 | and 1}}", "1", nil, true}, + {"and pipe-false", "{{0 | and 1}}", "0", nil, true}, + {"or pipe-true", "{{1 | or 0}}", "1", nil, true}, + {"or pipe-false", "{{0 | or 0}}", "0", nil, true}, + {"and undef", "{{and 1 .Unknown}}", "", nil, true}, + {"or undef", "{{or 0 .Unknown}}", "", nil, true}, + {"boolean if", "{{if and true 1 `hi`}}TRUE{{else}}FALSE{{end}}", "TRUE", tVal, true}, + {"boolean if not", "{{if and true 1 `hi` | not}}TRUE{{else}}FALSE{{end}}", "FALSE", nil, true}, + {"boolean if pipe", "{{if true | not | and 1}}TRUE{{else}}FALSE{{end}}", "FALSE", nil, true}, + + // Indexing. + {"slice[0]", "{{index .SI 0}}", "3", tVal, true}, + {"slice[1]", "{{index .SI 1}}", "4", tVal, true}, + {"slice[HUGE]", "{{index .SI 10}}", "", tVal, false}, + {"slice[WRONG]", "{{index .SI `hello`}}", "", tVal, false}, + {"slice[nil]", "{{index .SI nil}}", "", tVal, false}, + {"map[one]", "{{index .MSI `one`}}", "1", tVal, true}, + {"map[two]", "{{index .MSI `two`}}", "2", tVal, true}, + {"map[NO]", "{{index .MSI `XXX`}}", "0", tVal, true}, + {"map[nil]", "{{index .MSI nil}}", "", tVal, false}, + {"map[``]", "{{index .MSI ``}}", "0", tVal, true}, + {"map[WRONG]", "{{index .MSI 10}}", "", tVal, false}, + {"double index", "{{index .SMSI 1 `eleven`}}", "11", tVal, true}, + {"nil[1]", "{{index nil 1}}", "", tVal, false}, + {"map MI64S", "{{index .MI64S 2}}", "i642", tVal, true}, + {"map MI32S", "{{index .MI32S 2}}", "two", tVal, true}, + {"map MUI64S", "{{index .MUI64S 3}}", "ui643", tVal, true}, + {"map MI8S", "{{index .MI8S 3}}", "i83", tVal, true}, + {"map MUI8S", "{{index .MUI8S 2}}", "u82", tVal, true}, + {"index of an interface field", "{{index .Empty3 0}}", "7", tVal, true}, + + // Slicing. + {"slice[:]", "{{slice .SI}}", "[3 4 5]", tVal, true}, + {"slice[1:]", "{{slice .SI 1}}", "[4 5]", tVal, true}, + {"slice[1:2]", "{{slice .SI 1 2}}", "[4]", tVal, true}, + {"slice[-1:]", "{{slice .SI -1}}", "", tVal, false}, + {"slice[1:-2]", "{{slice .SI 1 -2}}", "", tVal, false}, + {"slice[1:2:-1]", "{{slice .SI 1 2 -1}}", "", tVal, false}, + {"slice[2:1]", "{{slice .SI 2 1}}", "", tVal, false}, + {"slice[2:2:1]", "{{slice .SI 2 2 1}}", "", tVal, false}, + {"out of range", "{{slice .SI 4 5}}", "", tVal, false}, + {"out of range", "{{slice .SI 2 2 5}}", "", tVal, false}, + {"len(s) < indexes < cap(s)", "{{slice .SICap 6 10}}", "[0 0 0 0]", tVal, true}, + {"len(s) < indexes < cap(s)", "{{slice .SICap 6 10 10}}", "[0 0 0 0]", tVal, true}, + {"indexes > cap(s)", "{{slice .SICap 10 11}}", "", tVal, false}, + {"indexes > cap(s)", "{{slice .SICap 6 10 11}}", "", tVal, false}, + {"array[:]", "{{slice .AI}}", "[3 4 5]", tVal, true}, + {"array[1:]", "{{slice .AI 1}}", "[4 5]", tVal, true}, + {"array[1:2]", "{{slice .AI 1 2}}", "[4]", tVal, true}, + {"string[:]", "{{slice .S}}", "xyz", tVal, true}, + {"string[0:1]", "{{slice .S 0 1}}", "x", tVal, true}, + {"string[1:]", "{{slice .S 1}}", "yz", tVal, true}, + {"string[1:2]", "{{slice .S 1 2}}", "y", tVal, true}, + {"out of range", "{{slice .S 1 5}}", "", tVal, false}, + {"3-index slice of string", "{{slice .S 1 2 2}}", "", tVal, false}, + {"slice of an interface field", "{{slice .Empty3 0 1}}", "[7]", tVal, true}, + + // Len. + {"slice", "{{len .SI}}", "3", tVal, true}, + {"map", "{{len .MSI }}", "3", tVal, true}, + {"len of int", "{{len 3}}", "", tVal, false}, + {"len of nothing", "{{len .Empty0}}", "", tVal, false}, + {"len of an interface field", "{{len .Empty3}}", "2", tVal, true}, + + // With. + {"with true", "{{with true}}{{.}}{{end}}", "true", tVal, true}, + {"with false", "{{with false}}{{.}}{{else}}FALSE{{end}}", "FALSE", tVal, true}, + {"with 1", "{{with 1}}{{.}}{{else}}ZERO{{end}}", "1", tVal, true}, + {"with 0", "{{with 0}}{{.}}{{else}}ZERO{{end}}", "ZERO", tVal, true}, + {"with 1.5", "{{with 1.5}}{{.}}{{else}}ZERO{{end}}", "1.5", tVal, true}, + {"with 0.0", "{{with .FloatZero}}{{.}}{{else}}ZERO{{end}}", "ZERO", tVal, true}, + {"with 1.5i", "{{with 1.5i}}{{.}}{{else}}ZERO{{end}}", "(0+1.5i)", tVal, true}, + {"with 0.0i", "{{with .ComplexZero}}{{.}}{{else}}ZERO{{end}}", "ZERO", tVal, true}, + {"with emptystring", "{{with ``}}{{.}}{{else}}EMPTY{{end}}", "EMPTY", tVal, true}, + {"with string", "{{with `notempty`}}{{.}}{{else}}EMPTY{{end}}", "notempty", tVal, true}, + {"with emptyslice", "{{with .SIEmpty}}{{.}}{{else}}EMPTY{{end}}", "EMPTY", tVal, true}, + {"with slice", "{{with .SI}}{{.}}{{else}}EMPTY{{end}}", "[3 4 5]", tVal, true}, + {"with emptymap", "{{with .MSIEmpty}}{{.}}{{else}}EMPTY{{end}}", "EMPTY", tVal, true}, + {"with map", "{{with .MSIone}}{{.}}{{else}}EMPTY{{end}}", "map[one:1]", tVal, true}, + {"with empty interface, struct field", "{{with .Empty4}}{{.V}}{{end}}", "UinEmpty", tVal, true}, + {"with $x int", "{{with $x := .I}}{{$x}}{{end}}", "17", tVal, true}, + {"with $x struct.U.V", "{{with $x := $}}{{$x.U.V}}{{end}}", "v", tVal, true}, + {"with variable and action", "{{with $x := $}}{{$y := $.U.V}}{{$y}}{{end}}", "v", tVal, true}, + {"with on typed nil interface value", "{{with .NonEmptyInterfaceTypedNil}}TRUE{{ end }}", "", tVal, true}, + {"with else with", "{{with 0}}{{.}}{{else with true}}{{.}}{{end}}", "true", tVal, true}, + {"with else with chain", "{{with 0}}{{.}}{{else with false}}{{.}}{{else with `notempty`}}{{.}}{{end}}", "notempty", tVal, true}, + + // Range. + {"range []int", "{{range .SI}}-{{.}}-{{end}}", "-3--4--5-", tVal, true}, + {"range empty no else", "{{range .SIEmpty}}-{{.}}-{{end}}", "", tVal, true}, + {"range []int else", "{{range .SI}}-{{.}}-{{else}}EMPTY{{end}}", "-3--4--5-", tVal, true}, + {"range empty else", "{{range .SIEmpty}}-{{.}}-{{else}}EMPTY{{end}}", "EMPTY", tVal, true}, + {"range []int break else", "{{range .SI}}-{{.}}-{{break}}NOTREACHED{{else}}EMPTY{{end}}", "-3-", tVal, true}, + {"range []int continue else", "{{range .SI}}-{{.}}-{{continue}}NOTREACHED{{else}}EMPTY{{end}}", "-3--4--5-", tVal, true}, + {"range []bool", "{{range .SB}}-{{.}}-{{end}}", "-true--false-", tVal, true}, + {"range []int method", "{{range .SI | .MAdd .I}}-{{.}}-{{end}}", "-20--21--22-", tVal, true}, + {"range map", "{{range .MSI}}-{{.}}-{{end}}", "-1--3--2-", tVal, true}, + {"range empty map no else", "{{range .MSIEmpty}}-{{.}}-{{end}}", "", tVal, true}, + {"range map else", "{{range .MSI}}-{{.}}-{{else}}EMPTY{{end}}", "-1--3--2-", tVal, true}, + {"range empty map else", "{{range .MSIEmpty}}-{{.}}-{{else}}EMPTY{{end}}", "EMPTY", tVal, true}, + {"range empty interface", "{{range .Empty3}}-{{.}}-{{else}}EMPTY{{end}}", "-7--8-", tVal, true}, + {"range empty nil", "{{range .Empty0}}-{{.}}-{{end}}", "", tVal, true}, + {"range $x SI", "{{range $x := .SI}}<{{$x}}>{{end}}", "<3><4><5>", tVal, true}, + {"range $x $y SI", "{{range $x, $y := .SI}}<{{$x}}={{$y}}>{{end}}", "<0=3><1=4><2=5>", tVal, true}, + {"range $x MSIone", "{{range $x := .MSIone}}<{{$x}}>{{end}}", "<1>", tVal, true}, + {"range $x $y MSIone", "{{range $x, $y := .MSIone}}<{{$x}}={{$y}}>{{end}}", "", tVal, true}, + {"range $x PSI", "{{range $x := .PSI}}<{{$x}}>{{end}}", "<21><22><23>", tVal, true}, + {"declare in range", "{{range $x := .PSI}}<{{$foo:=$x}}{{$x}}>{{end}}", "<21><22><23>", tVal, true}, + {"range count", `{{range $i, $x := count 5}}[{{$i}}]{{$x}}{{end}}`, "[0]a[1]b[2]c[3]d[4]e", tVal, true}, + {"range nil count", `{{range $i, $x := count 0}}{{else}}empty{{end}}`, "empty", tVal, true}, + + // Cute examples. + {"or as if true", `{{or .SI "slice is empty"}}`, "[3 4 5]", tVal, true}, + {"or as if false", `{{or .SIEmpty "slice is empty"}}`, "slice is empty", tVal, true}, + + // Error handling. + {"error method, error", "{{.MyError true}}", "", tVal, false}, + {"error method, no error", "{{.MyError false}}", "false", tVal, true}, + + // Numbers + {"decimal", "{{print 1234}}", "1234", tVal, true}, + {"decimal _", "{{print 12_34}}", "1234", tVal, true}, + {"binary", "{{print 0b101}}", "5", tVal, true}, + {"binary _", "{{print 0b_1_0_1}}", "5", tVal, true}, + {"BINARY", "{{print 0B101}}", "5", tVal, true}, + {"octal0", "{{print 0377}}", "255", tVal, true}, + {"octal", "{{print 0o377}}", "255", tVal, true}, + {"octal _", "{{print 0o_3_7_7}}", "255", tVal, true}, + {"OCTAL", "{{print 0O377}}", "255", tVal, true}, + {"hex", "{{print 0x123}}", "291", tVal, true}, + {"hex _", "{{print 0x1_23}}", "291", tVal, true}, + {"HEX", "{{print 0X123ABC}}", "1194684", tVal, true}, + {"float", "{{print 123.4}}", "123.4", tVal, true}, + {"float _", "{{print 0_0_1_2_3.4}}", "123.4", tVal, true}, + {"hex float", "{{print +0x1.ep+2}}", "7.5", tVal, true}, + {"hex float _", "{{print +0x_1.e_0p+0_2}}", "7.5", tVal, true}, + {"HEX float", "{{print +0X1.EP+2}}", "7.5", tVal, true}, + {"print multi", "{{print 1_2_3_4 7.5_00_00_00}}", "1234 7.5", tVal, true}, + {"print multi2", "{{print 1234 0x0_1.e_0p+02}}", "1234 7.5", tVal, true}, + + // Fixed bugs. + // Must separate dot and receiver; otherwise args are evaluated with dot set to variable. + {"bug0", "{{range .MSIone}}{{if $.Method1 .}}X{{end}}{{end}}", "X", tVal, true}, + // Do not loop endlessly in indirect for non-empty interfaces. + // The bug appears with *interface only; looped forever. + {"bug1", "{{.Method0}}", "M0", &iVal, true}, + // Was taking address of interface field, so method set was empty. + {"bug2", "{{$.NonEmptyInterface.Method0}}", "M0", tVal, true}, + // Struct values were not legal in with - mere oversight. + {"bug3", "{{with $}}{{.Method0}}{{end}}", "M0", tVal, true}, + // Nil interface values in if. + {"bug4", "{{if .Empty0}}non-nil{{else}}nil{{end}}", "nil", tVal, true}, + // Stringer. + {"bug5", "{{.Str}}", "foozle", tVal, true}, + {"bug5a", "{{.Err}}", "erroozle", tVal, true}, + // Args need to be indirected and dereferenced sometimes. + {"bug6a", "{{vfunc .V0 .V1}}", "vfunc", tVal, true}, + {"bug6b", "{{vfunc .V0 .V0}}", "vfunc", tVal, true}, + {"bug6c", "{{vfunc .V1 .V0}}", "vfunc", tVal, true}, + {"bug6d", "{{vfunc .V1 .V1}}", "vfunc", tVal, true}, + // Legal parse but illegal execution: non-function should have no arguments. + {"bug7a", "{{3 2}}", "", tVal, false}, + {"bug7b", "{{$x := 1}}{{$x 2}}", "", tVal, false}, + {"bug7c", "{{$x := 1}}{{3 | $x}}", "", tVal, false}, + // Pipelined arg was not being type-checked. + {"bug8a", "{{3|oneArg}}", "", tVal, false}, + {"bug8b", "{{4|dddArg 3}}", "", tVal, false}, + // A bug was introduced that broke map lookups for lower-case names. + {"bug9", "{{.cause}}", "neglect", map[string]string{"cause": "neglect"}, true}, + // Field chain starting with function did not work. + {"bug10", "{{mapOfThree.three}}-{{(mapOfThree).three}}", "3-3", 0, true}, + // Dereferencing nil pointer while evaluating function arguments should not panic. Issue 7333. + {"bug11", "{{valueString .PS}}", "", T{}, false}, + // 0xef gave constant type float64. Issue 8622. + {"bug12xe", "{{printf `%T` 0xef}}", "int", T{}, true}, + {"bug12xE", "{{printf `%T` 0xEE}}", "int", T{}, true}, + {"bug12Xe", "{{printf `%T` 0Xef}}", "int", T{}, true}, + {"bug12XE", "{{printf `%T` 0XEE}}", "int", T{}, true}, + // Chained nodes did not work as arguments. Issue 8473. + {"bug13", "{{print (.Copy).I}}", "17", tVal, true}, + // Didn't protect against nil or literal values in field chains. + {"bug14a", "{{(nil).True}}", "", tVal, false}, + {"bug14b", "{{$x := nil}}{{$x.anything}}", "", tVal, false}, + {"bug14c", `{{$x := (1.0)}}{{$y := ("hello")}}{{$x.anything}}{{$y.true}}`, "", tVal, false}, + // Didn't call validateType on function results. Issue 10800. + {"bug15", "{{valueString returnInt}}", "", tVal, false}, + // Variadic function corner cases. Issue 10946. + {"bug16a", "{{true|printf}}", "", tVal, false}, + {"bug16b", "{{1|printf}}", "", tVal, false}, + {"bug16c", "{{1.1|printf}}", "", tVal, false}, + {"bug16d", "{{'x'|printf}}", "", tVal, false}, + {"bug16e", "{{0i|printf}}", "", tVal, false}, + {"bug16f", "{{true|twoArgs \"xxx\"}}", "", tVal, false}, + {"bug16g", "{{\"aaa\" |twoArgs \"bbb\"}}", "twoArgs=bbbaaa", tVal, true}, + {"bug16h", "{{1|oneArg}}", "", tVal, false}, + {"bug16i", "{{\"aaa\"|oneArg}}", "oneArg=aaa", tVal, true}, + {"bug16j", "{{1+2i|printf \"%v\"}}", "(1+2i)", tVal, true}, + {"bug16k", "{{\"aaa\"|printf }}", "aaa", tVal, true}, + {"bug17a", "{{.NonEmptyInterface.X}}", "x", tVal, true}, + {"bug17b", "-{{.NonEmptyInterface.Method1 1234}}-", "-1234-", tVal, true}, + {"bug17c", "{{len .NonEmptyInterfacePtS}}", "2", tVal, true}, + {"bug17d", "{{index .NonEmptyInterfacePtS 0}}", "a", tVal, true}, + {"bug17e", "{{range .NonEmptyInterfacePtS}}-{{.}}-{{end}}", "-a--b-", tVal, true}, + + // More variadic function corner cases. Some runes would get evaluated + // as constant floats instead of ints. Issue 34483. + {"bug18a", "{{eq . '.'}}", "true", '.', true}, + {"bug18b", "{{eq . 'e'}}", "true", 'e', true}, + {"bug18c", "{{eq . 'P'}}", "true", 'P', true}, + + {"issue56490", "{{$i := 0}}{{$x := 0}}{{range $i = .AI}}{{end}}{{$i}}", "5", tVal, true}, + {"issue60801", "{{$k := 0}}{{$v := 0}}{{range $k, $v = .AI}}{{$k}}={{$v}} {{end}}", "0=3 1=4 2=5 ", tVal, true}, + } { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := template.New(tt.name).Funcs(funcExec).Parse(tt.input) + if err != nil { + t.Errorf("%s: parse error: %s", tt.name, err) + return + } + err = tmpl.Execute(io.Discard, tt.data) + + var dataType types.Type + switch d := tt.data.(type) { + case T: + dataType = checkTestPackage.Types.Scope().Lookup("T").Type() + case *T: + dataType = types.NewPointer(checkTestPackage.Types.Scope().Lookup("T").Type()) + case nil: + dataType = types.Universe.Lookup("nil").Type() + case *I: + dataType = types.NewPointer(checkTestPackage.Types.Scope().Lookup("I").Type()) + default: + typeName := reflect.TypeOf(tt.data).Name() + obj := types.Universe.Lookup(typeName) + require.NotNil(t, obj) + dt := obj.Type() + if dt == nil { + t.Fatal("unexpected type", reflect.TypeOf(d)) + } + dataType = dt + } + require.NotNil(t, dataType) + + checkErr := check.Tree(tmpl.Tree, dataType, checkTestPackage.Types, checkTestPackage.Fset, findTextTree(tmpl), funcSource) + switch { + case !tt.ok && checkErr == nil: + t.Errorf("%s: expected error; got none", tt.name) + return + case tt.ok && checkErr != nil: + t.Errorf("%s: unexpected execute error: %s", tt.name, checkErr) + return + case !tt.ok && checkErr != nil: + // expected error, got one + if testing.Verbose() { + fmt.Printf("%s: %s\n\t%s\n", tt.name, tt.input, checkErr) + } + } + }) + } +} diff --git a/internal/check/tree.go b/internal/check/tree.go new file mode 100644 index 0000000..210c13f --- /dev/null +++ b/internal/check/tree.go @@ -0,0 +1,491 @@ +package check + +import ( + "fmt" + "go/token" + "go/types" + "maps" + "strconv" + "strings" + "text/template/parse" +) + +type TreeFinder interface { + FindTree(name string) (*parse.Tree, bool) +} + +type FindTreeFunc func(name string) (*parse.Tree, bool) + +func (fn FindTreeFunc) FindTree(name string) (*parse.Tree, bool) { + return fn(name) +} + +type FunctionFinder interface { + FindFunction(name string) (*types.Signature, bool) +} + +func Tree(tree *parse.Tree, data types.Type, pkg *types.Package, fileSet *token.FileSet, trees TreeFinder, functions FunctionFinder) error { + s := &scope{ + global: global{ + TreeFinder: trees, + FunctionFinder: functions, + pkg: pkg, + fileSet: fileSet, + }, + variables: map[string]types.Type{ + "$": data, + }, + } + _, err := s.walk(tree, data, tree.Root) + return err +} + +type global struct { + TreeFinder + FunctionFinder + + pkg *types.Package + fileSet *token.FileSet +} + +type scope struct { + global + variables map[string]types.Type +} + +func (s *scope) child() *scope { + return &scope{ + global: s.global, + variables: maps.Clone(s.variables), + } +} + +func (s *scope) walk(tree *parse.Tree, dot types.Type, node parse.Node) (types.Type, error) { + switch n := node.(type) { + case *parse.DotNode: + return dot, nil + case *parse.ListNode: + return nil, s.checkListNode(tree, dot, n) + case *parse.ActionNode: + return nil, s.checkActionNode(tree, dot, n) + case *parse.CommandNode: + return s.checkCommandNode(tree, dot, n) + case *parse.FieldNode: + return s.checkFieldNode(tree, dot, n) + case *parse.PipeNode: + return s.checkPipeNode(tree, dot, n) + case *parse.IfNode: + return nil, s.checkIfNode(tree, dot, n) + case *parse.RangeNode: + return nil, s.checkRangeNode(tree, dot, n) + case *parse.TemplateNode: + return nil, s.checkTemplateNode(tree, dot, n) + case *parse.BoolNode: + return types.Typ[types.UntypedBool], nil + case *parse.StringNode: + return types.Typ[types.UntypedString], nil + case *parse.NumberNode: + return newNumberNodeType(n) + case *parse.VariableNode: + return s.checkVariableNode(tree, n) + case *parse.IdentifierNode: + return s.checkIdentifierNode(n) + case *parse.TextNode: + return nil, nil + case *parse.WithNode: + return nil, s.checkWithNode(tree, dot, n) + case *parse.CommentNode: + return nil, nil + case *parse.NilNode: + return types.Typ[types.UntypedNil], nil + case *parse.ChainNode: + return s.checkChainNode(tree, dot, n) + case *parse.BranchNode: + return nil, nil + case *parse.BreakNode: + return nil, nil + case *parse.ContinueNode: + return nil, nil + default: + return nil, fmt.Errorf("missing node type check %T", n) + } +} + +func (s *scope) checkChainNode(tree *parse.Tree, dot types.Type, n *parse.ChainNode) (types.Type, error) { + x, err := s.walk(tree, dot, n.Node) + if err != nil { + return nil, err + } + return s.checkIdentifiers(tree, x, n, n.Field) +} + +func (s *scope) checkVariableNode(tree *parse.Tree, n *parse.VariableNode) (types.Type, error) { + tp, ok := s.variables[n.Ident[0]] + if !ok { + return nil, fmt.Errorf("variable %s not found", n.Ident[0]) + } + return s.checkIdentifiers(tree, tp, n, n.Ident[1:]) +} + +func (s *scope) checkListNode(tree *parse.Tree, dot types.Type, n *parse.ListNode) error { + for _, child := range n.Nodes { + if _, err := s.walk(tree, dot, child); err != nil { + return err + } + } + return nil +} + +func (s *scope) checkActionNode(tree *parse.Tree, dot types.Type, n *parse.ActionNode) error { + _, err := s.walk(tree, dot, n.Pipe) + return err +} + +func (s *scope) checkPipeNode(tree *parse.Tree, dot types.Type, n *parse.PipeNode) (types.Type, error) { + x := dot + for _, cmd := range n.Cmds { + tp, err := s.walk(tree, x, cmd) + if err != nil { + return nil, err + } + x = tp + } + if len(n.Decl) > 0 { + switch r := x.(type) { + case *types.Slice: + if l := len(n.Decl); l == 1 { + s.variables[n.Decl[0].Ident[0]] = r.Elem() + } else if l == 2 { + s.variables[n.Decl[0].Ident[0]] = types.Typ[types.Int] + s.variables[n.Decl[1].Ident[0]] = r.Elem() + } else { + return nil, fmt.Errorf("expected 1 or 2 declaration") + } + case *types.Array: + if l := len(n.Decl); l == 1 { + s.variables[n.Decl[0].Ident[0]] = r.Elem() + } else if l == 2 { + s.variables[n.Decl[0].Ident[0]] = types.Typ[types.Int] + s.variables[n.Decl[1].Ident[0]] = r.Elem() + } else { + return nil, fmt.Errorf("expected 1 or 2 declaration") + } + case *types.Map: + if l := len(n.Decl); l == 1 { + s.variables[n.Decl[0].Ident[0]] = r.Elem() + } else if l == 2 { + s.variables[n.Decl[0].Ident[0]] = r.Key() + s.variables[n.Decl[1].Ident[0]] = r.Elem() + } else { + return nil, fmt.Errorf("expected 1 or 2 declaration") + } + default: + // assert.MaxLen(n.Decl, 1, "too many variable declarations in a pipe node") + if len(n.Decl) == 1 { + s.variables[n.Decl[0].Ident[0]] = x + } + } + } + return x, nil +} + +func (s *scope) checkIfNode(tree *parse.Tree, dot types.Type, n *parse.IfNode) error { + _, err := s.walk(tree, dot, n.Pipe) + if err != nil { + return err + } + ifScope := s.child() + if _, err := ifScope.walk(tree, dot, n.List); err != nil { + return err + } + if n.ElseList != nil { + elseScope := s.child() + if _, err := elseScope.walk(tree, dot, n.ElseList); err != nil { + return err + } + } + return nil +} + +func (s *scope) checkWithNode(tree *parse.Tree, dot types.Type, n *parse.WithNode) error { + child := s.child() + x, err := child.walk(tree, dot, n.Pipe) + if err != nil { + return err + } + withScope := child.child() + if _, err := withScope.walk(tree, x, n.List); err != nil { + return err + } + if n.ElseList != nil { + elseScope := child.child() + if _, err := elseScope.walk(tree, dot, n.ElseList); err != nil { + return err + } + } + return nil +} + +func newNumberNodeType(n *parse.NumberNode) (types.Type, error) { + if n.IsInt || n.IsUint { + tp := types.Typ[types.UntypedInt] + return tp, nil + } + if n.IsFloat { + tp := types.Typ[types.UntypedFloat] + return tp, nil + } + if n.IsComplex { + tp := types.Typ[types.UntypedComplex] + return tp, nil + } + return nil, fmt.Errorf("failed to evaluate template *parse.NumberNode type") +} + +func (s *scope) checkTemplateNode(tree *parse.Tree, dot types.Type, n *parse.TemplateNode) error { + x := dot + if n.Pipe != nil { + tp, err := s.walk(tree, x, n.Pipe) + if err != nil { + return err + } + x = tp + x = downgradeUntyped(x) + } else { + x = types.Typ[types.UntypedNil] + } + childTree, ok := s.FindTree(n.Name) + if !ok { + return fmt.Errorf("template %q not found", n.Name) + } + childScope := scope{ + global: s.global, + variables: map[string]types.Type{ + "$": x, + }, + } + _, err := childScope.walk(childTree, x, childTree.Root) + return err +} + +func downgradeUntyped(x types.Type) types.Type { + if x == nil { + return x + } + basic, ok := x.Underlying().(*types.Basic) + if !ok { + return x + } + switch k := basic.Kind(); k { + case types.UntypedInt: + return types.Typ[types.Int].Underlying() + case types.UntypedRune: + return types.Typ[types.Rune].Underlying() + case types.UntypedFloat: + return types.Typ[types.Float64].Underlying() + case types.UntypedComplex: + return types.Typ[types.Complex128].Underlying() + case types.UntypedString: + return types.Typ[types.String].Underlying() + default: + return x + } +} + +func (s *scope) checkFieldNode(tree *parse.Tree, dot types.Type, n *parse.FieldNode) (types.Type, error) { + return s.checkIdentifiers(tree, dot, n, n.Ident) +} + +func (s *scope) checkCommandNode(tree *parse.Tree, dot types.Type, n *parse.CommandNode) (types.Type, error) { + if _, ok := n.Args[0].(*parse.NilNode); len(n.Args) == 1 && ok { + loc, _ := tree.ErrorContext(n) + return nil, fmt.Errorf("%s: executing %q at <%s>: nil is not a command", loc, tree.Name, n.Args[0].String()) + } + argTypes := make([]types.Type, 0, len(n.Args)) + for _, arg := range n.Args[1:] { + argType, err := s.walk(tree, dot, arg) + if err != nil { + return nil, err + } + argTypes = append(argTypes, argType) + } + if ident, ok := n.Args[0].(*parse.IdentifierNode); ok { + switch ident.Ident { + case "slice": + var result types.Type + if slice, ok := argTypes[0].(*types.Slice); ok { + result = slice.Elem() + } else if array, ok := argTypes[0].(*types.Array); ok { + result = array.Elem() + } + if len(argTypes) > 1 { + first, ok := argTypes[1].(*types.Basic) + if !ok { + return nil, fmt.Errorf("slice expected int") + } + switch first.Kind() { + case types.UntypedInt, types.Int: + default: + } + } + if len(argTypes) > 2 { + second, ok := argTypes[1].(*types.Basic) + if !ok { + return nil, fmt.Errorf("slice expected int") + } + switch second.Kind() { + case types.UntypedInt, types.Int: + default: + } + } + return result, nil + case "index": + } + } + cmdType, err := s.walk(tree, dot, n.Args[0]) + if err != nil { + return nil, err + } + switch cmd := cmdType.(type) { + case *types.Signature: + for i := 0; i < len(argTypes); i++ { + at := argTypes[i] + var pt types.Type + isVar := cmd.Variadic() + argVar := i >= cmd.Params().Len()-1 + if isVar && argVar { + ps := cmd.Params() + v := ps.At(ps.Len() - 1).Type().(*types.Slice) + pt = v.Elem() + } else { + pt = cmd.Params().At(i).Type() + } + assignable := types.AssignableTo(at, pt) + if !assignable { + return nil, fmt.Errorf("%s argument %d has type %s expected %s", n.Args[0], i, at, pt) + } + } + return cmd.Results().At(0).Type(), nil + default: + return cmd, nil + } +} + +func (s *scope) checkIdentifiers(tree *parse.Tree, dot types.Type, n parse.Node, idents []string) (types.Type, error) { + x := dot + for i, ident := range idents { + for { + ptr, ok := x.(*types.Pointer) + if !ok { + break + } + x = ptr.Elem() + } + switch xx := x.(type) { + case *types.Map: + switch key := xx.Key().Underlying().(type) { + case *types.Basic: + switch key.Kind() { + // case types.Int, types.Int64, types.Int32, types.Int16, types.Int8, + // types.Uint, types.Uint64, types.Uint32, types.Uint16, types.Uint8: + case types.Int: + x = xx.Elem() + _, err := strconv.Atoi(ident) + if err != nil { + loc, _ := tree.ErrorContext(n) + return nil, fmt.Errorf(`%s: executing %q at <%s>: can't evaluate field one in type %s`, loc, tree.Name, n.String(), xx.String()) + } + case types.String: + x = xx.Elem() + default: + } + continue + default: + x = xx.Elem() + } + continue + case *types.Named: + obj, _, _ := types.LookupFieldOrMethod(x, true, nil, ident) + if obj == nil { + loc, _ := tree.ErrorContext(n) + return nil, fmt.Errorf("type check failed: %s: %s not found on %s", loc, ident, x) + } + switch o := obj.(type) { + default: + x = obj.Type() + case *types.Func: + sig := o.Signature() + resultLen := sig.Results().Len() + if resultLen < 1 || resultLen > 2 { + loc, _ := tree.ErrorContext(n) + methodPos := s.fileSet.Position(o.Pos()) + return nil, fmt.Errorf("type check failed: %s: function %s has %d return values; should be 1 or 2: incorrect signature at %s", loc, ident, resultLen, methodPos) + } + if resultLen > 1 { + loc, _ := tree.ErrorContext(n) + methodPos := s.fileSet.Position(obj.Pos()) + finalResult := sig.Results().At(sig.Results().Len() - 1) + errorType := types.Universe.Lookup("error") + if !types.Identical(errorType.Type(), finalResult.Type()) { + return nil, fmt.Errorf("type check failed: %s: invalid function signature for %s: second return value should be error; is %s: incorrect signature at %s", loc, ident, finalResult.Type(), methodPos) + } + } + if i == len(idents)-1 { + return o.Type(), nil + } + x = sig.Results().At(0).Type() + } + if _, ok := x.(*types.Signature); ok && i < len(idents)-1 { + loc, _ := tree.ErrorContext(n) + return nil, fmt.Errorf("type check failed: %s: can't evaluate field %s in type %s", loc, ident, x) + } + default: + loc, _ := tree.ErrorContext(n) + return nil, fmt.Errorf("type check failed: %s: identifier chain not supported for type %s", loc, x.String()) + } + } + return x, nil +} + +func (s *scope) checkRangeNode(tree *parse.Tree, dot types.Type, n *parse.RangeNode) error { + child := s.child() + pipeType, err := child.walk(tree, dot, n.Pipe) + if err != nil { + return err + } + var x types.Type + switch pt := pipeType.(type) { + case *types.Slice: + x = pt.Elem() + case *types.Array: + x = pt.Elem() + case *types.Map: + x = pt.Elem() + default: + return fmt.Errorf("failed to range over %s", pipeType) + } + if _, err := child.walk(tree, x, n.List); err != nil { + return err + } + if n.ElseList != nil { + if _, err := child.walk(tree, x, n.ElseList); err != nil { + return err + } + } + return nil +} + +func (s *scope) checkIdentifierNode(n *parse.IdentifierNode) (types.Type, error) { + if strings.HasPrefix(n.Ident, "$") { + tp, ok := s.variables[n.Ident] + if !ok { + return nil, fmt.Errorf("failed to find identifier %s", n.Ident) + } + return tp, nil + } + fn, ok := s.FindFunction(n.Ident) + if !ok { + return nil, fmt.Errorf("failed to find function %s", n.Ident) + } + return fn, nil +} diff --git a/internal/check/tree_test.go b/internal/check/tree_test.go new file mode 100644 index 0000000..eee8f65 --- /dev/null +++ b/internal/check/tree_test.go @@ -0,0 +1,614 @@ +package check_test + +import ( + "fmt" + "go/types" + "html/template" + "io" + "reflect" + "slices" + "strings" + "sync" + "testing" + "text/template/parse" + + "github.com/stretchr/testify/require" + "golang.org/x/tools/go/packages" + + "github.com/crhntr/muxt" + "github.com/crhntr/muxt/internal/check" + "github.com/crhntr/muxt/internal/source" +) + +var loadPkg = sync.OnceValue(func() []*packages.Package { + packageList, loadErr := packages.Load(&packages.Config{ + Mode: packages.NeedName | packages.NeedFiles | packages.NeedDeps | packages.NeedTypes, + Tests: true, + }, ".") + if loadErr != nil { + panic(loadErr) + } + return packageList +}) + +func findHTMLTree(tmpl *template.Template) check.FindTreeFunc { + return func(name string) (*parse.Tree, bool) { + ts := tmpl.Lookup(name) + if ts == nil { + return nil, false + } + return ts.Tree, true + } +} + +func TestTree(t *testing.T) { + checkTestPackage := find(t, loadPkg(), func(p *packages.Package) bool { + return p.Name == "check_test" + }) + for _, tt := range []struct { + Name string + Template string + Data any + Error func(t *testing.T, checkErr, execErr error, tp types.Type) + }{ + { + Name: "when accessing nil on an empty struct", + Template: `{{.Field}}`, + Data: Void{}, + Error: func(t *testing.T, err, _ error, tp types.Type) { + require.EqualError(t, err, fmt.Sprintf(`type check failed: template:1:2: Field not found on %s`, tp)) + }, + }, + { + Name: "when accessing the dot", + Template: `{{.}}`, + Data: Void{}, + }, + { + Name: "when a method does not any results", + Template: `{{.Method}}`, + Data: TypeWithMethodSignatureNoResultMethod{}, + Error: func(t *testing.T, err, _ error, tp types.Type) { + method, _, _ := types.LookupFieldOrMethod(tp, true, checkTestPackage.Types, "Method") + require.NotNil(t, method) + methodPos := checkTestPackage.Fset.Position(method.Pos()) + + require.EqualError(t, err, fmt.Sprintf(`type check failed: template:1:2: function Method has 0 return values; should be 1 or 2: incorrect signature at %s`, methodPos)) + }, + }, + { + Name: "when a method does has a result", + Template: `{{.Method}}`, + Data: TypeWithMethodSignatureResult{}, + }, + { + Name: "when a method also has an error", + Template: `{{.Method}}`, + Data: TypeWithMethodSignatureResultAndError{}, + }, + { + Name: "when a method has a second result that is not an error", + Template: `{{.Method}}`, + Data: TypeWithMethodSignatureResultAndNonError{}, + Error: func(t *testing.T, err, _ error, tp types.Type) { + method, _, _ := types.LookupFieldOrMethod(tp, true, checkTestPackage.Types, "Method") + require.NotNil(t, method) + methodPos := checkTestPackage.Fset.Position(method.Pos()) + + require.EqualError(t, err, fmt.Sprintf(`type check failed: template:1:2: invalid function signature for Method: second return value should be error; is int: incorrect signature at %s`, methodPos)) + }, + }, + { + Name: "when a method with too many results", + Template: `{{.Method}}`, + Data: TypeWithMethodSignatureThreeResults{}, + Error: func(t *testing.T, err, _ error, tp types.Type) { + method, _, _ := types.LookupFieldOrMethod(tp, true, checkTestPackage.Types, "Method") + require.NotNil(t, method) + methodPos := checkTestPackage.Fset.Position(method.Pos()) + + require.EqualError(t, err, fmt.Sprintf(`type check failed: template:1:2: function Method has 3 return values; should be 1 or 2: incorrect signature at %s`, methodPos)) + }, + }, + { + Name: "when a method is part of a field node list", + Template: `{{.Method.Method}}`, + Data: TypeWithMethodSignatureResultHasMethod{}, + }, + { + Name: "when result method does not have a method", + Template: `{{.Method.Method}}`, + Data: TypeWithMethodSignatureResultHasMethodWithNoResults{}, + Error: func(t *testing.T, err, _ error, tp types.Type) { + m1, _, _ := types.LookupFieldOrMethod(tp, true, checkTestPackage.Types, "Method") + require.NotNil(t, m1) + m2, _, _ := types.LookupFieldOrMethod(m1.Type().(*types.Signature).Results().At(0).Type(), true, checkTestPackage.Types, "Method") + require.NotNil(t, m2) + methodPos := checkTestPackage.Fset.Position(m2.Pos()) + + require.EqualError(t, err, fmt.Sprintf(`type check failed: template:1:9: function Method has 0 return values; should be 1 or 2: incorrect signature at %s`, methodPos)) + }, + }, + { + Name: "when the struct has the field", + Template: `{{.Field}}`, + Data: StructWithField{}, + }, + { + Name: "when the struct has the field and the field has a method", + Template: `{{.Field.Method}}`, + Data: StructWithFieldWithMethod{}, + }, + { + Name: "when the struct has the field and the field has a method", + Template: `{{.Field}}`, + Data: StructWithFieldWithMethod{}, + }, + { + Name: "when the struct has the field of kind func", + Template: `{{.Func.Method}}`, + Data: StructWithFuncFieldWithResultWithMethod{ + Func: func() (_ TypeWithMethodSignatureResult) { return }, + }, + Error: func(t *testing.T, err, _ error, tp types.Type) { + fn, _, _ := types.LookupFieldOrMethod(tp, true, checkTestPackage.Types, "Func") + require.NotNil(t, fn) + require.ErrorContains(t, err, fmt.Sprintf("type check failed: template:1:7: can't evaluate field Func in type %s", fn.Type())) + }, + }, + { + Name: "when a method has an int parameter", + Template: `{{.F 21}}`, + Data: MethodWithIntParam{}, + }, + { + Name: "when a method argument is an bool but param is int", + Template: `{{.F false}}`, + Data: MethodWithIntParam{}, + Error: func(t *testing.T, checkErr, _ error, tp types.Type) { + require.Error(t, checkErr) + require.ErrorContains(t, checkErr, "expected int") + }, + }, + { + Name: "when a method has a bool parameter", + Template: `{{.F true}}`, + Data: MethodWithBoolParam{}, + }, + { + Name: "when a method argument is an int but param is bool", + Template: `{{.F 32}}`, + Data: MethodWithBoolParam{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.Error(t, checkErr) + require.ErrorContains(t, checkErr, "expected bool") + require.Error(t, execErr) + }, + }, + { + Name: "when a method receives a 64 bit floating point literal", + Template: `{{.F 3.2}}`, + Data: MethodWithFloat64Param{}, + }, + { + Name: "when a method receives a 32 bit floating point literal", + Template: `{{.F 3.2}}`, + Data: MethodWithFloat32Param{}, + }, + { + Name: "when the method parameter is an int8", + Template: `{{.F 32}}`, + Data: MethodWithInt8Param{}, + }, + { + Name: "when the method parameter is an int16", + Template: `{{.F 32}}`, + Data: MethodWithInt16Param{}, + }, + { + Name: "when the method parameter is an int32", + Template: `{{.F 32}}`, + Data: MethodWithInt32Param{}, + }, + { + Name: "when the method parameter is an int64", + Template: `{{.F 32}}`, + Data: MethodWithInt64Param{}, + }, + { + Name: "when the method parameter is an uint", + Template: `{{.F 32}}`, + Data: MethodWithUintParam{}, + }, + { + Name: "when the method parameter is an uint8", + Template: `{{.F 32}}`, + Data: MethodWithUint8Param{}, + }, + { + Name: "when the method parameter is an uint16", + Template: `{{.F 32}}`, + Data: MethodWithUint16Param{}, + }, + { + Name: "when the method parameter is an uint32", + Template: `{{.F 32}}`, + Data: MethodWithUint32Param{}, + }, + { + Name: "when the method parameter is an uint64", + Template: `{{.F 32}}`, + Data: MethodWithUint64Param{}, + }, + { + Name: "when a method is on the dollar variable", + Template: `{{$.F 32}}`, + Data: MethodWithUint64Param{}, + }, + { + Name: "when accessing the dollar variable in an underlying template", + Template: `{{define "t1"}}{{$.F 3.2}}{{end}}{{template "t1" $.Method}}`, + Data: TypeWithMethodSignatureResultMethodWithFloat32Param{}, + }, + { + Name: "when ranging over a slice field", + Template: `{{range .Numbers}}{{$.F .}}{{end}}`, + Data: TypeWithMethodAndSliceFloat64{ + Numbers: []float64{1, 2, 3}, + }, + }, + { + Name: "when ranging over an array field", + Template: `{{range .Numbers}}{{$.F .}}{{end}}`, + Data: TypeWithMethodAndArrayFloat64{ + Numbers: [...]float64{1, 2}, + }, + }, + { + Name: "when passing key value range variables for slice", + Template: `{{range $k, $v := .Numbers}}{{$.F $k $v}}{{end}}`, + Data: MethodWithKeyValForSlices{ + Numbers: []float64{1, 2}, + }, + }, + { + Name: "when passing key value range variables for array", + Template: `{{range $k, $v := .Numbers}}{{$.F $k $v}}{{end}}`, + Data: MethodWithKeyValForArray{ + Numbers: [...]float64{1, 2}, + }, + }, + { + Name: "when passing key value range variables for map", + Template: `{{range $k, $v := .Numbers}}{{$.F $k $v}}{{end}}`, + Data: MethodWithKeyValForMap{ + Numbers: map[int16]float32{}, + }, + }, + { + Name: "when a variable is used", + Template: `{{$v := 1}}{{.F $v}}`, + Data: MethodWithIntParam{}, + }, + { + Name: "when there is an error in the else block", + Template: `{{$x := "wrong type"}}{{if false}}{{else}}{{.F $x}}{{end}}`, + Data: MethodWithIntParam{}, + Error: func(t *testing.T, checkErr, _ error, tp types.Type) { + require.Error(t, checkErr) + require.ErrorContains(t, checkErr, ".F argument 0 has type untyped string expected int") + }, + }, + { + Name: "variable redefined in if block", + Template: `{{$x := 1}}{{if true}}{{$x := "str"}}{{end}}{{.F $x}}`, + Data: MethodWithIntParam{}, + }, + { + Name: "range variable does not clobber outer scope", + Template: `{{$x := 1}}{{range .Numbers}}{{$x := "str"}}{{end}}{{square $x}}`, + Data: MethodWithKeyValForSlices{}, + }, + { + Name: "range variable does not override outer scope", + Template: `{{$x := "str"}}{{range $x, $y := .Numbers}}{{$.F $x $y}}{{end}}{{printf $x}}`, + Data: MethodWithKeyValForSlices{}, + }, + { + Name: "source provided function", + Template: `{{square 5}}`, + Data: Void{}, + }, + { + Name: "with expression", + Template: `{{$x := 1}}{{with $x := .Numbers}}{{$x}}{{end}}`, + Data: MethodWithKeyValForSlices{}, + }, + { + Name: "with expression declares variable with same name as parent scope", + Template: `{{$x := 1.2}}{{with $x := ceil $x}}{{$x}}{{end}}`, + Data: MethodWithKeyValForSlices{}, + }, + { + Name: "with expression has action with wrong dot type used in call", + Template: `{{with $x := "wrong"}}{{expectInt .}}{{else}}{{end}}`, + Data: Void{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.ErrorContains(t, execErr, "wrong type for value; expected int; got string") + require.ErrorContains(t, checkErr, "expectInt argument 0 has type untyped string expected int") + }, + }, + { + Name: "with else expression has action with correct dot type used in call", + Template: `{{with $x := 12}}{{with $x := 1.2}}{{else}}{{expectInt $x}}{{end}}{{end}}`, + Data: Void{}, + }, + { + Name: "with else expression has action with wrong dot type used in call", + Template: `{{with $outer := 12}}{{with $x := true}}{{else}}{{expectString .}}{{end}}{{end}}`, + Data: Void{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.NoError(t, execErr) + require.ErrorContains(t, checkErr, "expectString argument 0 has type untyped int expected string") + }, + }, + { + Name: "complex number parses", + Template: `{{$x := 2i}}{{printf "%T" $x}}`, + Data: Void{}, + }, + { + Name: "template node without parameter", + Template: `{{define "t"}}{{end}}{{template "t"}}`, + Data: Void{}, + }, + { + Name: "template wrong input type", + Template: `{{define "t"}}{{expectInt .}}{{end}}{{if false}}{{template "t" 1.2}}{{end}}`, + Data: Void{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.NoError(t, execErr) + require.ErrorContains(t, checkErr, "expectInt argument 0 has type float64 expected int") + }, + }, + { + Name: "it downgrades untyped integers", + Template: `{{define "t"}}{{expectInt8 .}}{{end}}{{if false}}{{template "t" 12}}{{end}}`, + Data: Void{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.NoError(t, execErr) + require.ErrorContains(t, checkErr, "expectInt8 argument 0 has type int expected int8") + }, + }, + { + Name: "it downgrades untyped floats", + Template: `{{define "t"}}{{expectFloat32 .}}{{end}}{{if false}}{{template "t" 1.2}}{{end}}`, + Data: Void{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.NoError(t, execErr) + require.ErrorContains(t, checkErr, "expectFloat32 argument 0 has type float64 expected float32") + }, + }, + { + Name: "it downgrades untyped complex", + Template: `{{define "t"}}{{expectComplex64 .}}{{end}}{{if false}}{{template "t" 2i}}{{end}}`, + Data: Void{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.NoError(t, execErr) + require.ErrorContains(t, checkErr, "expectComplex64 argument 0 has type complex128 expected complex64") + }, + }, + // not sure if I should be downgrading bool, it should be fine to let it be since there is only one basic bool type + { + Name: "chain node", + Template: `{{(.).A.B.C.D}}`, + Data: LetterChainA{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.NoError(t, execErr) + require.NoError(t, checkErr) + }, + }, + { + Name: "chain node with type change in term", + Template: `{{(.A).B.C.D}}`, + Data: LetterChainA{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.NoError(t, execErr) + require.NoError(t, checkErr) + }, + }, + + // stdlib exec tests + + // Trivial cases. + // {"empty", "", "", nil, true}, + { + Name: "empty", + Template: "", + Data: Void{}, + }, + // {"text", "some text", "some text", nil, true}, + { + Name: "text", + Template: "some text", + Data: Void{}, + }, + // {"nil action", "{{nil}}", "", nil, false}, + { + Name: "nil action", + Template: `{{nil}}`, + Data: Void{}, + Error: func(t *testing.T, checkErr, execErr error, tp types.Type) { + require.ErrorContains(t, checkErr, strings.TrimPrefix(execErr.Error(), "template: ")) + }, + }, + + // Ideal constants. + // {"ideal int", "{{typeOf 3}}", "int", 0, true}, + { + Name: "ideal int", + Template: `{{expectInt 3}}`, + Data: Void{}, + }, + // {"ideal float", "{{typeOf 1.0}}", "float64", 0, true}, + { + Name: "ideal int", + Template: `{{expectFloat64 1.0}}}`, + Data: Void{}, + }, + // {"ideal exp float", "{{typeOf 1e1}}", "float64", 0, true}, + { + Name: "ideal float", + Template: `{{expectFloat64 1e1}}`, + Data: Void{}, + }, + // {"ideal complex", "{{typeOf 1i}}", "complex128", 0, true}, + { + Name: "ideal complex", + Template: `{{expectComplex128 1i}}`, + Data: Void{}, + }, + // {"ideal int", "{{typeOf " + bigInt + "}}", "int", 0, true}, + { + Name: "ideal big int", + Template: fmt.Sprintf(`{{expectInt 0x%x}}}`, 1<= 0 { + return list[i] + } else { + var zero T + t.Fatalf("failed to find") + return zero + } +} diff --git a/internal/check/types_test.go b/internal/check/types_test.go new file mode 100644 index 0000000..4e6ac77 --- /dev/null +++ b/internal/check/types_test.go @@ -0,0 +1,519 @@ +package check_test + +import ( + "bytes" + "errors" + "fmt" + "math" + "strings" + "text/template" +) + +type Void struct{} + +type TypeWithMethodSignatureNoResultMethod struct{} + +func (TypeWithMethodSignatureNoResultMethod) Method() {} + +type TypeWithMethodSignatureResult struct{} + +func (TypeWithMethodSignatureResult) Method() struct{} { return struct{}{} } + +type TypeWithMethodSignatureResultAndError struct{} + +func (TypeWithMethodSignatureResultAndError) Method() (struct{}, error) { return struct{}{}, nil } + +type TypeWithMethodSignatureResultAndNonError struct{} + +func (TypeWithMethodSignatureResultAndNonError) Method() (struct{}, int) { return struct{}{}, 0 } + +type TypeWithMethodSignatureThreeResults struct{} + +func (TypeWithMethodSignatureThreeResults) Method() (struct{}, struct{}, error) { + return struct{}{}, struct{}{}, nil +} + +type TypeWithMethodSignatureResultHasMethod struct{} + +func (TypeWithMethodSignatureResultHasMethod) Method() (_ TypeWithMethodSignatureResult) { + return +} + +type TypeWithMethodSignatureResultHasMethodWithNoResults struct{} + +func (TypeWithMethodSignatureResultHasMethodWithNoResults) Method() (_ TypeWithMethodSignatureNoResultMethod) { + return +} + +type StructWithField struct { + Field struct{} +} + +type StructWithFieldWithMethod struct { + Field TypeWithMethodSignatureResultAndError +} + +type StructWithFuncFieldWithResultWithMethod struct { + Func func() TypeWithMethodSignatureResult +} + +type MethodWithIntParam struct{} + +func (MethodWithIntParam) F(int) (_ Void) { return } + +type MethodWithInt8Param struct{} + +func (MethodWithInt8Param) F(int8) (_ Void) { return } + +type MethodWithInt16Param struct{} + +func (MethodWithInt16Param) F(int16) (_ Void) { return } + +type MethodWithInt32Param struct{} + +func (MethodWithInt32Param) F(int32) (_ Void) { return } + +type MethodWithInt64Param struct{} + +func (MethodWithInt64Param) F(int64) (_ Void) { return } + +type MethodWithUintParam struct{} + +func (MethodWithUintParam) F(uint) (_ Void) { return } + +type MethodWithUint8Param struct{} + +func (MethodWithUint8Param) F(uint8) (_ Void) { return } + +type MethodWithUint16Param struct{} + +func (MethodWithUint16Param) F(uint16) (_ Void) { return } + +type MethodWithUint32Param struct{} + +func (MethodWithUint32Param) F(uint32) (_ Void) { return } + +type MethodWithUint64Param struct{} + +func (MethodWithUint64Param) F(uint64) (_ Void) { return } + +type MethodWithBoolParam struct{} + +func (MethodWithBoolParam) F(bool) (_ Void) { return } + +type MethodWithFloat64Param struct{} + +func (MethodWithFloat64Param) F(float64) (_ Void) { return } + +type MethodWithFloat32Param struct{} + +func (MethodWithFloat32Param) F(float32) (_ Void) { return } + +type TypeWithMethodSignatureResultMethodWithFloat32Param struct{} + +func (TypeWithMethodSignatureResultMethodWithFloat32Param) Method() (_ MethodWithFloat32Param) { + return +} + +type TypeWithMethodAndSliceFloat64 struct { + MethodWithFloat64Param + Numbers []float64 +} + +type TypeWithMethodAndArrayFloat64 struct { + MethodWithFloat64Param + Numbers [2]float64 +} + +type MethodWithKeyValForSlices struct { + Numbers []float64 +} + +func (MethodWithKeyValForSlices) F(int, float64) (_ Void) { return } + +type MethodWithKeyValForArray struct { + Numbers [2]float64 +} + +func (MethodWithKeyValForArray) F(int, float64) (_ Void) { return } + +type MethodWithKeyValForMap struct { + Numbers map[int16]float32 +} + +func (MethodWithKeyValForMap) F(int16, float32) (_ Void) { return } + +func square(n int) int { + return n * n +} + +func ceil(n float64) int { + return int(math.Ceil(n)) +} + +func expectInt(n int) int { return n } + +func expectFloat64(n float64) float64 { return n } + +func expectString(s string) string { return s } + +func expectInt8(n int8) int8 { return n } + +func expectFloat32(n float32) float32 { return n } + +func expectComplex64(n complex64) complex64 { return n } + +func expectComplex128(n complex128) complex128 { return n } + +type ( + LetterChainA struct { + A LetterChainB + } + LetterChainB struct { + B LetterChainC + } + LetterChainC struct { + C LetterChainD + } + LetterChainD struct { + D Void + } +) + +// T has lots of interesting pieces to use to test execution. +type T struct { + // Basics + True bool + I int + U16 uint16 + X, S string + FloatZero float64 + ComplexZero complex128 + // Nested structs. + U *U + // Struct with String method. + V0 V + V1, V2 *V + // Struct with Error method. + W0 W + W1, W2 *W + // Slices + SI []int + SICap []int + SIEmpty []int + SB []bool + // Arrays + AI [3]int + // Maps + MSI map[string]int + MSIone map[string]int // one element, for deterministic output + MSIEmpty map[string]int + MXI map[any]int + MII map[int]int + MI32S map[int32]string + MI64S map[int64]string + MUI32S map[uint32]string + MUI64S map[uint64]string + MI8S map[int8]string + MUI8S map[uint8]string + SMSI []map[string]int + // Empty interfaces; used to see if we can dig inside one. + Empty0 any // nil + Empty1 any + Empty2 any + Empty3 any + Empty4 any + // Non-empty interfaces. + NonEmptyInterface I + NonEmptyInterfacePtS *I + NonEmptyInterfaceNil I + NonEmptyInterfaceTypedNil I + // Stringer. + Str fmt.Stringer + Err error + // Pointers + PI *int + PS *string + PSI *[]int + NIL *int + // Function (not method) + BinaryFunc func(string, string) string + VariadicFunc func(...string) string + VariadicFuncInt func(int, ...string) string + NilOKFunc func(*int) bool + ErrFunc func() (string, error) + PanicFunc func() string + TooFewReturnCountFunc func() + TooManyReturnCountFunc func() (string, error, int) + InvalidReturnTypeFunc func() (string, bool) + // Template to test evaluation of templates. + Tmpl *template.Template + // Unexported field; cannot be accessed by template. + unexported int +} + +type S []string + +func (S) Method0() string { + return "M0" +} + +type U struct { + V string +} + +type V struct { + j int +} + +func (v *V) String() string { + if v == nil { + return "nilV" + } + return fmt.Sprintf("<%d>", v.j) +} + +type W struct { + k int +} + +func (w *W) Error() string { + if w == nil { + return "nilW" + } + return fmt.Sprintf("[%d]", w.k) +} + +var siVal = I(S{"a", "b"}) + +var tVal = &T{ + True: true, + I: 17, + U16: 16, + X: "x", + S: "xyz", + U: &U{"v"}, + V0: V{6666}, + V1: &V{7777}, // leave V2 as nil + W0: W{888}, + W1: &W{999}, // leave W2 as nil + SI: []int{3, 4, 5}, + SICap: make([]int, 5, 10), + AI: [3]int{3, 4, 5}, + SB: []bool{true, false}, + MSI: map[string]int{"one": 1, "two": 2, "three": 3}, + MSIone: map[string]int{"one": 1}, + MXI: map[any]int{"one": 1}, + MII: map[int]int{1: 1}, + MI32S: map[int32]string{1: "one", 2: "two"}, + MI64S: map[int64]string{2: "i642", 3: "i643"}, + MUI32S: map[uint32]string{2: "u322", 3: "u323"}, + MUI64S: map[uint64]string{2: "ui642", 3: "ui643"}, + MI8S: map[int8]string{2: "i82", 3: "i83"}, + MUI8S: map[uint8]string{2: "u82", 3: "u83"}, + SMSI: []map[string]int{ + {"one": 1, "two": 2}, + {"eleven": 11, "twelve": 12}, + }, + Empty1: 3, + Empty2: "empty2", + Empty3: []int{7, 8}, + Empty4: &U{"UinEmpty"}, + NonEmptyInterface: &T{X: "x"}, + NonEmptyInterfacePtS: &siVal, + NonEmptyInterfaceTypedNil: (*T)(nil), + Str: bytes.NewBuffer([]byte("foozle")), + Err: errors.New("erroozle"), + PI: newInt(23), + PS: newString("a string"), + PSI: newIntSlice(21, 22, 23), + BinaryFunc: func(a, b string) string { return fmt.Sprintf("[%s=%s]", a, b) }, + VariadicFunc: func(s ...string) string { return fmt.Sprint("<", strings.Join(s, "+"), ">") }, + VariadicFuncInt: func(a int, s ...string) string { return fmt.Sprint(a, "=<", strings.Join(s, "+"), ">") }, + NilOKFunc: func(s *int) bool { return s == nil }, + ErrFunc: func() (string, error) { return "bla", nil }, + PanicFunc: func() string { panic("test panic") }, + TooFewReturnCountFunc: func() {}, + TooManyReturnCountFunc: func() (string, error, int) { return "", nil, 0 }, + InvalidReturnTypeFunc: func() (string, bool) { return "", false }, + Tmpl: template.Must(template.New("x").Parse("test template")), // "x" is the value of .X +} + +var tSliceOfNil = []*T{nil} + +// A non-empty interface. +type I interface { + Method0() string +} + +var iVal I = tVal + +// Helpers for creation. +func newInt(n int) *int { + return &n +} + +func newString(s string) *string { + return &s +} + +func newIntSlice(n ...int) *[]int { + p := new([]int) + *p = make([]int, len(n)) + copy(*p, n) + return p +} + +// Simple methods with and without arguments. +func (t *T) Method0() string { + return "M0" +} + +func (t *T) Method1(a int) int { + return a +} + +func (t *T) Method2(a uint16, b string) string { + return fmt.Sprintf("Method2: %d %s", a, b) +} + +func (t *T) Method3(v any) string { + return fmt.Sprintf("Method3: %v", v) +} + +func (t *T) Copy() *T { + n := new(T) + *n = *t + return n +} + +func (t *T) MAdd(a int, b []int) []int { + v := make([]int, len(b)) + for i, x := range b { + v[i] = x + a + } + return v +} + +var myError = errors.New("my error") + +// MyError returns a value and an error according to its argument. +func (t *T) MyError(error bool) (bool, error) { + if error { + return true, myError + } + return false, nil +} + +// A few methods to test chaining. +func (t *T) GetU() *U { + return t.U +} + +func (u *U) TrueFalse(b bool) string { + if b { + return "true" + } + return "" +} + +func typeOf(arg any) string { + return fmt.Sprintf("%T", arg) +} + +func zeroArgs() string { + return "zeroArgs" +} + +func oneArg(a string) string { + return "oneArg=" + a +} + +func twoArgs(a, b string) string { + return "twoArgs=" + a + b +} + +func dddArg(a int, b ...string) string { + return fmt.Sprintln(a, b) +} + +// count returns a channel that will deliver n sequential 1-letter strings starting at "a" +func count(n int) chan string { + if n == 0 { + return nil + } + c := make(chan string) + go func() { + for i := 0; i < n; i++ { + c <- "abcdefghijklmnop"[i : i+1] + } + close(c) + }() + return c +} + +// vfunc takes a *V and a V +func vfunc(V, *V) string { + return "vfunc" +} + +// valueString takes a string, not a pointer. +func valueString(v string) string { + return "value is ignored" +} + +// returnInt returns an int +func returnInt() int { + return 7 +} + +func add(args ...int) int { + sum := 0 + for _, x := range args { + sum += x + } + return sum +} + +func echo(arg any) any { + return arg +} + +func makemap(arg ...string) map[string]string { + if len(arg)%2 != 0 { + panic("bad makemap") + } + m := make(map[string]string) + for i := 0; i < len(arg); i += 2 { + m[arg[i]] = arg[i+1] + } + return m +} + +func stringer(s fmt.Stringer) string { + return s.String() +} + +func mapOfThree() any { + return map[string]int{"three": 3} +} + +func die() bool { panic("die") } + +func print(in ...any) string { + return fmt.Sprint(in...) +} + +func println(in ...any) string { + return fmt.Sprintln(in...) +} + +func printf(f string, in ...any) string { + return fmt.Sprintf(f, in...) +} + +func not(in bool) bool { return !in } + +func and(...any) bool { return false } + +func or(...any) bool { return false } diff --git a/internal/source/template.go b/internal/source/template.go index 42b1139..05f379b 100644 --- a/internal/source/template.go +++ b/internal/source/template.go @@ -1,9 +1,12 @@ package source import ( + "bytes" "fmt" "go/ast" + "go/format" "go/token" + "go/types" "html/template" "path/filepath" "slices" @@ -13,7 +16,8 @@ import ( "golang.org/x/tools/go/packages" ) -func Templates(workingDirectory, templatesVariable string, pkg *packages.Package) (*template.Template, error) { +func Templates(workingDirectory, templatesVariable string, pkg *packages.Package) (*template.Template, Functions, error) { + funcTypeMap := DefaultFunctions(pkg.Types) for _, tv := range IterateValueSpecs(pkg.Syntax) { i := slices.IndexFunc(tv.Names, func(e *ast.Ident) bool { return e.Name == templatesVariable @@ -23,19 +27,31 @@ func Templates(workingDirectory, templatesVariable string, pkg *packages.Package } embeddedPaths, err := relativeFilePaths(workingDirectory, pkg.EmbedFiles...) if err != nil { - return nil, fmt.Errorf("failed to calculate relative path for embedded files: %w", err) + return nil, nil, fmt.Errorf("failed to calculate relative path for embedded files: %w", err) } const templatePackageIdent = "template" - ts, err := evaluateTemplateSelector(nil, tv.Values[i], workingDirectory, templatesVariable, templatePackageIdent, "", "", pkg.Fset, pkg.Syntax, embeddedPaths) + ts, err := evaluateTemplateSelector(nil, pkg.Types, tv.Values[i], workingDirectory, templatesVariable, templatePackageIdent, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap)) if err != nil { - return nil, fmt.Errorf("run template %s failed at %w", templatesVariable, err) + return nil, nil, fmt.Errorf("run template %s failed at %w", templatesVariable, err) } - return ts, nil + return ts, funcTypeMap, nil } - return nil, fmt.Errorf("variable %s not found", templatesVariable) + return nil, nil, fmt.Errorf("variable %s not found", templatesVariable) } -func evaluateTemplateSelector(ts *template.Template, expression ast.Expr, workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string) (*template.Template, error) { +func findPackage(pkg *types.Package, path string) (*types.Package, bool) { + if pkg == nil || pkg.Path() == path { + return pkg, true + } + for _, im := range pkg.Imports() { + if p, ok := findPackage(im, path); ok { + return p, true + } + } + return nil, false +} + +func evaluateTemplateSelector(ts *template.Template, pkg *types.Package, expression ast.Expr, workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string, funcTypeMaps Functions, fm template.FuncMap) (*template.Template, error) { call, ok := expression.(*ast.CallExpr) if !ok { return nil, contextError(workingDirectory, fileSet, expression.Pos(), fmt.Errorf("expected call expression")) @@ -56,7 +72,7 @@ func evaluateTemplateSelector(ts *template.Template, expression ast.Expr, workin if len(call.Args) != 1 { return nil, contextError(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one argument %s got %d", Format(sel.X), len(call.Args))) } - return evaluateTemplateSelector(ts, call.Args[0], workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim, fileSet, files, embeddedPaths) + return evaluateTemplateSelector(ts, pkg, call.Args[0], workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim, fileSet, files, embeddedPaths, funcTypeMaps, fm) case "New": if len(call.Args) != 1 { return nil, contextError(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument")) @@ -76,7 +92,7 @@ func evaluateTemplateSelector(ts *template.Template, expression ast.Expr, workin return nil, contextError(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported function %s", sel.Sel.Name)) } case *ast.CallExpr: - up, err := evaluateTemplateSelector(ts, sel.X, workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim, fileSet, files, embeddedPaths) + up, err := evaluateTemplateSelector(ts, pkg, sel.X, workingDirectory, templatesVariable, templatePackageIdent, rDelim, lDelim, fileSet, files, embeddedPaths, funcTypeMaps, fm) if err != nil { return nil, err } @@ -121,44 +137,44 @@ func evaluateTemplateSelector(ts *template.Template, expression ast.Expr, workin } return up.Option(list...), nil case "Funcs": - funcMap, err := evaluateFuncMap(workingDirectory, templatePackageIdent, fileSet, call) - if err != nil { + if err := evaluateFuncMap(workingDirectory, templatePackageIdent, pkg, fileSet, call, fm, funcTypeMaps); err != nil { return nil, err } - return up.Funcs(funcMap), nil + return up.Funcs(fm), nil default: return nil, contextError(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported method %s", sel.Sel.Name)) } } } -func evaluateFuncMap(workingDirectory, templatePackageIdent string, fileSet *token.FileSet, call *ast.CallExpr) (template.FuncMap, error) { +func evaluateFuncMap(workingDirectory, templatePackageIdent string, pkg *types.Package, fileSet *token.FileSet, call *ast.CallExpr, fm template.FuncMap, funcTypesMap Functions) error { const funcMapTypeIdent = "FuncMap" - fm := make(template.FuncMap) if len(call.Args) != 1 { - return nil, contextError(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly 1 template.FuncMap composite literal argument")) + return contextError(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly 1 template.FuncMap composite literal argument")) } arg := call.Args[0] lit, ok := arg.(*ast.CompositeLit) if !ok { - return nil, contextError(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a composite literal with type %s.%s got %s", templatePackageIdent, funcMapTypeIdent, Format(arg))) + return contextError(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a composite literal with type %s.%s got %s", templatePackageIdent, funcMapTypeIdent, Format(arg))) } typeSel, ok := lit.Type.(*ast.SelectorExpr) if !ok || typeSel.Sel.Name != funcMapTypeIdent { - return nil, contextError(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a composite literal with type %s.%s got %s", templatePackageIdent, funcMapTypeIdent, Format(arg))) + return contextError(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a composite literal with type %s.%s got %s", templatePackageIdent, funcMapTypeIdent, Format(arg))) } if tp, ok := typeSel.X.(*ast.Ident); !ok || tp.Name != templatePackageIdent { - return nil, contextError(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a composite literal with type %s.%s got %s", templatePackageIdent, funcMapTypeIdent, Format(arg))) + return contextError(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a composite literal with type %s.%s got %s", templatePackageIdent, funcMapTypeIdent, Format(arg))) } + var buf bytes.Buffer for i, exp := range lit.Elts { el, ok := exp.(*ast.KeyValueExpr) if !ok { - return nil, contextError(workingDirectory, fileSet, exp.Pos(), fmt.Errorf("expected element at index %d to be a key value pair got %s", i, Format(exp))) + return contextError(workingDirectory, fileSet, exp.Pos(), fmt.Errorf("expected element at index %d to be a key value pair got %s", i, Format(exp))) } funcName, err := evaluateStringLiteralExpression(workingDirectory, fileSet, el.Key) if err != nil { - return nil, err + return err } + // template.Parse does not evaluate the function signature parameters; // it ensures the function name is in scope and there is one or two results. // we could use something like func() string { return "" } for this signature @@ -171,8 +187,21 @@ func evaluateFuncMap(workingDirectory, templatePackageIdent string, fileSet *tok // or // fm[funcName] = func() (int, int) {return 0, 0} // will fail because the second result is not an error fm[funcName] = fmt.Sprintln + + if pkg == nil { + continue + } + buf.Reset() + if err := format.Node(&buf, fileSet, el.Value); err != nil { + return err + } + tv, err := types.Eval(fileSet, pkg, lit.Pos(), buf.String()) + if err != nil { + return err + } + funcTypesMap[funcName] = tv.Type.(*types.Signature) } - return fm, nil + return nil } func evaluateCallParseFilesArgs(workingDirectory string, fileSet *token.FileSet, call *ast.CallExpr, files []*ast.File, embeddedPaths []string) ([]string, error) { @@ -342,3 +371,30 @@ func relativeFilePaths(wd string, abs ...string) ([]string, error) { } return result, nil } + +type Functions map[string]*types.Signature + +func NewFunctions(m map[string]*types.Signature) Functions { + return Functions(m) +} + +func DefaultFunctions(pkg *types.Package) Functions { + funcTypeMap := make(Functions) + fmtPkg, ok := findPackage(pkg, "fmt") + if !ok || fmtPkg == nil { + return funcTypeMap + } + funcTypeMap["printf"] = fmtPkg.Scope().Lookup("Sprintf").Type().(*types.Signature) + funcTypeMap["print"] = fmtPkg.Scope().Lookup("Sprint").Type().(*types.Signature) + funcTypeMap["println"] = fmtPkg.Scope().Lookup("Sprintln").Type().(*types.Signature) + return funcTypeMap +} + +func (functions Functions) FindFunction(name string) (*types.Signature, bool) { + m := (map[string]*types.Signature)(functions) + fn, ok := m[name] + if !ok { + return nil, false + } + return fn, true +} diff --git a/internal/source/template_test.go b/internal/source/template_test.go index 480dfbb..08934b0 100644 --- a/internal/source/template_test.go +++ b/internal/source/template_test.go @@ -23,14 +23,14 @@ func TestTemplates(t *testing.T) { t.Run("non call", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir) - _, err := source.Templates(dir, "templatesIdent", pkg) + _, _, err := source.Templates(dir, "templatesIdent", pkg) require.ErrorContains(t, err, "run template templatesIdent failed at template.go:32:19: expected call expression") }) t.Run("call ParseFS", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/template_ParseFS.txtar")) pkg := parseGo(t, dir, "index.gohtml", "form.gohtml") - ts, err := source.Templates(dir, "templates", pkg) + ts, _, err := source.Templates(dir, "templates", pkg) require.NoError(t, err) var names []string for _, t := range ts.Templates() { @@ -43,7 +43,7 @@ func TestTemplates(t *testing.T) { t.Run("call ParseFS with assets dir", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/assets_dir.txtar")) pkg := parseGo(t, dir, "assets/index.gohtml", "assets/form.gohtml") - ts, err := source.Templates(dir, "templates", pkg) + ts, _, err := source.Templates(dir, "templates", pkg) require.NoError(t, err) var names []string for _, t := range ts.Templates() { @@ -56,7 +56,7 @@ func TestTemplates(t *testing.T) { t.Run("call New", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - ts, err := source.Templates(dir, "templateNew", pkg) + ts, _, err := source.Templates(dir, "templateNew", pkg) require.NoError(t, err) var names []string for _, t := range ts.Templates() { @@ -69,7 +69,7 @@ func TestTemplates(t *testing.T) { t.Run("call New after calling ParseFS", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - ts, err := source.Templates(dir, "templateParseFSNew", pkg) + ts, _, err := source.Templates(dir, "templateParseFSNew", pkg) require.NoError(t, err) var names []string for _, t := range ts.Templates() { @@ -82,7 +82,7 @@ func TestTemplates(t *testing.T) { t.Run("call New before calling ParseFS", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - ts, err := source.Templates(dir, "templateNewParseFS", pkg) + ts, _, err := source.Templates(dir, "templateNewParseFS", pkg) require.NoError(t, err) var names []string @@ -96,7 +96,7 @@ func TestTemplates(t *testing.T) { t.Run("call new with non args", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateNewMissingArg", pkg) + _, _, err := source.Templates(dir, "templateNewMissingArg", pkg) require.ErrorContains(t, err, "expected exactly one string literal argument") }) @@ -104,7 +104,7 @@ func TestTemplates(t *testing.T) { t.Run("call New on unknown X", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateWrongX", pkg) + _, _, err := source.Templates(dir, "templateWrongX", pkg) require.ErrorContains(t, err, "template.go:20:19: expected template got UNKNOWN") }) @@ -112,7 +112,7 @@ func TestTemplates(t *testing.T) { t.Run("call New with wrong arg count", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateWrongArgCount", pkg) + _, _, err := source.Templates(dir, "templateWrongArgCount", pkg) require.ErrorContains(t, err, "template.go:22:38: expected exactly one string literal argument") }) @@ -120,7 +120,7 @@ func TestTemplates(t *testing.T) { t.Run("call New on unexpected X", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateNewOnIndexed", pkg) + _, _, err := source.Templates(dir, "templateNewOnIndexed", pkg) require.ErrorContains(t, err, "template.go:24:25: expected exactly one argument ts[0] got 2") }) @@ -128,7 +128,7 @@ func TestTemplates(t *testing.T) { t.Run("call New with non string literal arg", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateNewArg42", pkg) + _, _, err := source.Templates(dir, "templateNewArg42", pkg) require.ErrorContains(t, err, "template.go:26:34: expected string literal got 42") }) @@ -136,7 +136,7 @@ func TestTemplates(t *testing.T) { t.Run("call New with non literal arg", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateNewArgIdent", pkg) + _, _, err := source.Templates(dir, "templateNewArgIdent", pkg) require.ErrorContains(t, err, "template.go:28:37: expected string literal got TemplateName") }) @@ -144,7 +144,7 @@ func TestTemplates(t *testing.T) { t.Run("call New with upstream error", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateNewErrUpstream", pkg) + _, _, err := source.Templates(dir, "templateNewErrUpstream", pkg) require.ErrorContains(t, err, "run template templateNewErrUpstream failed at template.go:30:40: expected string literal got fail") }) @@ -152,7 +152,7 @@ func TestTemplates(t *testing.T) { t.Run("unknown templates variable", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "variableDoesNotExist", pkg) + _, _, err := source.Templates(dir, "variableDoesNotExist", pkg) require.NotNil(t, err) require.Equal(t, "variable variableDoesNotExist not found", err.Error()) @@ -161,7 +161,7 @@ func TestTemplates(t *testing.T) { t.Run("unknown templates variable", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "unsupportedMethod", pkg) + _, _, err := source.Templates(dir, "unsupportedMethod", pkg) require.ErrorContains(t, err, "run template unsupportedMethod failed at template.go:34:22: unsupported function Unknown") }) @@ -169,7 +169,7 @@ func TestTemplates(t *testing.T) { t.Run("call Must with unexpected function expression", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "unexpectedFunExpression", pkg) + _, _, err := source.Templates(dir, "unexpectedFunExpression", pkg) require.ErrorContains(t, err, "run template unexpectedFunExpression failed at template.go:36:28: unexpected expression *ast.IndexExpr: x[3]") }) @@ -177,91 +177,91 @@ func TestTemplates(t *testing.T) { t.Run("call Must on non ident receiver", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateMustNonIdentReceiver", pkg) + _, _, err := source.Templates(dir, "templateMustNonIdentReceiver", pkg) require.ErrorContains(t, err, "run template templateMustNonIdentReceiver failed at template.go:38:33: unexpected expression *ast.Ident: f") }) t.Run("call Must with two arguments", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateMustCalledWithTwoArgs", pkg) + _, _, err := source.Templates(dir, "templateMustCalledWithTwoArgs", pkg) require.ErrorContains(t, err, "run template templateMustCalledWithTwoArgs failed at template.go:40:47: expected exactly one argument template got 2") }) t.Run("call Must with one argument", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateMustCalledWithNoArg", pkg) + _, _, err := source.Templates(dir, "templateMustCalledWithNoArg", pkg) require.ErrorContains(t, err, "run template templateMustCalledWithNoArg failed at template.go:42:47: expected exactly one argument template got 0") }) t.Run("call Must wrong template package ident", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateMustWrongPackageIdent", pkg) + _, _, err := source.Templates(dir, "templateMustWrongPackageIdent", pkg) require.ErrorContains(t, err, "run template templateMustWrongPackageIdent failed at template.go:44:34: expected template got wrong") }) t.Run("call ParseFS wrong template package ident", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateParseFSWrongPackageIdent", pkg) + _, _, err := source.Templates(dir, "templateParseFSWrongPackageIdent", pkg) require.ErrorContains(t, err, "run template templateParseFSWrongPackageIdent failed at template.go:46:37: expected template got wrong") }) t.Run("call ParseFS receiver errored", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateParseFSReceiverErr", pkg) + _, _, err := source.Templates(dir, "templateParseFSReceiverErr", pkg) require.ErrorContains(t, err, "run template templateParseFSReceiverErr failed at template.go:48:43: expected exactly one string literal argument") }) t.Run("call ParseFS unexpected receiver", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateParseFSUnexpectedReceiver", pkg) + _, _, err := source.Templates(dir, "templateParseFSUnexpectedReceiver", pkg) require.ErrorContains(t, err, "run template templateParseFSUnexpectedReceiver failed at template.go:50:38: expected exactly one argument x[0] got 2") }) t.Run("call ParseFS with no arguments", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateParseFSNoArgs", pkg) + _, _, err := source.Templates(dir, "templateParseFSNoArgs", pkg) require.ErrorContains(t, err, "template.go:52:42: missing required arguments") }) t.Run("call ParseFS with first arg non ident", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateParseFSFirstArgNonIdent", pkg) + _, _, err := source.Templates(dir, "templateParseFSFirstArgNonIdent", pkg) require.ErrorContains(t, err, "template.go:54:53: first argument to ParseFS must be an identifier") }) t.Run("call ParseFS with first arg non ident", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateParseFSNonStringLiteralGlob", pkg) + _, _, err := source.Templates(dir, "templateParseFSNonStringLiteralGlob", pkg) require.ErrorContains(t, err, "template.go:56:78: expected string literal got 42") }) t.Run("call ParseFS with bad glob", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateParseFSWithBadGlob", pkg) + _, _, err := source.Templates(dir, "templateParseFSWithBadGlob", pkg) require.ErrorContains(t, err, `template.go:58:64: bad pattern "[fail": syntax error in pattern`) }) t.Run("call ParseFS and fail to get relative template path", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/template_ParseFS.txtar")) pkg := parseGo(t, dir) pkg.EmbedFiles = []string{"\x00/index.gohtml"} // null must not be in a path - _, err := source.Templates(dir, "templates", pkg) + _, _, err := source.Templates(dir, "templates", pkg) require.ErrorContains(t, err, `failed to calculate relative path for embedded files: Rel: can't make`) }) t.Run("call ParseFS and filter filepaths by globs", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/template_ParseFS.txtar")) pkg := parseGo(t, dir, "index.gohtml", "script.html") - tsHTML, err := source.Templates(dir, "templatesHTML", pkg) + tsHTML, _, err := source.Templates(dir, "templatesHTML", pkg) require.NoError(t, err) - tsGoHTML, err := source.Templates(dir, "templatesGoHTML", pkg) + tsGoHTML, _, err := source.Templates(dir, "templatesGoHTML", pkg) assert.NotNil(t, tsHTML.Lookup("script.html")) assert.NotNil(t, tsHTML.Lookup("console_log")) assert.Nil(t, tsGoHTML.Lookup("script.html")) @@ -270,19 +270,19 @@ func TestTemplates(t *testing.T) { t.Run("call bad embed pattern", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/bad_embed_pattern.txtar")) pkg := parseGo(t, dir, "greeting.gohtml") - _, err := source.Templates(dir, "templates", pkg) + _, _, err := source.Templates(dir, "templates", pkg) require.ErrorContains(t, err, `template.go:9:2: embed comment malformed: syntax error in pattern`) }) t.Run("call bad embed pattern", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/template_ParseFS.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateEmbedVariableNotFound", pkg) + _, _, err := source.Templates(dir, "templateEmbedVariableNotFound", pkg) require.ErrorContains(t, err, `template.go:22:65: variable hiding not found`) }) t.Run("multiple delimiter types", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/delims.txtar")) pkg := parseGo(t, dir, "default.gohtml", "triple_parens.gohtml", "double_square.gohtml") - templates, err := source.Templates(dir, "templates", pkg) + templates, _, err := source.Templates(dir, "templates", pkg) require.NoError(t, err) var names []string for _, ts := range templates.Templates() { @@ -293,142 +293,142 @@ func TestTemplates(t *testing.T) { t.Run("Run method call gets no args", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateNewHasWrongNumberOfArgs", pkg) + _, _, err := source.Templates(dir, "templateNewHasWrongNumberOfArgs", pkg) require.ErrorContains(t, err, `template.go:60:101: expected exactly one string literal argument`) }) t.Run("Run method call gets wrong type of args", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateNewHasWrongTypeOfArgs", pkg) + _, _, err := source.Templates(dir, "templateNewHasWrongTypeOfArgs", pkg) require.ErrorContains(t, err, `template.go:62:56: expected string literal got 9000`) }) t.Run("Run method call gets too many args", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateNewHasTooManyArgs", pkg) + _, _, err := source.Templates(dir, "templateNewHasTooManyArgs", pkg) require.ErrorContains(t, err, `template.go:64:51: expected exactly one string literal argument`) }) t.Run("Delims method call gets no args", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateDelimsGetsNoArgs", pkg) + _, _, err := source.Templates(dir, "templateDelimsGetsNoArgs", pkg) require.ErrorContains(t, err, `template.go:66:53: expected exactly two string literal arguments`) }) t.Run("Delims method call gets too many args", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateDelimsGetsTooMany", pkg) + _, _, err := source.Templates(dir, "templateDelimsGetsTooMany", pkg) require.ErrorContains(t, err, `template.go:68:54: expected exactly two string literal arguments`) }) t.Run("Delims have wrong type of argument expressions", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateDelimsWrongExpressionArg", pkg) + _, _, err := source.Templates(dir, "templateDelimsWrongExpressionArg", pkg) require.ErrorContains(t, err, `template.go:70:67: expected string literal got y`) }) t.Run("ParseFS method fails", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateParseFSMethodFails", pkg) + _, _, err := source.Templates(dir, "templateParseFSMethodFails", pkg) require.ErrorContains(t, err, `template.go:72:73: expected string literal got fail`) }) t.Run("Options method requires string literals", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateOptionsRequiresStringLiterals", pkg) + _, _, err := source.Templates(dir, "templateOptionsRequiresStringLiterals", pkg) require.ErrorContains(t, err, `template.go:74:67: expected string literal got fail`) }) t.Run("unknown method", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateUnknownMethod", pkg) + _, _, err := source.Templates(dir, "templateUnknownMethod", pkg) require.ErrorContains(t, err, `template.go:76:26: unsupported method Unknown`) }) t.Run("Option call", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") - _, err := source.Templates(dir, "templateOptionCall", pkg) + _, _, err := source.Templates(dir, "templateOptionCall", pkg) require.NoError(t, err) }) t.Run("Option call wrong argument", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/templates.txtar")) pkg := parseGo(t, dir, "index.gohtml") assert.Panics(t, func() { - _, _ = source.Templates(dir, "templateOptionCallUnknownArg", pkg) + _, _, _ = source.Templates(dir, "templateOptionCallUnknownArg", pkg) }) }) t.Run("Funcs call", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/funcs.txtar")) pkg := parseGo(t, dir, "greet.gohtml") - _, err := source.Templates(dir, "templates", pkg) + _, _, err := source.Templates(dir, "templates", pkg) require.NoError(t, err) }) t.Run("Func not defined", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/funcs.txtar")) pkg := parseGo(t, dir, "missing_func.gohtml", "greet.gohtml") - _, err := source.Templates(dir, "templatesFuncNotDefined", pkg) + _, _, err := source.Templates(dir, "templatesFuncNotDefined", pkg) require.ErrorContains(t, err, `missing_func.gohtml:1: function "enemy" not defined`) }) t.Run("Func wrong parameter kind", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/funcs.txtar")) pkg := parseGo(t, dir, "missing_func.gohtml", "greet.gohtml") - _, err := source.Templates(dir, "templatesWrongArg", pkg) + _, _, err := source.Templates(dir, "templatesWrongArg", pkg) require.ErrorContains(t, err, `expected a composite literal with type template.FuncMap got wrong`) }) t.Run("Func wrong too many args", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/funcs.txtar")) pkg := parseGo(t, dir, "missing_func.gohtml", "greet.gohtml") - _, err := source.Templates(dir, "templatesTwoArgs", pkg) + _, _, err := source.Templates(dir, "templatesTwoArgs", pkg) require.ErrorContains(t, err, `expected exactly 1 template.FuncMap composite literal argument`) }) t.Run("Func wrong too no args", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/funcs.txtar")) pkg := parseGo(t, dir, "missing_func.gohtml", "greet.gohtml") - _, err := source.Templates(dir, "templatesNoArgs", pkg) + _, _, err := source.Templates(dir, "templatesNoArgs", pkg) require.ErrorContains(t, err, `expected exactly 1 template.FuncMap composite literal argument`) }) t.Run("Func wrong package ident", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/funcs.txtar")) pkg := parseGo(t, dir, "missing_func.gohtml", "greet.gohtml") - _, err := source.Templates(dir, "templatesWrongTypePackageName", pkg) + _, _, err := source.Templates(dir, "templatesWrongTypePackageName", pkg) require.ErrorContains(t, err, `expected a composite literal with type template.FuncMap got wrong.FuncMap{}`) }) t.Run("Func wrong Type ident", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/funcs.txtar")) pkg := parseGo(t, dir, "missing_func.gohtml", "greet.gohtml") - _, err := source.Templates(dir, "templatesWrongTypeName", pkg) + _, _, err := source.Templates(dir, "templatesWrongTypeName", pkg) require.ErrorContains(t, err, `expected a composite literal with type template.FuncMap got template.Wrong{}`) }) t.Run("Func wrong Type", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/funcs.txtar")) pkg := parseGo(t, dir, "missing_func.gohtml", "greet.gohtml") - _, err := source.Templates(dir, "templatesWrongTypeExpression", pkg) + _, _, err := source.Templates(dir, "templatesWrongTypeExpression", pkg) require.ErrorContains(t, err, `expected a composite literal with type template.FuncMap got wrong{}`) }) t.Run("Func wrong elem", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/funcs.txtar")) pkg := parseGo(t, dir, "missing_func.gohtml", "greet.gohtml") - _, err := source.Templates(dir, "templatesWrongTypeElem", pkg) + _, _, err := source.Templates(dir, "templatesWrongTypeElem", pkg) require.ErrorContains(t, err, `expected element at index 0 to be a key value pair got wrong`) }) t.Run("Func wrong elem key", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/funcs.txtar")) pkg := parseGo(t, dir, "missing_func.gohtml", "greet.gohtml") - _, err := source.Templates(dir, "templatesWrongElemKey", pkg) + _, _, err := source.Templates(dir, "templatesWrongElemKey", pkg) require.ErrorContains(t, err, `expected string literal got wrong`) }) t.Run("Parse template name from new", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/parse.txtar")) pkg := parseGo(t, dir) - ts, err := source.Templates(dir, "templates", pkg) + ts, _, err := source.Templates(dir, "templates", pkg) require.NoError(t, err) assert.NotNil(t, ts.Lookup("GET /")) }) t.Run("Parse string has multiple routes", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/parse.txtar")) pkg := parseGo(t, dir) - ts, err := source.Templates(dir, "multiple", pkg) + ts, _, err := source.Templates(dir, "multiple", pkg) require.NoError(t, err) assert.NotNil(t, ts.Lookup("GET /")) assert.NotNil(t, ts.Lookup("GET /{name}")) @@ -436,13 +436,13 @@ func TestTemplates(t *testing.T) { t.Run("Parse is missing argument", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/parse.txtar")) pkg := parseGo(t, dir) - _, err := source.Templates(dir, "noArg", pkg) + _, _, err := source.Templates(dir, "noArg", pkg) require.ErrorContains(t, err, "run template noArg failed at parse.go:12:35: expected exactly one string literal argument") }) t.Run("Parse gets wrong argument type", func(t *testing.T) { dir := createTestDir(t, filepath.FromSlash("testdata/template/parse.txtar")) pkg := parseGo(t, dir) - _, err := source.Templates(dir, "wrongArg", pkg) + _, _, err := source.Templates(dir, "wrongArg", pkg) require.ErrorContains(t, err, "run template wrongArg failed at parse.go:14:40: expected string literal got 500") }) } diff --git a/routes.go b/routes.go index c2b4992..a330a56 100644 --- a/routes.go +++ b/routes.go @@ -14,6 +14,7 @@ import ( "slices" "strconv" "strings" + "text/template/parse" "time" "github.com/crhntr/dom" @@ -22,6 +23,7 @@ import ( "golang.org/x/net/html/atom" "golang.org/x/tools/go/packages" + "github.com/crhntr/muxt/internal/check" "github.com/crhntr/muxt/internal/source" ) @@ -54,6 +56,7 @@ const ( ) type RoutesFileConfiguration struct { + ExperimentalCheckTypes, executeFunc bool PackageName, PackagePath, @@ -82,7 +85,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur imports := source.NewImports(&ast.GenDecl{Tok: token.IMPORT}) patterns := []string{ - wd, "net/http", + wd, "net/http", "fmt", } if config.ReceiverPackage != "" { @@ -125,7 +128,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur receiver = types.NewNamed(types.NewTypeName(0, routesPkg.Types, "Receiver", nil), types.NewStruct(nil, nil), nil) } - ts, err := source.Templates(wd, config.TemplatesVariable, routesPkg) + ts, fm, err := source.Templates(wd, config.TemplatesVariable, routesPkg) if err != nil { return "", err } @@ -205,6 +208,16 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur handlerFunc.Body.List = append(handlerFunc.Body.List, receiverCallStatements...) handlerFunc.Body.List = append(handlerFunc.Body.List, t.executeCall(source.HTTPStatusCode(imports, t.statusCode), ast.NewIdent(dataVarIdent), writeHeader)) routesFunc.Body.List = append(routesFunc.Body.List, t.callHandleFunc(handlerFunc)) + + if config.ExperimentalCheckTypes { + dataVar := sig.Results().At(0) + if types.Identical(dataVar.Type(), types.Universe.Lookup("any").Type()) { + continue + } + if err := check.Tree(t.template.Tree, dataVar.Type(), dataVar.Pkg(), routesPkg.Fset, newForrest(ts), functionMap(fm)); err != nil { + return "", err + } + } } imports.SortImports() @@ -980,3 +993,28 @@ func executeFuncDecl(imports *source.Imports, templatesVariableIdent string) *as }, } } + +type forest template.Template + +func newForrest(templates *template.Template) *forest { + return (*forest)(templates) +} + +func (f *forest) FindTree(name string) (*parse.Tree, bool) { + ts := (*template.Template)(f).Lookup(name) + if ts == nil { + return nil, false + } + return ts.Tree, true +} + +type functionMap map[string]*types.Signature + +func (fm functionMap) FindFunction(name string) (*types.Signature, bool) { + m := (map[string]*types.Signature)(fm) + fn, ok := m[name] + if !ok { + return nil, false + } + return fn, true +} diff --git a/routes_test.go b/routes_test.go index 1671aa2..a1861ed 100644 --- a/routes_test.go +++ b/routes_test.go @@ -1952,6 +1952,8 @@ var templates = template.Must(template.ParseFS(templatesDir, "template.gohtml")) PackagePath: "example.com", ReceiverType: tt.Receiver, OutputFileName: "template_routes.go", + + ExperimentalCheckTypes: true, }) if tt.ExpectedError == "" { require.NoError(t, err) diff --git a/template.go b/template.go index 1f6b57b..2161806 100644 --- a/template.go +++ b/template.go @@ -152,6 +152,17 @@ func checkPathValueNames(in []string) error { func (t Template) String() string { return t.name } +func (t Template) Method() string { + if t.fun == nil { + return "" + } + return t.fun.Name +} + +func (t Template) Template() *template.Template { + return t.template +} + func (t Template) byPathThenMethod(d Template) int { if n := cmp.Compare(t.path, d.path); n != 0 { return n