Skip to content

Commit

Permalink
refactor: move more stuff to source pkg
Browse files Browse the repository at this point in the history
  • Loading branch information
crhntr committed Oct 1, 2024
1 parent 31d429c commit 74b7d26
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 105 deletions.
16 changes: 14 additions & 2 deletions internal/source/go.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,12 @@ func String(s string) *ast.BasicLit {
return &ast.BasicLit{Value: strconv.Quote(s), Kind: token.STRING}
}

func Nil() *ast.Ident { return ast.NewIdent("nil") }

func ErrorCheckReturn(errVarIdent string, body ...ast.Stmt) *ast.IfStmt {
return &ast.IfStmt{
Cond: &ast.BinaryExpr{X: ast.NewIdent(errVarIdent), Op: token.NEQ, Y: ast.NewIdent("nil")},
Body: &ast.BlockStmt{List: body},
Cond: &ast.BinaryExpr{X: ast.NewIdent(errVarIdent), Op: token.NEQ, Y: Nil()},
Body: &ast.BlockStmt{List: append(body, &ast.ReturnStmt{})},
}
}

Expand All @@ -261,3 +263,13 @@ func FieldIndex(fields []*ast.Field, i int) (*ast.Ident, ast.Expr, bool) {
}
return nil, nil, false
}

func CallError(errIdent string) *ast.CallExpr {
return &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(errIdent),
Sel: ast.NewIdent("Error"),
},
Args: []ast.Expr{},
}
}
30 changes: 30 additions & 0 deletions internal/source/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ func (imports *Imports) Add(pkgIdent, pkgPath string) string {
return pkgIdent
}

func (imports *Imports) Ident(pkgPath string) string {
if imports != nil && imports.GenDecl != nil {
for _, s := range imports.GenDecl.Specs {
spec := s.(*ast.ImportSpec)
pp, _ := strconv.Unquote(spec.Path.Value)
if pp == pkgPath {
if spec.Name != nil && spec.Name.Name != "" {
return spec.Name.Name
}
return path.Base(pp)
}
}
}
return path.Base(pkgPath)
}

func (imports *Imports) ImportSpecs() []*ast.ImportSpec {
result := make([]*ast.ImportSpec, 0, len(imports.GenDecl.Specs))
for _, spec := range imports.GenDecl.Specs {
Expand All @@ -72,3 +88,17 @@ func (imports *Imports) SortImports() {
func (imports *Imports) AddNetHTTP() string { return imports.Add("", "net/http") }
func (imports *Imports) AddHTMLTemplate() string { return imports.Add("", "html/template") }
func (imports *Imports) AddContext() string { return imports.Add("", "context") }

func (imports *Imports) HTTPErrorCall(response ast.Expr, message ast.Expr, code int) *ast.ExprStmt {
return &ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(imports.AddNetHTTP()),
Sel: ast.NewIdent("Error"),
},
Args: []ast.Expr{
response,
message,
HTTPStatusCode(imports, code),
},
}}
}
18 changes: 4 additions & 14 deletions internal/source/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,20 +239,10 @@ func GenerateValidations(imports *Imports, variable, variableType ast.Expr, inpu
var statements []ast.Stmt
for _, validation := range validations {
statements = append(statements, validation.GenerateValidation(imports, variable, func(message string) ast.Stmt {
return &ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(imports.AddNetHTTP()),
Sel: ast.NewIdent("Error"),
},
Args: []ast.Expr{
ast.NewIdent(responseIdent),
&ast.BasicLit{
Kind: token.STRING,
Value: strconv.Quote(message),
},
HTTPStatusCode(imports, http.StatusBadRequest),
},
}}
return imports.HTTPErrorCall(ast.NewIdent(responseIdent), &ast.BasicLit{
Kind: token.STRING,
Value: strconv.Quote(message),
}, http.StatusBadRequest)
}))
}
return statements, nil, true
Expand Down
113 changes: 24 additions & 89 deletions routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ const (
receiverInterfaceIdent = "RoutesReceiver"

InputAttributeNameStructTag = "name"

errIdent = "err"
)

func Generate(templateNames []TemplateName, ts *template.Template, packageName, templatesVariableName, routesFunctionName, receiverTypeIdent, output string, fileSet *token.FileSet, receiverPackage, templatesPackage []*ast.File, log *log.Logger) (string, error) {
Expand Down Expand Up @@ -121,13 +123,7 @@ func (def TemplateName) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStm
X: ast.NewIdent(muxVarIdent),
Sel: ast.NewIdent(httpHandleFuncIdent),
},
Args: []ast.Expr{
&ast.BasicLit{
Kind: token.STRING,
Value: strconv.Quote(def.endpoint),
},
handlerFuncLit,
},
Args: []ast.Expr{source.String(def.endpoint), handlerFuncLit},
}}
}

