diff --git a/pattern.go b/pattern.go index 7a1a14e..0139678 100644 --- a/pattern.go +++ b/pattern.go @@ -17,6 +17,7 @@ import ( func TemplatePatterns(ts *template.Template) ([]Pattern, error) { var patterns []Pattern + routes := make(map[string]struct{}) for _, t := range ts.Templates() { pat, err, ok := NewPattern(t.Name()) if !ok { @@ -25,9 +26,10 @@ func TemplatePatterns(ts *template.Template) ([]Pattern, error) { if err != nil { return patterns, err } - if slices.ContainsFunc(patterns, pat.sameRoute) { + if _, exists := routes[pat.Method+pat.Path]; exists { return patterns, fmt.Errorf("duplicate route pattern: %s", pat.Route) } + routes[pat.Method+pat.Path] = struct{}{} patterns = append(patterns, pat) } slices.SortFunc(patterns, Pattern.byPathThenMethod) diff --git a/pattern_test.go b/pattern_test.go index 7d3e55e..4cbdbb3 100644 --- a/pattern_test.go +++ b/pattern_test.go @@ -1,6 +1,7 @@ package muxt_test import ( + "html/template" "net/http" "testing" @@ -10,6 +11,19 @@ import ( "github.com/crhntr/muxt" ) +func TestTemplatePatterns(t *testing.T) { + t.Run("when one of the template names is a malformed pattern", func(t *testing.T) { + ts := template.Must(template.New("").Parse(`{{define "HEAD /"}}{{end}}`)) + _, err := muxt.TemplatePatterns(ts) + require.Error(t, err) + }) + t.Run("when the pattern is not unique", func(t *testing.T) { + ts := template.Must(template.New("").Parse(`{{define "GET / F1()"}}a{{end}} {{define "GET / F2()"}}b{{end}}`)) + _, err := muxt.TemplatePatterns(ts) + require.Error(t, err) + }) +} + func TestNewPattern(t *testing.T) { for _, tt := range []struct { Name string