Skip to content

Commit

Permalink
rename Pattern to TemplateName
Browse files Browse the repository at this point in the history
  • Loading branch information
crhntr committed Aug 19, 2024
1 parent 56e9d02 commit 1df23cf
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 155 deletions.
4 changes: 2 additions & 2 deletions cmd/muxt/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string)
if err != nil {
return err
}
patterns, err := muxt.TemplatePatterns(ts)
templateNames, err := muxt.TemplateNames(ts)
if err != nil {
return err
}
out := log.New(stdout, "", 0)
s, err := muxt.Generate(patterns, g.goPackage, g.templatesVariable, g.routesFunction, g.receiverIdent, g.Package.Fset, g.Package.Syntax, g.Package.Syntax, out)
s, err := muxt.Generate(templateNames, g.goPackage, g.templatesVariable, g.routesFunction, g.receiverIdent, g.Package.Fset, g.Package.Syntax, g.Package.Syntax, out)
if err != nil {
return err
}
Expand Down
54 changes: 27 additions & 27 deletions generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ const (
DefaultRoutesFunctionName = "Routes"
)

func Generate(patterns []Pattern, packageName, templatesVariableName, routesFunctionName, receiverTypeIdent string, _ *token.FileSet, receiverPackage, templatesPackage []*ast.File, log *log.Logger) (string, error) {
func Generate(templateNames []TemplateName, packageName, templatesVariableName, routesFunctionName, receiverTypeIdent string, _ *token.FileSet, receiverPackage, templatesPackage []*ast.File, log *log.Logger) (string, error) {
packageName = cmp.Or(packageName, defaultPackageName)
templatesVariableName = cmp.Or(templatesVariableName, DefaultTemplatesVariableName)
routesFunctionName = cmp.Or(routesFunctionName, DefaultRoutesFunctionName)
Expand All @@ -61,7 +61,7 @@ func Generate(patterns []Pattern, packageName, templatesVariableName, routesFunc
imports := []*ast.ImportSpec{
importSpec("net/" + httpPackageIdent),
}
for _, pattern := range patterns {
for _, pattern := range templateNames {
var method *ast.FuncType
if pattern.fun != nil {
for _, funcDecl := range source.IterateFunctions(receiverPackage) {
Expand Down Expand Up @@ -113,7 +113,7 @@ func Generate(patterns []Pattern, packageName, templatesVariableName, routesFunc
return source.Format(file), nil
}

func (def Pattern) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStmt {
func (def TemplateName) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStmt {
return &ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(muxVarIdent),
Expand All @@ -129,7 +129,7 @@ func (def Pattern) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStmt {
}}
}

func (def Pattern) funcLit(templatesVariableIdent string, method *ast.FuncType) (*ast.FuncLit, []*ast.ImportSpec, error) {
func (def TemplateName) funcLit(templatesVariableIdent string, method *ast.FuncType) (*ast.FuncLit, []*ast.ImportSpec, error) {
if def.Handler == "" {
return def.httpRequestReceiverTemplateHandlerFunc(templatesVariableIdent), nil, nil
}
Expand All @@ -152,12 +152,12 @@ func (def Pattern) funcLit(templatesVariableIdent string, method *ast.FuncType)
for _, a := range def.call.Args {
arg := a.(*ast.Ident)
switch arg.Name {
case PatternScopeIdentifierHTTPRequest, PatternScopeIdentifierHTTPResponse:
case TemplateNameScopeIdentifierHTTPRequest, TemplateNameScopeIdentifierHTTPResponse:
call.Args = append(call.Args, ast.NewIdent(arg.Name))
imports = append(imports, importSpec("net/http"))
case PatternScopeIdentifierContext:
case TemplateNameScopeIdentifierContext:
lit.Body.List = append(lit.Body.List, contextAssignment())
call.Args = append(call.Args, ast.NewIdent(PatternScopeIdentifierContext))
call.Args = append(call.Args, ast.NewIdent(TemplateNameScopeIdentifierContext))
imports = append(imports, importSpec("context"))
default:
lit.Body.List = append(lit.Body.List, httpPathValueAssignment(arg))
Expand Down Expand Up @@ -204,7 +204,7 @@ func (def Pattern) funcLit(templatesVariableIdent string, method *ast.FuncType)
return lit, imports, nil
}

func (def Pattern) funcType() (*ast.FuncType, []*ast.ImportSpec) {
func (def TemplateName) funcType() (*ast.FuncType, []*ast.ImportSpec) {
method := &ast.FuncType{
Params: &ast.FieldList{},
Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}},
Expand All @@ -213,13 +213,13 @@ func (def Pattern) funcType() (*ast.FuncType, []*ast.ImportSpec) {
for _, a := range def.call.Args {
arg := a.(*ast.Ident)
switch arg.Name {
case PatternScopeIdentifierHTTPRequest:
case TemplateNameScopeIdentifierHTTPRequest:
method.Params.List = append(method.Params.List, httpRequestField())
imports = append(imports, importSpec("net/"+httpPackageIdent))
case PatternScopeIdentifierHTTPResponse:
case TemplateNameScopeIdentifierHTTPResponse:
method.Params.List = append(method.Params.List, httpResponseField())
imports = append(imports, importSpec("net/"+httpPackageIdent))
case PatternScopeIdentifierContext:
case TemplateNameScopeIdentifierContext:
method.Params.List = append(method.Params.List, contextContextField())
imports = append(imports, importSpec(contextPackageIdent))
default:
Expand Down Expand Up @@ -254,24 +254,24 @@ func fieldListTypes(fieldList *ast.FieldList) func(func(int, ast.Expr) bool) {
}
}

func errWrongNumberOfArguments(def Pattern, method *ast.FuncType) error {
func errWrongNumberOfArguments(def TemplateName, method *ast.FuncType) error {
return fmt.Errorf("handler %s expects %d arguments but call %s has %d", source.Format(&ast.FuncDecl{Name: ast.NewIdent(def.fun.Name), Type: method}), method.Params.NumFields(), def.Handler, len(def.call.Args))
}

func checkArgument(exp ast.Expr, tp ast.Expr) error {
arg := exp.(*ast.Ident)
switch arg.Name {
case PatternScopeIdentifierHTTPRequest:
case TemplateNameScopeIdentifierHTTPRequest:
if !matchSelectorIdents(tp, httpPackageIdent, httpRequestIdent, true) {
return fmt.Errorf("method expects type %s but %s is *%s.%s", source.Format(tp), arg.Name, httpPackageIdent, httpRequestIdent)
}
return nil
case PatternScopeIdentifierHTTPResponse:
case TemplateNameScopeIdentifierHTTPResponse:
if !matchSelectorIdents(tp, httpPackageIdent, httpResponseWriterIdent, false) {
return fmt.Errorf("method expects type %s but %s is %s.%s", source.Format(tp), arg.Name, httpPackageIdent, httpResponseWriterIdent)
}
return nil
case PatternScopeIdentifierContext:
case TemplateNameScopeIdentifierContext:
if !matchSelectorIdents(tp, contextPackageIdent, contextContextTypeIdent, false) {
return fmt.Errorf("method expects type %s but %s is %s.%s", source.Format(tp), arg.Name, contextPackageIdent, contextContextTypeIdent)
}
Expand Down Expand Up @@ -310,14 +310,14 @@ func pathValueField(name string) *ast.Field {

func contextContextField() *ast.Field {
return &ast.Field{
Names: []*ast.Ident{ast.NewIdent(PatternScopeIdentifierContext)},
Names: []*ast.Ident{ast.NewIdent(TemplateNameScopeIdentifierContext)},
Type: contextContextType(),
}
}

func httpResponseField() *ast.Field {
return &ast.Field{
Names: []*ast.Ident{ast.NewIdent(PatternScopeIdentifierHTTPResponse)},
Names: []*ast.Ident{ast.NewIdent(TemplateNameScopeIdentifierHTTPResponse)},
Type: &ast.SelectorExpr{X: ast.NewIdent(httpPackageIdent), Sel: ast.NewIdent(httpResponseWriterIdent)},
}
}
Expand All @@ -335,7 +335,7 @@ func routesFuncType(receiverType ast.Expr) *ast.FuncType {

func httpRequestField() *ast.Field {
return &ast.Field{
Names: []*ast.Ident{ast.NewIdent(PatternScopeIdentifierHTTPRequest)},
Names: []*ast.Ident{ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest)},
Type: &ast.StarExpr{X: &ast.SelectorExpr{X: ast.NewIdent(httpPackageIdent), Sel: ast.NewIdent(httpRequestIdent)}},
}
}
Expand Down Expand Up @@ -369,10 +369,10 @@ func httpStatusCode(name string) *ast.SelectorExpr {
func contextAssignment() *ast.AssignStmt {
return &ast.AssignStmt{
Tok: token.DEFINE,
Lhs: []ast.Expr{ast.NewIdent(PatternScopeIdentifierContext)},
Lhs: []ast.Expr{ast.NewIdent(TemplateNameScopeIdentifierContext)},
Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(PatternScopeIdentifierHTTPRequest),
X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest),
Sel: ast.NewIdent(httpRequestContextMethod),
},
}},
Expand All @@ -385,7 +385,7 @@ func httpPathValueAssignment(arg *ast.Ident) *ast.AssignStmt {
Lhs: []ast.Expr{ast.NewIdent(arg.Name)},
Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(PatternScopeIdentifierHTTPRequest),
X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest),
Sel: ast.NewIdent(requestPathValue),
},
Args: []ast.Expr{
Expand All @@ -398,12 +398,12 @@ func httpPathValueAssignment(arg *ast.Ident) *ast.AssignStmt {
}
}

func (def Pattern) executeCall(templatesVariable *ast.Ident, status, data ast.Expr) *ast.ExprStmt {
func (def TemplateName) executeCall(templatesVariable *ast.Ident, status, data ast.Expr) *ast.ExprStmt {
return &ast.ExprStmt{X: &ast.CallExpr{
Fun: ast.NewIdent(executeIdentName),
Args: []ast.Expr{
ast.NewIdent(PatternScopeIdentifierHTTPResponse),
ast.NewIdent(PatternScopeIdentifierHTTPRequest),
ast.NewIdent(TemplateNameScopeIdentifierHTTPResponse),
ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest),
&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(templatesVariable.Name),
Expand All @@ -417,14 +417,14 @@ func (def Pattern) executeCall(templatesVariable *ast.Ident, status, data ast.Ex
}}
}

func (def Pattern) httpRequestReceiverTemplateHandlerFunc(templatesVariableName string) *ast.FuncLit {
func (def TemplateName) httpRequestReceiverTemplateHandlerFunc(templatesVariableName string) *ast.FuncLit {
return &ast.FuncLit{
Type: httpHandlerFuncType(),
Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(ast.NewIdent(templatesVariableName), httpStatusCode(httpStatusCode200Ident), ast.NewIdent(PatternScopeIdentifierHTTPRequest))}},
Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(ast.NewIdent(templatesVariableName), httpStatusCode(httpStatusCode200Ident), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest))}},
}
}

func (def Pattern) matchReceiver(funcDecl *ast.FuncDecl, receiverTypeIdent string) bool {
func (def TemplateName) matchReceiver(funcDecl *ast.FuncDecl, receiverTypeIdent string) bool {
if funcDecl == nil || funcDecl.Name == nil || funcDecl.Name.Name != def.fun.Name ||
funcDecl.Recv == nil || len(funcDecl.Recv.List) < 1 {
return false
Expand Down
8 changes: 4 additions & 4 deletions generate_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/crhntr/muxt/internal/source"
)

func TestPattern_funcLit(t *testing.T) {
func TestTemplateName_funcLit(t *testing.T) {
for _, tt := range []struct {
Name string
In string
Expand Down Expand Up @@ -102,7 +102,7 @@ func TestPattern_funcLit(t *testing.T) {
},
} {
t.Run(tt.Name, func(t *testing.T) {
pat, err, ok := NewPattern(tt.In)
pat, err, ok := NewTemplateName(tt.In)
require.True(t, ok)
require.NoError(t, err)
tv := "templates"
Expand All @@ -113,7 +113,7 @@ func TestPattern_funcLit(t *testing.T) {
}
}

func TestPattern_HandlerFuncLit_err(t *testing.T) {
func TestTemplateName_HandlerFuncLit_err(t *testing.T) {
for _, tt := range []struct {
Name string
In string
Expand Down Expand Up @@ -209,7 +209,7 @@ func TestPattern_HandlerFuncLit_err(t *testing.T) {
},
} {
t.Run(tt.Name, func(t *testing.T) {
pat, err, ok := NewPattern(tt.In)
pat, err, ok := NewTemplateName(tt.In)
require.True(t, ok)
require.NoError(t, err)
tv := "templates"
Expand Down
4 changes: 2 additions & 2 deletions generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,12 +369,12 @@ func execute(response http.ResponseWriter, request *http.Request, t *template.Te
} {
t.Run(tt.Name, func(t *testing.T) {
ts := template.Must(template.New(tt.Name).Parse(tt.Templates))
patterns, err := muxt.TemplatePatterns(ts)
templateNames, err := muxt.TemplateNames(ts)
require.NoError(t, err)
logs := log.New(io.Discard, "", 0)
set := token.NewFileSet()
goFiles := methodFuncTypeLoader(t, set, tt.ReceiverPackage)
out, err := muxt.Generate(patterns, tt.PackageName, tt.TemplatesVar, tt.RoutesFunc, tt.Receiver, set, goFiles, goFiles, logs)
out, err := muxt.Generate(templateNames, tt.PackageName, tt.TemplatesVar, tt.RoutesFunc, tt.Receiver, set, goFiles, goFiles, logs)
if tt.ExpectedError == "" {
assert.NoError(t, err)
assert.Equal(t, tt.ExpectedFile, out)
Expand Down
44 changes: 22 additions & 22 deletions internal/source/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ func parseTemplates(workingDirectory, templatesVariable, templatesPackageIdent s
if err != nil {
return nil, err
}
patterns, err := parseStringLiterals(workingDirectory, fileSet, call.Args[1:])
templateNames, err := parseStringLiterals(workingDirectory, fileSet, call.Args[1:])
if err != nil {
return nil, err
}
filtered := matches[:0]
for _, ef := range matches {
for j, pattern := range patterns {
for j, pattern := range templateNames {
match, err := filepath.Match(pattern, ef)
if err != nil {
return nil, contextError(workingDirectory, fileSet, call.Args[j+1].Pos(), fmt.Errorf("bad pattern %q: %w", pattern, err))
Expand Down Expand Up @@ -164,11 +164,11 @@ func embedFSFilepaths(dir string, fileSet *token.FileSet, files []*ast.File, exp
}
var comment strings.Builder
commentNode := readComments(&comment, decl.Doc, spec.Doc)
patterns, err := parsePatterns(comment.String())
templateNames, err := parseTemplateNames(comment.String())
if err != nil {
return nil, err
}
absMat, err := embeddedFilesMatchingPatternList(dir, fileSet, commentNode, patterns, embeddedFiles)
absMat, err := embeddedFilesMatchingTemplateNameList(dir, fileSet, commentNode, templateNames, embeddedFiles)
if err != nil {
return nil, err
}
Expand All @@ -178,10 +178,10 @@ func embedFSFilepaths(dir string, fileSet *token.FileSet, files []*ast.File, exp
return nil, fmt.Errorf("variable %s not found", fsIdent.Name)
}

func embeddedFilesMatchingPatternList(dir string, set *token.FileSet, comment ast.Node, patterns, embeddedFiles []string) ([]string, error) {
func embeddedFilesMatchingTemplateNameList(dir string, set *token.FileSet, comment ast.Node, templateNames, embeddedFiles []string) ([]string, error) {
var matches []string
for _, fp := range embeddedFiles {
for _, pattern := range patterns {
for _, pattern := range templateNames {
pat := filepath.FromSlash(pattern)
if !strings.ContainsAny(pat, "*[]") {
prefix := filepath.FromSlash(pat + "/")
Expand Down Expand Up @@ -221,13 +221,13 @@ func readComments(s *strings.Builder, groups ...*ast.CommentGroup) ast.Node {
return n
}

func parsePatterns(input string) ([]string, error) {
func parseTemplateNames(input string) ([]string, error) {
// todo: refactor to use strconv.QuotedPrefix
var (
patterns []string
currentPattern strings.Builder
inQuote = false
quoteChar rune
templateNames []string
currentTemplateName strings.Builder
inQuote = false
quoteChar rune
)

for _, r := range input {
Expand All @@ -239,32 +239,32 @@ func parsePatterns(input string) ([]string, error) {
continue
}
if r != quoteChar {
currentPattern.WriteRune(r)
currentTemplateName.WriteRune(r)
continue
}
patterns = append(patterns, currentPattern.String())
currentPattern.Reset()
templateNames = append(templateNames, currentTemplateName.String())
currentTemplateName.Reset()
inQuote = false
case unicode.IsSpace(r):
if inQuote {
currentPattern.WriteRune(r)
currentTemplateName.WriteRune(r)
continue
}
if currentPattern.Len() > 0 {
patterns = append(patterns, currentPattern.String())
currentPattern.Reset()
if currentTemplateName.Len() > 0 {
templateNames = append(templateNames, currentTemplateName.String())
currentTemplateName.Reset()
}
default:
currentPattern.WriteRune(r)
currentTemplateName.WriteRune(r)
}
}

// Add any remaining pattern
if currentPattern.Len() > 0 {
patterns = append(patterns, currentPattern.String())
if currentTemplateName.Len() > 0 {
templateNames = append(templateNames, currentTemplateName.String())
}

return patterns, nil
return templateNames, nil
}

func contextError(workingDirectory string, set *token.FileSet, pos token.Pos, err error) error {
Expand Down
4 changes: 2 additions & 2 deletions internal/source/template_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/require"
)

func Test_parsePatterns(t *testing.T) {
func Test_parseTemplateNames(t *testing.T) {
for _, tt := range []struct {
name string
input string
Expand Down Expand Up @@ -40,7 +40,7 @@ func Test_parsePatterns(t *testing.T) {
},
} {
t.Run(tt.name, func(t *testing.T) {
result, err := parsePatterns(tt.input)
result, err := parseTemplateNames(tt.input)
require.NoError(t, err)
assert.EqualValues(t, tt.expected, result)
})
Expand Down
Loading

0 comments on commit 1df23cf

Please sign in to comment.