Expand All @@ -152,7 +148,6 @@ func (def TemplateName) funcLit(imports *source.Imports, method *ast.FuncType, t
formStruct = s
}
}
const errVarIdent = "err"
writeHeader := true
for i, a := range def.call.Args {
arg := a.(*ast.Ident)
Expand Down Expand Up @@ -185,20 +180,8 @@ func (def TemplateName) funcLit(imports *source.Imports, method *ast.FuncType, t
X: ast.NewIdent(arg.Name),
Sel: ast.NewIdent(name.Name),
}
errCheck := source.ErrorCheckReturn(errVarIdent, &ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(imports.AddNetHTTP()),
Sel: ast.NewIdent("Error"),
},
Args: []ast.Expr{
ast.NewIdent(httpResponseField(imports).Names[0].Name),
&ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("err"), Sel: ast.NewIdent("Error")},
Args: []ast.Expr{},
},
source.HTTPStatusCode(imports, http.StatusBadRequest),
},
}}, &ast.ReturnStmt{})

errCheck := source.ErrorCheckReturn(errIdent, imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), source.CallError(errIdent), http.StatusBadRequest))

const parsedVariableName = "value"
if fieldType, ok := field.Type.(*ast.ArrayType); ok {
Expand All @@ -217,7 +200,7 @@ func (def TemplateName) funcLit(imports *source.Imports, method *ast.FuncType, t
if ok && err != nil {
return nil, err
}
statements, err := source.GenerateParseValueFromStringStatements(imports, parsedVariableName, errVarIdent, ast.NewIdent(valVar), fieldType.Elt, errCheck, validations, assignment)
statements, err := source.GenerateParseValueFromStringStatements(imports, parsedVariableName, errIdent, ast.NewIdent(valVar), fieldType.Elt, errCheck, validations, assignment)
if err != nil {
return nil, fmt.Errorf("failed to generate parse statements for form field %s: %w", name.Name, err)
}
Expand Down Expand Up @@ -263,7 +246,7 @@ func (def TemplateName) funcLit(imports *source.Imports, method *ast.FuncType, t
if ok && err != nil {
return nil, err
}
statements, err := source.GenerateParseValueFromStringStatements(imports, parsedVariableName, errVarIdent, str, field.Type, errCheck, validations, assignment)
statements, err := source.GenerateParseValueFromStringStatements(imports, parsedVariableName, errIdent, str, field.Type, errCheck, validations, assignment)
if err != nil {
return nil, fmt.Errorf("failed to generate parse statements for form field %s: %w", name.Name, err)
}
Expand All @@ -283,28 +266,15 @@ func (def TemplateName) funcLit(imports *source.Imports, method *ast.FuncType, t
}
call.Args = append(call.Args, ast.NewIdent(arg.Name))
default:
errCheck := source.ErrorCheckReturn(errVarIdent, &ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(imports.AddNetHTTP()),
Sel: ast.NewIdent("Error"),
},
Args: []ast.Expr{
ast.NewIdent(httpResponseField(imports).Names[0].Name),
&ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent("err"), Sel: ast.NewIdent("Error")},
Args: []ast.Expr{},
},
source.HTTPStatusCode(imports, http.StatusBadRequest),
},
}}, &ast.ReturnStmt{})
errCheck := source.ErrorCheckReturn(errIdent, imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), source.CallError(errIdent), http.StatusBadRequest))
src := &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest),
Sel: ast.NewIdent(requestPathValue),
},
Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}},
}
statements, err := httpPathValueAssignment(imports, method, i, arg, errVarIdent, src, token.DEFINE, errCheck)
statements, err := httpPathValueAssignment(imports, method, i, arg, errIdent, src, token.DEFINE, errCheck)
if err != nil {
return nil, err
}
Expand All @@ -315,31 +285,13 @@ func (def TemplateName) funcLit(imports *source.Imports, method *ast.FuncType, t

const dataVarIdent = "data"
if len(method.Results.List) > 1 {
errVar := ast.NewIdent("err")

lit.Body.List = append(lit.Body.List,
&ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent), ast.NewIdent(errVar.Name)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}},
&ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent), ast.NewIdent(errIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}},
&ast.IfStmt{
Cond: &ast.BinaryExpr{X: ast.NewIdent(errVar.Name), Op: token.NEQ, Y: ast.NewIdent("nil")},
Cond: &ast.BinaryExpr{X: ast.NewIdent(errIdent), Op: token.NEQ, Y: source.Nil()},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(imports.AddNetHTTP()),
Sel: ast.NewIdent("Error"),
},
Args: []ast.Expr{
ast.NewIdent(httpResponseField(imports).Names[0].Name),
&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent("err"),
Sel: ast.NewIdent("Error"),
},
Args: []ast.Expr{},
},
source.HTTPStatusCode(imports, http.StatusInternalServerError),
},
}},
imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), source.CallError(errIdent), http.StatusInternalServerError),
&ast.ReturnStmt{},
},
},
Expand Down Expand Up @@ -379,7 +331,7 @@ func (def TemplateName) funcType(imports *source.Imports) *ast.FuncType {
case TemplateNameScopeIdentifierContext:
method.Params.List = append(method.Params.List, contextContextField(imports))
case TemplateNameScopeIdentifierForm:
method.Params.List = append(method.Params.List, urlValuesField(arg.Name))
method.Params.List = append(method.Params.List, urlValuesField(imports, arg.Name))
default:
method.Params.List = append(method.Params.List, pathValueField(arg.Name))
}
Expand Down Expand Up @@ -525,10 +477,10 @@ func routesFuncType(imports *source.Imports, receiverType ast.Expr) *ast.FuncTyp
}}
}

