From f199cdd5b93df8584251d07823d01eb8eb74683d Mon Sep 17 00:00:00 2001 From: sunboyy Date: Wed, 15 May 2024 20:17:53 +0700 Subject: [PATCH] Utilize go/types for code generation (#47) --- internal/code/models.go | 67 --------------- internal/code/models_test.go | 48 ----------- internal/codegen/base.go | 6 +- internal/codegen/body.go | 32 ++++++-- internal/codegen/body_test.go | 29 +++---- internal/codegen/builder.go | 11 ++- internal/codegen/builder_test.go | 18 ++-- internal/codegen/function.go | 18 ++-- internal/codegen/function_test.go | 118 ++++++++++++++++++++++++++- internal/codegen/method.go | 12 ++- internal/codegen/method_test.go | 20 +++-- internal/codegen/struct.go | 6 +- internal/codegen/struct_test.go | 26 +++--- internal/generator/generator.go | 4 +- internal/generator/generator_test.go | 2 +- internal/mongo/common.go | 19 ++++- internal/mongo/count_test.go | 8 +- internal/mongo/delete_test.go | 8 +- internal/mongo/find.go | 20 ++--- internal/mongo/find_test.go | 8 +- internal/mongo/generator.go | 42 ++++------ internal/mongo/generator_test.go | 22 ++--- internal/mongo/insert.go | 14 ++-- internal/mongo/insert_test.go | 8 +- internal/mongo/models.go | 31 +++---- internal/mongo/update_test.go | 8 +- main.go | 2 +- 27 files changed, 299 insertions(+), 308 deletions(-) delete mode 100644 internal/code/models_test.go diff --git a/internal/code/models.go b/internal/code/models.go index 105dbf1..67fa305 100644 --- a/internal/code/models.go +++ b/internal/code/models.go @@ -1,52 +1,16 @@ package code import ( - "fmt" "go/types" "reflect" ) -// Import is a model for package imports -type Import struct { - Name string - Path string -} - -// LegacyStructField is a definition of the struct field -type LegacyStructField struct { - Name string - Type Type - Tag reflect.StructTag -} - // StructField is a definition of the struct field type StructField struct { Var *types.Var Tag reflect.StructTag } -// InterfaceType is a definition of the interface -type InterfaceType struct { -} - -// Code returns token string in code format -func (intf InterfaceType) Code() string { - return `interface{}` -} - -// Type is an interface for value types -type Type interface { - Code() string -} - -// SimpleType is a type that can be called directly -type SimpleType string - -// Code returns token string in code format -func (t SimpleType) Code() string { - return string(t) -} - var ( TypeBool = types.Typ[types.Bool] TypeInt = types.Typ[types.Int] @@ -55,34 +19,3 @@ var ( TypeString = types.Typ[types.String] TypeError = types.Universe.Lookup("error").Type() ) - -// ExternalType is a type that is called to another package -type ExternalType struct { - PackageAlias string - Name string -} - -// Code returns token string in code format -func (t ExternalType) Code() string { - return fmt.Sprintf("%s.%s", t.PackageAlias, t.Name) -} - -// PointerType is a model of pointer -type PointerType struct { - ContainedType Type -} - -// Code returns token string in code format -func (t PointerType) Code() string { - return fmt.Sprintf("*%s", t.ContainedType.Code()) -} - -// ArrayType is a model of array -type ArrayType struct { - ContainedType Type -} - -// Code returns token string in code format -func (t ArrayType) Code() string { - return fmt.Sprintf("[]%s", t.ContainedType.Code()) -} diff --git a/internal/code/models_test.go b/internal/code/models_test.go deleted file mode 100644 index 594f00d..0000000 --- a/internal/code/models_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package code_test - -import ( - "testing" - - "github.com/sunboyy/repogen/internal/code" -) - -type TypeCodeTestCase struct { - Name string - Type code.Type - ExpectedCode string -} - -func TestTypeCode(t *testing.T) { - testTable := []TypeCodeTestCase{ - { - Name: "simple type", - Type: code.SimpleType("UserModel"), - ExpectedCode: "UserModel", - }, - { - Name: "external type", - Type: code.ExternalType{PackageAlias: "context", Name: "Context"}, - ExpectedCode: "context.Context", - }, - { - Name: "pointer type", - Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}, - ExpectedCode: "*UserModel", - }, - { - Name: "array type", - Type: code.ArrayType{ContainedType: code.SimpleType("UserModel")}, - ExpectedCode: "[]UserModel", - }, - } - - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - code := testCase.Type.Code() - - if code != testCase.ExpectedCode { - t.Errorf("Expected = %+v\nReceived = %+v", testCase.ExpectedCode, code) - } - }) - } -} diff --git a/internal/codegen/base.go b/internal/codegen/base.go index 5508dc7..899b58d 100644 --- a/internal/codegen/base.go +++ b/internal/codegen/base.go @@ -3,8 +3,6 @@ package codegen import ( "fmt" "strings" - - "github.com/sunboyy/repogen/internal/code" ) const baseTemplate = `// Code generated by {{.Program}}. DO NOT EDIT. @@ -18,7 +16,7 @@ import ( type baseTemplateData struct { Program string PackageName string - Imports [][]code.Import + Imports [][]Import } func (data baseTemplateData) GenImports() string { @@ -33,7 +31,7 @@ func (data baseTemplateData) GenImports() string { return strings.Join(sections, "\n\n") } -func (data baseTemplateData) generateImportLine(imp code.Import) string { +func (data baseTemplateData) generateImportLine(imp Import) string { if imp.Name == "" { return fmt.Sprintf("\t\"%s\"", imp.Path) } diff --git a/internal/codegen/body.go b/internal/codegen/body.go index e6ba09e..7072857 100644 --- a/internal/codegen/body.go +++ b/internal/codegen/body.go @@ -2,9 +2,8 @@ package codegen import ( "fmt" + "go/types" "strings" - - "github.com/sunboyy/repogen/internal/code" ) type FunctionBody []Statement @@ -37,12 +36,21 @@ func (id Identifier) CodeLines() []string { } type DeclStatement struct { + Pkg *types.Package Name string - Type code.Type + Type types.Type +} + +func NewDeclStatement(pkg *types.Package, name string, typ types.Type) DeclStatement { + return DeclStatement{ + Pkg: pkg, + Name: name, + Type: typ, + } } func (stmt DeclStatement) CodeLines() []string { - return []string{fmt.Sprintf("var %s %s", stmt.Name, stmt.Type.Code())} + return []string{fmt.Sprintf("var %s %s", stmt.Name, TypeToString(stmt.Pkg, stmt.Type))} } type DeclAssignStatement struct { @@ -105,12 +113,22 @@ func (stmt CallStatement) CodeLines() []string { } type SliceStatement struct { - Type code.Type + Pkg *types.Package + Type types.Type Values []Statement } +func NewSliceStatement(pkg *types.Package, typ types.Type, values []Statement) SliceStatement { + return SliceStatement{ + Pkg: pkg, + Type: typ, + Values: values, + } +} + func (stmt SliceStatement) CodeLines() []string { - lines := []string{stmt.Type.Code() + "{"} + lines := []string{TypeToString(stmt.Pkg, stmt.Type) + "{"} + for _, value := range stmt.Values { stmtLines := value.CodeLines() stmtLines[len(stmtLines)-1] += "," @@ -118,7 +136,9 @@ func (stmt SliceStatement) CodeLines() []string { lines = append(lines, "\t"+line) } } + lines = append(lines, "}") + return lines } diff --git a/internal/codegen/body_test.go b/internal/codegen/body_test.go index be0c32f..bcd4c18 100644 --- a/internal/codegen/body_test.go +++ b/internal/codegen/body_test.go @@ -1,6 +1,7 @@ package codegen_test import ( + "go/types" "reflect" "testing" @@ -20,10 +21,7 @@ func TestIdentifier(t *testing.T) { } func TestDeclStatement(t *testing.T) { - stmt := codegen.DeclStatement{ - Name: "arrs", - Type: code.ArrayType{ContainedType: code.SimpleType("int")}, - } + stmt := codegen.NewDeclStatement(nil, "arrs", types.NewSlice(code.TypeInt)) expected := []string{"var arrs []int"} actual := stmt.CodeLines() @@ -139,23 +137,18 @@ func TestCallStatement(t *testing.T) { } func TestSliceStatement(t *testing.T) { - stmt := codegen.SliceStatement{ - Type: code.ArrayType{ - ContainedType: code.SimpleType("string"), - }, - Values: []codegen.Statement{ - codegen.Identifier(`"hello"`), - codegen.ChainStatement{ - codegen.CallStatement{ - FuncName: "GetUser", - Params: codegen.StatementList{ - codegen.Identifier("userID"), - }, + stmt := codegen.NewSliceStatement(nil, types.NewSlice(code.TypeString), []codegen.Statement{ + codegen.Identifier(`"hello"`), + codegen.ChainStatement{ + codegen.CallStatement{ + FuncName: "GetUser", + Params: codegen.StatementList{ + codegen.Identifier("userID"), }, - codegen.Identifier("Name"), }, + codegen.Identifier("Name"), }, - } + }) expected := []string{ "[]string{", ` "hello",`, diff --git a/internal/codegen/builder.go b/internal/codegen/builder.go index a080921..556985e 100644 --- a/internal/codegen/builder.go +++ b/internal/codegen/builder.go @@ -4,10 +4,15 @@ import ( "bytes" "text/template" - "github.com/sunboyy/repogen/internal/code" "golang.org/x/tools/imports" ) +// Import is a model for package imports +type Import struct { + Name string + Path string +} + type Builder struct { // Program defines generator program name in the generated file. Program string @@ -17,7 +22,7 @@ type Builder struct { // Imports defines necessary imports to reduce ambiguity when generating // formatting the raw-generated code. - Imports [][]code.Import + Imports [][]Import implementers []Implementer } @@ -29,7 +34,7 @@ type Implementer interface { } // NewBuilder is a constructor of Builder struct. -func NewBuilder(program string, packageName string, imports [][]code.Import) *Builder { +func NewBuilder(program string, packageName string, imports [][]Import) *Builder { return &Builder{ Program: program, PackageName: packageName, diff --git a/internal/codegen/builder_test.go b/internal/codegen/builder_test.go index 7459799..3f98505 100644 --- a/internal/codegen/builder_test.go +++ b/internal/codegen/builder_test.go @@ -38,7 +38,7 @@ func (u User) IDHex() string { ` func TestBuilderBuild(t *testing.T) { - builder := codegen.NewBuilder("repogen", "user", [][]code.Import{ + builder := codegen.NewBuilder("repogen", "user", [][]codegen.Import{ { { Name: "_", @@ -56,23 +56,20 @@ func TestBuilderBuild(t *testing.T) { }, }) builder.AddImplementer(codegen.StructBuilder{ + Pkg: testutils.Pkg, Name: "User", - Fields: []code.LegacyStructField{ + Fields: []code.StructField{ { - Name: "ID", - Type: code.ExternalType{ - PackageAlias: "primitive", - Name: "ObjectID", - }, + Var: types.NewVar(token.NoPos, nil, "ID", testutils.TypeObjectIDNamed), Tag: `bson:"id" json:"id,omitempty"`, }, { - Name: "Username", - Type: code.SimpleType("string"), + Var: types.NewVar(token.NoPos, nil, "Username", code.TypeString), }, }, }) builder.AddImplementer(codegen.FunctionBuilder{ + Pkg: testutils.Pkg, Name: "NewUser", Params: types.NewTuple( types.NewVar(token.NoPos, nil, "username", code.TypeString), @@ -104,7 +101,8 @@ func TestBuilderBuild(t *testing.T) { }, }) builder.AddImplementer(codegen.MethodBuilder{ - Receiver: codegen.MethodReceiver{Name: "u", Type: code.SimpleType("User")}, + Pkg: testutils.Pkg, + Receiver: codegen.MethodReceiver{Name: "u", TypeName: "User"}, Name: "IDHex", Params: nil, Returns: []types.Type{code.TypeString}, diff --git a/internal/codegen/function.go b/internal/codegen/function.go index 73726dd..01c783c 100644 --- a/internal/codegen/function.go +++ b/internal/codegen/function.go @@ -51,25 +51,25 @@ func generateParams(pkg *types.Package, params *types.Tuple) string { paramList = append( paramList, - fmt.Sprintf("%s %s", param.Name(), typeToString(pkg, param.Type())), + fmt.Sprintf("%s %s", param.Name(), TypeToString(pkg, param.Type())), ) } return strings.Join(paramList, ", ") } -func typeToString(pkg *types.Package, t types.Type) string { +func TypeToString(pkg *types.Package, t types.Type) string { switch t := t.(type) { case *types.Pointer: - return fmt.Sprintf("*%s", typeToString(pkg, t.Elem())) + return fmt.Sprintf("*%s", TypeToString(pkg, t.Elem())) case *types.Slice: - return fmt.Sprintf("[]%s", typeToString(pkg, t.Elem())) + return fmt.Sprintf("[]%s", TypeToString(pkg, t.Elem())) case *types.Named: - if t.Obj().Pkg() == nil || t.Obj().Pkg().Path() == pkg.Path() { - return t.Obj().Name() + if pkg == nil || (t.Obj().Pkg() != nil && t.Obj().Pkg().Path() != pkg.Path()) { + return fmt.Sprintf("%s.%s", t.Obj().Pkg().Name(), t.Obj().Name()) } - return fmt.Sprintf("%s.%s", t.Obj().Pkg().Name(), t.Obj().Name()) + return t.Obj().Name() default: return t.String() @@ -82,12 +82,12 @@ func generateReturns(pkg *types.Package, returns []types.Type) string { } if len(returns) == 1 { - return " " + typeToString(pkg, returns[0]) + return " " + TypeToString(pkg, returns[0]) } var returnList []string for _, ret := range returns { - returnList = append(returnList, typeToString(pkg, ret)) + returnList = append(returnList, TypeToString(pkg, ret)) } return fmt.Sprintf(" (%s)", strings.Join(returnList, ", ")) diff --git a/internal/codegen/function_test.go b/internal/codegen/function_test.go index 0181608..d67c662 100644 --- a/internal/codegen/function_test.go +++ b/internal/codegen/function_test.go @@ -54,15 +54,15 @@ func init() { func TestFunctionBuilderBuild_OneReturn(t *testing.T) { fb := codegen.FunctionBuilder{ + Pkg: testutils.Pkg, Name: "NewUser", Params: types.NewTuple( types.NewVar(token.NoPos, nil, "username", code.TypeString), types.NewVar(token.NoPos, nil, "age", code.TypeInt), - types.NewVar(token.NoPos, nil, "parent", - types.NewPointer(types.NewNamed(types.NewTypeName(token.NoPos, nil, "User", nil), nil, nil))), + types.NewVar(token.NoPos, nil, "parent", types.NewPointer(testutils.TypeUserNamed)), ), Returns: []types.Type{ - types.NewNamed(types.NewTypeName(token.NoPos, nil, "User", nil), nil, nil), + testutils.TypeUserNamed, }, Body: codegen.FunctionBody{ codegen.ReturnStatement{ @@ -104,6 +104,7 @@ func NewUser(username string, age int, parent *User) User { func TestFunctionBuilderBuild_MultiReturn(t *testing.T) { fb := codegen.FunctionBuilder{ + Pkg: testutils.Pkg, Name: "Save", Params: types.NewTuple( types.NewVar(token.NoPos, nil, "user", @@ -147,3 +148,114 @@ func Save(user User) (User, error) { t.Error(err) } } + +func TestTypeToString(t *testing.T) { + internalPkg := types.NewPackage("github.com/sunboyy/repogen/internal/foo", "foo") + externalPkg := types.NewPackage("github.com/sunboyy/repogen/internal/bar", "bar") + + tests := []struct { + name string + typ types.Type + want string + }{ + { + name: "basic type", + typ: code.TypeString, + want: "string", + }, + { + name: "pointer type", + typ: types.NewPointer(code.TypeString), + want: "*string", + }, + { + name: "slice type", + typ: types.NewSlice(code.TypeString), + want: "[]string", + }, + { + name: "named type internal", + typ: types.NewNamed(types.NewTypeName(token.NoPos, internalPkg, "User", nil), nil, nil), + want: "User", + }, + { + name: "named type external", + typ: types.NewNamed(types.NewTypeName(token.NoPos, externalPkg, "User", nil), nil, nil), + want: "bar.User", + }, + { + name: "integration internal", + typ: types.NewSlice( + types.NewPointer( + types.NewNamed(types.NewTypeName(token.NoPos, internalPkg, "User", nil), nil, nil), + ), + ), + want: "[]*User", + }, + { + name: "integration external", + typ: types.NewSlice( + types.NewPointer( + types.NewNamed(types.NewTypeName(token.NoPos, externalPkg, "User", nil), nil, nil), + ), + ), + want: "[]*bar.User", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := codegen.TypeToString(internalPkg, tt.typ); got != tt.want { + t.Errorf("TypeToString() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTypeToString_PkgNil(t *testing.T) { + externalPkg := types.NewPackage("github.com/sunboyy/repogen/internal/bar", "bar") + + tests := []struct { + name string + typ types.Type + want string + }{ + { + name: "basic type", + typ: code.TypeString, + want: "string", + }, + { + name: "pointer type", + typ: types.NewPointer(code.TypeString), + want: "*string", + }, + { + name: "slice type", + typ: types.NewSlice(code.TypeString), + want: "[]string", + }, + { + name: "named type external", + typ: types.NewNamed(types.NewTypeName(token.NoPos, externalPkg, "User", nil), nil, nil), + want: "bar.User", + }, + { + name: "integration external", + typ: types.NewSlice( + types.NewPointer( + types.NewNamed(types.NewTypeName(token.NoPos, externalPkg, "User", nil), nil, nil), + ), + ), + want: "[]*bar.User", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := codegen.TypeToString(nil, tt.typ); got != tt.want { + t.Errorf("TypeToString() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/codegen/method.go b/internal/codegen/method.go index e36a71f..122b1c8 100644 --- a/internal/codegen/method.go +++ b/internal/codegen/method.go @@ -5,8 +5,6 @@ import ( "fmt" "go/types" "text/template" - - "github.com/sunboyy/repogen/internal/code" ) const methodTemplate = ` @@ -27,9 +25,9 @@ type MethodBuilder struct { // MethodReceiver describes a specification of a method receiver. type MethodReceiver struct { - Name string - Type code.SimpleType - Pointer bool + Name string + TypeName string + Pointer bool } // Impl writes method declatation code to the buffer. @@ -54,9 +52,9 @@ func (mb MethodBuilder) GenReceiver() string { func (mb MethodBuilder) generateReceiverType() string { if !mb.Receiver.Pointer { - return mb.Receiver.Type.Code() + return mb.Receiver.TypeName } - return code.PointerType{ContainedType: mb.Receiver.Type}.Code() + return "*" + mb.Receiver.TypeName } func (mb MethodBuilder) GenParams() string { diff --git a/internal/codegen/method_test.go b/internal/codegen/method_test.go index 3269238..982e089 100644 --- a/internal/codegen/method_test.go +++ b/internal/codegen/method_test.go @@ -13,7 +13,7 @@ import ( func TestMethodBuilderBuild_IgnoreReceiverNoReturn(t *testing.T) { fb := codegen.MethodBuilder{ - Receiver: codegen.MethodReceiver{Type: "User"}, + Receiver: codegen.MethodReceiver{TypeName: "User"}, Name: "Init", Params: nil, Returns: nil, @@ -52,9 +52,10 @@ func (User) Init() { func TestMethodBuilderBuild_IgnorePoinerReceiverOneReturn(t *testing.T) { fb := codegen.MethodBuilder{ + Pkg: testutils.Pkg, Receiver: codegen.MethodReceiver{ - Type: "User", - Pointer: true, + TypeName: "User", + Pointer: true, }, Name: "Init", Params: nil, @@ -96,16 +97,17 @@ func (*User) Init() error { func TestMethodBuilderBuild_UseReceiverMultiReturn(t *testing.T) { fb := codegen.MethodBuilder{ + Pkg: testutils.Pkg, Receiver: codegen.MethodReceiver{ - Name: "u", - Type: "User", + Name: "u", + TypeName: "User", }, Name: "WithAge", Params: types.NewTuple( types.NewVar(token.NoPos, nil, "age", code.TypeInt), ), Returns: []types.Type{ - types.NewNamed(types.NewTypeName(token.NoPos, nil, "User", nil), nil, nil), + testutils.TypeUserNamed, code.TypeError, }, Body: codegen.FunctionBody{ @@ -145,9 +147,9 @@ func (u User) WithAge(age int) (User, error) { func TestMethodBuilderBuild_UsePointerReceiverNoReturn(t *testing.T) { fb := codegen.MethodBuilder{ Receiver: codegen.MethodReceiver{ - Name: "u", - Type: "User", - Pointer: true, + Name: "u", + TypeName: "User", + Pointer: true, }, Name: "SetAge", Params: types.NewTuple( diff --git a/internal/codegen/struct.go b/internal/codegen/struct.go index ce1b1cf..2a310f8 100644 --- a/internal/codegen/struct.go +++ b/internal/codegen/struct.go @@ -3,6 +3,7 @@ package codegen import ( "bytes" "fmt" + "go/types" "strings" "text/template" @@ -17,8 +18,9 @@ type {{.Name}} struct { // StructBuilder is an implementer of a struct. type StructBuilder struct { + Pkg *types.Package Name string - Fields []code.LegacyStructField + Fields []code.StructField } // Impl writes struct declatation code to the buffer. @@ -37,7 +39,7 @@ func (sb StructBuilder) Impl(buffer *bytes.Buffer) error { func (sb StructBuilder) GenFields() string { var fieldLines []string for _, field := range sb.Fields { - fieldLine := fmt.Sprintf("\t%s %s", field.Name, field.Type.Code()) + fieldLine := fmt.Sprintf("\t%s %s", field.Var.Name(), TypeToString(sb.Pkg, field.Var.Type())) if len(field.Tag) > 0 { fieldLine += fmt.Sprintf(" `%s`", string(field.Tag)) } diff --git a/internal/codegen/struct_test.go b/internal/codegen/struct_test.go index 56aa09e..45d70fa 100644 --- a/internal/codegen/struct_test.go +++ b/internal/codegen/struct_test.go @@ -2,6 +2,8 @@ package codegen_test import ( "bytes" + "go/token" + "go/types" "testing" "github.com/sunboyy/repogen/internal/code" @@ -20,31 +22,23 @@ type User struct { func TestStructBuilderBuild(t *testing.T) { sb := codegen.StructBuilder{ + Pkg: testutils.Pkg, Name: "User", - Fields: []code.LegacyStructField{ + Fields: []code.StructField{ { - Name: "ID", - Type: code.ExternalType{ - PackageAlias: "primitive", - Name: "ObjectID", - }, + Var: types.NewVar(token.NoPos, nil, "ID", testutils.TypeObjectIDNamed), Tag: `bson:"id,omitempty" json:"id,omitempty"`, }, { - Name: "Username", - Type: code.SimpleType("string"), - Tag: `bson:"username" json:"username"`, + Var: types.NewVar(token.NoPos, nil, "Username", code.TypeString), + Tag: `bson:"username" json:"username"`, }, { - Name: "Age", - Type: code.SimpleType("int"), - Tag: `bson:"age"`, + Var: types.NewVar(token.NoPos, nil, "Age", code.TypeInt), + Tag: `bson:"age"`, }, { - Name: "orderCount", - Type: code.PointerType{ - ContainedType: code.SimpleType("int"), - }, + Var: types.NewVar(token.NoPos, nil, "orderCount", types.NewPointer(code.TypeInt)), }, }, } diff --git a/internal/generator/generator.go b/internal/generator/generator.go index dc210c9..5644ced 100644 --- a/internal/generator/generator.go +++ b/internal/generator/generator.go @@ -10,10 +10,10 @@ import ( // GenerateRepository generates repository implementation code from repository // interface specification. -func GenerateRepository(pkg *types.Package, structModelName string, +func GenerateRepository(pkg *types.Package, namedStruct *types.Named, interfaceName string, methodSpecs []spec.MethodSpec) (string, error) { - generator := mongo.NewGenerator(pkg, structModelName, interfaceName) + generator := mongo.NewGenerator(pkg, namedStruct, interfaceName) codeBuilder := codegen.NewBuilder( "repogen", diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go index 6e29f7c..7bfa8fa 100644 --- a/internal/generator/generator_test.go +++ b/internal/generator/generator_test.go @@ -299,7 +299,7 @@ func TestGenerateMongoRepository(t *testing.T) { } expectedCode := string(expectedBytes) - code, err := generator.GenerateRepository(testutils.Pkg, "User", "UserRepository", methods) + code, err := generator.GenerateRepository(testutils.Pkg, testutils.TypeUserNamed, "UserRepository", methods) if err != nil { t.Fatal(err) diff --git a/internal/mongo/common.go b/internal/mongo/common.go index 1387989..8ddcc84 100644 --- a/internal/mongo/common.go +++ b/internal/mongo/common.go @@ -1,6 +1,7 @@ package mongo import ( + "go/token" "go/types" "strings" @@ -9,6 +10,19 @@ import ( "github.com/sunboyy/repogen/internal/spec" ) +var ( + mongoCollectionType types.Type + bsonMType types.Type +) + +func init() { + bareMongoPkg := types.NewPackage("go.mongodb.org/mongo-driver/mongo", "mongo") + mongoCollectionType = types.NewNamed(types.NewTypeName(token.NoPos, bareMongoPkg, "Collection", nil), nil, nil) + + bareBsonPkg := types.NewPackage("go.mongodb.org/mongo-driver/bson", "bson") + bsonMType = types.NewNamed(types.NewTypeName(token.NoPos, bareBsonPkg, "M", nil), nil, nil) +} + var errOccurred = codegen.RawStatement("err != nil") var returnNilErr = codegen.ReturnStatement{ @@ -50,8 +64,8 @@ var ifErrReturnFalseErr = codegen.IfBlock{ } type baseMethodGenerator struct { - pkg *types.Package - structModelName string + targetPkg *types.Package + structModelNamed *types.Named } func (g baseMethodGenerator) bsonFieldReference(fieldReference spec.FieldReference) (string, error) { @@ -93,6 +107,7 @@ func (g baseMethodGenerator) convertQuerySpec(query spec.QuerySpec) (querySpec, } return querySpec{ + TargetPkg: g.targetPkg, Operator: query.Operator, Predicates: predicates, }, nil diff --git a/internal/mongo/count_test.go b/internal/mongo/count_test.go index 06b90a8..82dbd29 100644 --- a/internal/mongo/count_test.go +++ b/internal/mongo/count_test.go @@ -464,11 +464,11 @@ func TestGenerateMethod_Count(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(testutils.Pkg, "User", "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, testutils.TypeUserNamed, "UserRepository") expectedReceiver := codegen.MethodReceiver{ - Name: "r", - Type: "UserRepositoryMongo", - Pointer: true, + Name: "r", + TypeName: "UserRepositoryMongo", + Pointer: true, } params := testCase.MethodSpec.Signature.Params() diff --git a/internal/mongo/delete_test.go b/internal/mongo/delete_test.go index 7ba10e7..3b49277 100644 --- a/internal/mongo/delete_test.go +++ b/internal/mongo/delete_test.go @@ -514,11 +514,11 @@ func TestGenerateMethod_Delete(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(testutils.Pkg, "User", "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, testutils.TypeUserNamed, "UserRepository") expectedReceiver := codegen.MethodReceiver{ - Name: "r", - Type: "UserRepositoryMongo", - Pointer: true, + Name: "r", + TypeName: "UserRepositoryMongo", + Pointer: true, } params := testCase.MethodSpec.Signature.Params() diff --git a/internal/mongo/find.go b/internal/mongo/find.go index 2a71434..bd67a20 100644 --- a/internal/mongo/find.go +++ b/internal/mongo/find.go @@ -1,9 +1,9 @@ package mongo import ( + "go/types" "strconv" - "github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/codegen" "github.com/sunboyy/repogen/internal/spec" ) @@ -44,10 +44,7 @@ func (g findBodyGenerator) generateFindOneBody(querySpec querySpec, sortsCode codegen.MapStatement) codegen.FunctionBody { return codegen.FunctionBody{ - codegen.DeclStatement{ - Name: "entity", - Type: code.SimpleType(g.structModelName), - }, + codegen.NewDeclStatement(g.targetPkg, "entity", g.structModelNamed), codegen.IfBlock{ Condition: []codegen.Statement{ codegen.DeclAssignStatement{ @@ -101,14 +98,11 @@ func (g findBodyGenerator) generateFindManyBody(querySpec querySpec, codegen.DeclAssignStatement{ Vars: []string{"entities"}, Values: []codegen.Statement{ - codegen.SliceStatement{ - Type: code.ArrayType{ - ContainedType: code.PointerType{ - ContainedType: code.SimpleType(g.structModelName), - }, - }, - Values: []codegen.Statement{}, - }, + codegen.NewSliceStatement( + g.targetPkg, + types.NewSlice(types.NewPointer(g.structModelNamed)), + []codegen.Statement{}, + ), }, }, codegen.IfBlock{ diff --git a/internal/mongo/find_test.go b/internal/mongo/find_test.go index 0b356d5..2d37f89 100644 --- a/internal/mongo/find_test.go +++ b/internal/mongo/find_test.go @@ -1077,11 +1077,11 @@ func TestGenerateMethod_Find(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(testutils.Pkg, "User", "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, testutils.TypeUserNamed, "UserRepository") expectedReceiver := codegen.MethodReceiver{ - Name: "r", - Type: "UserRepositoryMongo", - Pointer: true, + Name: "r", + TypeName: "UserRepositoryMongo", + Pointer: true, } params := testCase.MethodSpec.Signature.Params() diff --git a/internal/mongo/generator.go b/internal/mongo/generator.go index a3a595a..372404a 100644 --- a/internal/mongo/generator.go +++ b/internal/mongo/generator.go @@ -8,15 +8,14 @@ import ( "github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/codegen" "github.com/sunboyy/repogen/internal/spec" - "golang.org/x/tools/go/packages" ) // NewGenerator creates a new instance of MongoDB repository generator -func NewGenerator(pkg *types.Package, structModelName string, interfaceName string) RepositoryGenerator { +func NewGenerator(targetPkg *types.Package, structModelNamed *types.Named, interfaceName string) RepositoryGenerator { return RepositoryGenerator{ baseMethodGenerator: baseMethodGenerator{ - pkg: pkg, - structModelName: structModelName, + targetPkg: targetPkg, + structModelNamed: structModelNamed, }, InterfaceName: interfaceName, } @@ -30,8 +29,8 @@ type RepositoryGenerator struct { } // Imports returns necessary imports for the mongo repository implementation. -func (g RepositoryGenerator) Imports() [][]code.Import { - return [][]code.Import{ +func (g RepositoryGenerator) Imports() [][]codegen.Import { + return [][]codegen.Import{ { {Path: "context"}, }, @@ -48,16 +47,11 @@ func (g RepositoryGenerator) Imports() [][]code.Import { // implementation struct. func (g RepositoryGenerator) GenerateStruct() codegen.StructBuilder { return codegen.StructBuilder{ + Pkg: g.targetPkg, Name: g.repoImplStructName(), - Fields: []code.LegacyStructField{ + Fields: []code.StructField{ { - Name: "collection", - Type: code.PointerType{ - ContainedType: code.ExternalType{ - PackageAlias: "mongo", - Name: "Collection", - }, - }, + Var: types.NewVar(token.NoPos, nil, "collection", types.NewPointer(mongoCollectionType)), }, }, } @@ -66,18 +60,10 @@ func (g RepositoryGenerator) GenerateStruct() codegen.StructBuilder { // GenerateConstructor creates codegen.FunctionBuilder of a constructor for // mongo repository implementation struct. func (g RepositoryGenerator) GenerateConstructor() (codegen.FunctionBuilder, error) { - mongoPkgs, err := packages.Load(&packages.Config{Mode: packages.NeedTypes}, "go.mongodb.org/mongo-driver/mongo") - if err != nil { - return codegen.FunctionBuilder{}, err - } - mongoPkg := mongoPkgs[0] - collectionObj := mongoPkg.Types.Scope().Lookup("Collection") - collectionType := collectionObj.Type() - return codegen.FunctionBuilder{ - Pkg: g.pkg, + Pkg: g.targetPkg, Name: "New" + g.InterfaceName, - Params: types.NewTuple(types.NewVar(token.NoPos, nil, "collection", types.NewPointer(collectionType))), + Params: types.NewTuple(types.NewVar(token.NoPos, nil, "collection", types.NewPointer(mongoCollectionType))), Returns: []types.Type{ types.NewPointer(types.NewNamed( types.NewTypeName(token.NoPos, nil, g.repoImplStructName(), nil), nil, nil)), @@ -117,11 +103,11 @@ func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec) (codegen } return codegen.MethodBuilder{ - Pkg: g.pkg, + Pkg: g.targetPkg, Receiver: codegen.MethodReceiver{ - Name: "r", - Type: code.SimpleType(g.repoImplStructName()), - Pointer: true, + Name: "r", + TypeName: g.repoImplStructName(), + Pointer: true, }, Name: methodSpec.Name, Params: types.NewTuple(paramVars...), diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index 537993e..46623f3 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -15,8 +15,8 @@ import ( ) func TestImports(t *testing.T) { - generator := mongo.NewGenerator(testutils.Pkg, "User", "UserRepository") - expected := [][]code.Import{ + generator := mongo.NewGenerator(testutils.Pkg, testutils.TypeUserNamed, "UserRepository") + expected := [][]codegen.Import{ { {Path: "context"}, }, @@ -36,18 +36,14 @@ func TestImports(t *testing.T) { } func TestGenerateStruct(t *testing.T) { - generator := mongo.NewGenerator(testutils.Pkg, "User", "UserRepository") + bareMongoPkg := types.NewPackage("go.mongodb.org/mongo-driver/mongo", "mongo") + bareCollectionType := types.NewNamed(types.NewTypeName(token.NoPos, bareMongoPkg, "Collection", nil), nil, nil) + generator := mongo.NewGenerator(testutils.Pkg, testutils.TypeUserNamed, "UserRepository") expected := codegen.StructBuilder{ Name: "UserRepositoryMongo", - Fields: []code.LegacyStructField{ + Fields: []code.StructField{ { - Name: "collection", - Type: code.PointerType{ - ContainedType: code.ExternalType{ - PackageAlias: "mongo", - Name: "Collection", - }, - }, + Var: types.NewVar(token.NoPos, nil, "collection", types.NewPointer(bareCollectionType)), }, }, } @@ -71,7 +67,7 @@ func TestGenerateStruct(t *testing.T) { } func TestGenerateConstructor(t *testing.T) { - generator := mongo.NewGenerator(testutils.Pkg, "User", "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, testutils.TypeUserNamed, "UserRepository") expected := codegen.FunctionBuilder{ Name: "NewUserRepository", Params: types.NewTuple( @@ -371,7 +367,7 @@ func TestGenerateMethod_Invalid(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(testutils.Pkg, "User", "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, testutils.TypeUserNamed, "UserRepository") _, err := generator.GenerateMethod(testCase.Method) diff --git a/internal/mongo/insert.go b/internal/mongo/insert.go index 996acbc..5749791 100644 --- a/internal/mongo/insert.go +++ b/internal/mongo/insert.go @@ -1,7 +1,8 @@ package mongo import ( - "github.com/sunboyy/repogen/internal/code" + "go/types" + "github.com/sunboyy/repogen/internal/codegen" "github.com/sunboyy/repogen/internal/spec" ) @@ -38,12 +39,11 @@ func (g RepositoryGenerator) generateInsertOneBody() codegen.FunctionBody { func (g RepositoryGenerator) generateInsertManyBody() codegen.FunctionBody { return codegen.FunctionBody{ - codegen.DeclStatement{ - Name: "entities", - Type: code.ArrayType{ - ContainedType: code.InterfaceType{}, - }, - }, + codegen.NewDeclStatement( + g.targetPkg, + "entities", + types.NewSlice(types.NewInterfaceType(nil, nil)), + ), codegen.RawBlock{ Header: []string{"for _, model := range arg1"}, Statements: []codegen.Statement{ diff --git a/internal/mongo/insert_test.go b/internal/mongo/insert_test.go index 2a97b53..373a68d 100644 --- a/internal/mongo/insert_test.go +++ b/internal/mongo/insert_test.go @@ -80,11 +80,11 @@ func TestGenerateMethod_Insert(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(testutils.Pkg, "User", "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, testutils.TypeUserNamed, "UserRepository") expectedReceiver := codegen.MethodReceiver{ - Name: "r", - Type: "UserRepositoryMongo", - Pointer: true, + Name: "r", + TypeName: "UserRepositoryMongo", + Pointer: true, } params := testCase.MethodSpec.Signature.Params() diff --git a/internal/mongo/models.go b/internal/mongo/models.go index 4b863de..70e99d1 100644 --- a/internal/mongo/models.go +++ b/internal/mongo/models.go @@ -2,9 +2,9 @@ package mongo import ( "fmt" + "go/types" "sort" - "github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/codegen" "github.com/sunboyy/repogen/internal/spec" ) @@ -66,6 +66,7 @@ func (u updateFields) Code() codegen.Statement { } type querySpec struct { + TargetPkg *types.Package Operator spec.Operator Predicates []predicate } @@ -89,28 +90,20 @@ func (q querySpec) Code() codegen.Statement { case spec.OperatorOr: stmt.Pairs = append(stmt.Pairs, codegen.MapPair{ Key: "$or", - Value: codegen.SliceStatement{ - Type: code.ArrayType{ - ContainedType: code.ExternalType{ - PackageAlias: "bson", - Name: "M", - }, - }, - Values: predicateMaps, - }, + Value: codegen.NewSliceStatement( + q.TargetPkg, + types.NewSlice(bsonMType), + predicateMaps, + ), }) case spec.OperatorAnd: stmt.Pairs = append(stmt.Pairs, codegen.MapPair{ Key: "$and", - Value: codegen.SliceStatement{ - Type: code.ArrayType{ - ContainedType: code.ExternalType{ - PackageAlias: "bson", - Name: "M", - }, - }, - Values: predicateMaps, - }, + Value: codegen.NewSliceStatement( + q.TargetPkg, + types.NewSlice(bsonMType), + predicateMaps, + ), }) default: stmt.Pairs = predicatePairs diff --git a/internal/mongo/update_test.go b/internal/mongo/update_test.go index 752a46c..802428a 100644 --- a/internal/mongo/update_test.go +++ b/internal/mongo/update_test.go @@ -424,11 +424,11 @@ func TestGenerateMethod_Update(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(testutils.Pkg, "User", "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, testutils.TypeUserNamed, "UserRepository") expectedReceiver := codegen.MethodReceiver{ - Name: "r", - Type: "UserRepositoryMongo", - Pointer: true, + Name: "r", + TypeName: "UserRepositoryMongo", + Pointer: true, } params := testCase.MethodSpec.Signature.Params() diff --git a/main.go b/main.go index 21d529a..10b4371 100644 --- a/main.go +++ b/main.go @@ -136,5 +136,5 @@ func generateRepository(pkg *types.Package, structModelName, repositoryInterface methodSpecs = append(methodSpecs, methodSpec) } - return generator.GenerateRepository(pkg, structModelName, repositoryInterfaceName, methodSpecs) + return generator.GenerateRepository(pkg, namedStruct, repositoryInterfaceName, methodSpecs) }