func urlValuesField(ident string) *ast.Field {
func urlValuesField(imports *source.Imports, ident string) *ast.Field {
return &ast.Field{
Names: []*ast.Ident{ast.NewIdent(ident)},
Type: &ast.SelectorExpr{X: ast.NewIdent("url"), Sel: ast.NewIdent("Values")},
Type: &ast.SelectorExpr{X: ast.NewIdent(imports.Add("", "net/url")), Sel: ast.NewIdent("Values")},
}
}

Expand Down Expand Up @@ -565,7 +517,8 @@ func contextAssignment() *ast.AssignStmt {
}

func formDeclaration(imports *source.Imports, ident string, typeExp ast.Expr) *ast.DeclStmt {
if matchSelectorIdents(typeExp, "url", "Values", false) {
if matchSelectorIdents(typeExp, imports.Ident("net/url"), "Values", false) {
imports.Add("", "net/url")
return &ast.DeclStmt{
Decl: &ast.GenDecl{
Tok: token.VAR,
Expand All @@ -581,7 +534,6 @@ func formDeclaration(imports *source.Imports, ident string, typeExp ast.Expr) *a
},
}
}

return &ast.DeclStmt{
Decl: &ast.GenDecl{
Tok: token.VAR,
Expand Down Expand Up @@ -665,7 +617,6 @@ func (def TemplateName) matchReceiver(funcDecl *ast.FuncDecl, receiverTypeIdent

func executeFuncDecl(imports *source.Imports, templatesVariableIdent string) *ast.FuncDecl {
const writeHeaderIdent = "writeHeader"
imports.Add("", "bytes")
return &ast.FuncDecl{
Name: ast.NewIdent(executeIdentName),
Type: &ast.FuncType{
Expand All @@ -687,15 +638,15 @@ func executeFuncDecl(imports *source.Imports, templatesVariableIdent string) *as
Tok: token.DEFINE,
Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent("bytes"),
X: ast.NewIdent(imports.Add("", "bytes")),
Sel: ast.NewIdent("NewBuffer"),
},
Args: []ast.Expr{ast.NewIdent("nil")},
Args: []ast.Expr{source.Nil()},
}},
},
&ast.IfStmt{
Init: &ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent("err")},
Lhs: []ast.Expr{ast.NewIdent(errIdent)},
Tok: token.DEFINE,
Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{
Expand All @@ -706,29 +657,13 @@ func executeFuncDecl(imports *source.Imports, templatesVariableIdent string) *as
}},
},
Cond: &ast.BinaryExpr{
X: ast.NewIdent("err"),
X: ast.NewIdent(errIdent),
Op: token.NEQ,
Y: ast.NewIdent("nil"),
Y: source.Nil(),
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(imports.AddNetHTTP()),
Sel: ast.NewIdent("Error"),
},
Args: []ast.Expr{
ast.NewIdent(httpResponseField(imports).Names[0].Name),
&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent("err"),
Sel: ast.NewIdent("Error"),
},
Args: []ast.Expr{},
},
source.HTTPStatusCode(imports, http.StatusInternalServerError),
},
}},
imports.HTTPErrorCall(ast.NewIdent(httpResponseField(imports).Names[0].Name), source.CallError(errIdent), http.StatusInternalServerError),
&ast.ReturnStmt{},
},
},
Expand All @@ -738,7 +673,7 @@ func executeFuncDecl(imports *source.Imports, templatesVariableIdent string) *as
Body: &ast.BlockStmt{List: []ast.Stmt{
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: &ast.CallExpr{Fun: &ast.SelectorExpr{X: ast.NewIdent(TemplateNameScopeIdentifierHTTPResponse), Sel: ast.NewIdent("Header")}, Args: []ast.Expr{}}, Sel: ast.NewIdent("Set")},
Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote("content-type")}, &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote("text/html; charset=utf-8")}},
Args: []ast.Expr{source.String("content-type"), source.String("text/html; charset=utf-8")},
}},
&ast.ExprStmt{X: &ast.CallExpr{
Fun: &ast.SelectorExpr{X: ast.NewIdent(httpResponseField(imports).Names[0].Name), Sel: ast.NewIdent("WriteHeader")},
Expand Down

0 comments on commit 74b7d26

Please sign in to comment.