diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a4f6a05..b555d40 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/setup-go@v4 with: - go-version: 1.19 + go-version: 1.22 - name: Build run: go build -v ./... diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 7462990..e16f404 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -14,9 +14,9 @@ jobs: - uses: actions/setup-go@v4 with: - go-version: 1.19 + go-version: 1.22 - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: - version: v1.52.2 + version: v1.58.1 diff --git a/.golangci.yml b/.golangci.yml index 45af5a6..d1d63d8 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -2,7 +2,7 @@ linters: enable: - errname - errorlint - - goerr113 + - err113 - lll - stylecheck linters-settings: diff --git a/go.mod b/go.mod index 70563df..663fd0b 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,13 @@ module github.com/sunboyy/repogen -go 1.21 +go 1.22 + +toolchain go1.22.3 require ( github.com/fatih/camelcase v1.0.0 - go.mongodb.org/mongo-driver v1.13.1 - golang.org/x/tools v0.17.0 + go.mongodb.org/mongo-driver v1.15.0 + golang.org/x/tools v0.21.0 ) require ( @@ -16,8 +18,8 @@ require ( github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect - golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect - golang.org/x/mod v0.14.0 // indirect - golang.org/x/sync v0.6.0 // indirect - golang.org/x/text v0.7.0 // indirect + golang.org/x/crypto v0.17.0 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/text v0.14.0 // indirect ) diff --git a/go.sum b/go.sum index 4c94321..ef7cd8f 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ github.com/fatih/camelcase v1.0.0 h1:hxNvNX/xYBp0ovncs8WyWZrOrpBNub/JfaMvbURyft8 github.com/fatih/camelcase v1.0.0/go.mod h1:yN2Sb0lFhZJUdVvtELVWefmrXpuZESvPmqwoZc+/fpc= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= -github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= @@ -19,26 +19,24 @@ github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gi github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mongodb.org/mongo-driver v1.13.1 h1:YIc7HTYsKndGK4RFzJ3covLz1byri52x0IoMB0Pt/vk= -go.mongodb.org/mongo-driver v1.13.1/go.mod h1:wcDf1JBCXy2mOW0bWHwO/IOYqdca1MPCwDtFu/Z9+eo= +go.mongodb.org/mongo-driver v1.15.0 h1:rJCKC8eEliewXjZGf0ddURtl7tTVy1TK3bfl0gkUSLc= +go.mongodb.org/mongo-driver v1.15.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= -golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -46,16 +44,13 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= -golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= +golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= +golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/code/errors.go b/internal/code/errors.go deleted file mode 100644 index c6bfd1f..0000000 --- a/internal/code/errors.go +++ /dev/null @@ -1,21 +0,0 @@ -package code - -import "fmt" - -type DuplicateStructError string - -func (err DuplicateStructError) Error() string { - return fmt.Sprintf( - "code: duplicate implementation of struct '%s'", - string(err), - ) -} - -type DuplicateInterfaceError string - -func (err DuplicateInterfaceError) Error() string { - return fmt.Sprintf( - "code: duplicate implementation of interface '%s'", - string(err), - ) -} diff --git a/internal/code/errors_test.go b/internal/code/errors_test.go deleted file mode 100644 index 3446807..0000000 --- a/internal/code/errors_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package code_test - -import ( - "testing" - - "github.com/sunboyy/repogen/internal/code" -) - -type ErrorTestCase struct { - Name string - Error error - ExpectedString string -} - -func TestError(t *testing.T) { - testTable := []ErrorTestCase{ - { - Name: "DuplicateStructError", - Error: code.DuplicateStructError("User"), - ExpectedString: "code: duplicate implementation of struct 'User'", - }, - { - Name: "DuplicateInterfaceError", - Error: code.DuplicateInterfaceError("UserRepository"), - ExpectedString: "code: duplicate implementation of interface 'UserRepository'", - }, - } - - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - if testCase.Error.Error() != testCase.ExpectedString { - t.Errorf("Expected = %+v\nReceived = %+v", testCase.ExpectedString, testCase.Error.Error()) - } - }) - } -} diff --git a/internal/code/extractor.go b/internal/code/extractor.go deleted file mode 100644 index f3b064a..0000000 --- a/internal/code/extractor.go +++ /dev/null @@ -1,170 +0,0 @@ -package code - -import ( - "fmt" - "go/ast" - "reflect" - "strconv" - "strings" -) - -// ExtractComponents converts ast file into code components model -func ExtractComponents(f *ast.File) File { - file := File{ - Interfaces: map[string]InterfaceType{}, - } - - for _, decl := range f.Decls { - genDecl, ok := decl.(*ast.GenDecl) - if !ok { - continue - } - - for _, spec := range genDecl.Specs { - switch spec := spec.(type) { - case *ast.ImportSpec: - var imp Import - if spec.Name != nil { - imp.Name = spec.Name.Name - } - importPath, err := strconv.Unquote(spec.Path.Value) - if err != nil { - fmt.Printf("cannot unquote import %s : %s \n", spec.Path.Value, err) - continue - } - imp.Path = importPath - - file.Imports = append(file.Imports, imp) - - case *ast.TypeSpec: - switch t := spec.Type.(type) { - case *ast.StructType: - file.Structs = append(file.Structs, extractStructType(spec.Name.Name, t)) - case *ast.InterfaceType: - file.Interfaces[spec.Name.Name] = extractInterfaceType(t) - } - } - } - } - return file -} - -func extractStructType(name string, structType *ast.StructType) Struct { - str := Struct{ - Name: name, - } - - for _, field := range structType.Fields.List { - var strField StructField - for _, name := range field.Names { - strField.Name = name.Name - break - } - strField.Type = getType(field.Type) - if field.Tag != nil { - strField.Tag = extractStructTag(field.Tag.Value) - } - - str.Fields = append(str.Fields, strField) - } - - return str -} - -func extractInterfaceType(interfaceType *ast.InterfaceType) InterfaceType { - intf := InterfaceType{} - - for _, method := range interfaceType.Methods.List { - funcType, ok := method.Type.(*ast.FuncType) - if !ok { - continue - } - - var name string - for _, n := range method.Names { - name = n.Name - break - } - - var comments []string - if method.Doc != nil { - for _, comment := range method.Doc.List { - commentRunes := []rune(comment.Text) - commentText := strings.TrimSpace(string(commentRunes[2:])) - comments = append(comments, commentText) - } - } - - meth := extractFunction(name, comments, funcType) - - intf.Methods = append(intf.Methods, meth) - } - - return intf -} - -func extractStructTag(tagValue string) reflect.StructTag { - return reflect.StructTag(tagValue[1 : len(tagValue)-1]) -} - -func extractFunction(name string, comments []string, funcType *ast.FuncType) Method { - meth := Method{ - Name: name, - Comments: comments, - } - for _, param := range funcType.Params.List { - paramType := getType(param.Type) - - if len(param.Names) == 0 { - meth.Params = append(meth.Params, Param{Type: paramType}) - continue - } - - for _, name := range param.Names { - meth.Params = append(meth.Params, Param{ - Name: name.Name, - Type: paramType, - }) - } - } - - if funcType.Results != nil { - for _, result := range funcType.Results.List { - meth.Returns = append(meth.Returns, getType(result.Type)) - } - } - - return meth -} - -func getType(expr ast.Expr) Type { - switch expr := expr.(type) { - case *ast.Ident: - return SimpleType(expr.Name) - - case *ast.SelectorExpr: - xExpr, ok := expr.X.(*ast.Ident) - if !ok { - return ExternalType{Name: expr.Sel.Name} - } - return ExternalType{PackageAlias: xExpr.Name, Name: expr.Sel.Name} - - case *ast.StarExpr: - containedType := getType(expr.X) - return PointerType{ContainedType: containedType} - - case *ast.ArrayType: - containedType := getType(expr.Elt) - return ArrayType{ContainedType: containedType} - - case *ast.MapType: - keyType := getType(expr.Key) - valueType := getType(expr.Value) - return MapType{KeyType: keyType, ValueType: valueType} - - case *ast.InterfaceType: - return extractInterfaceType(expr) - } - - return nil -} diff --git a/internal/code/extractor_test.go b/internal/code/extractor_test.go deleted file mode 100644 index 682229f..0000000 --- a/internal/code/extractor_test.go +++ /dev/null @@ -1,323 +0,0 @@ -package code_test - -import ( - "go/parser" - "go/token" - "reflect" - "testing" - - "github.com/sunboyy/repogen/internal/code" -) - -type TestCase struct { - Name string - Source string - ExpectedOutput code.File -} - -func TestExtractComponents(t *testing.T) { - testTable := []TestCase{ - { - Name: "package name", - Source: `package user`, - ExpectedOutput: code.File{ - Interfaces: map[string]code.InterfaceType{}, - }, - }, - { - Name: "single line imports", - Source: `package user - -import ctx "context" -import "go.mongodb.org/mongo-driver/bson/primitive"`, - ExpectedOutput: code.File{ - Imports: []code.Import{ - {Name: "ctx", Path: "context"}, - {Path: "go.mongodb.org/mongo-driver/bson/primitive"}, - }, - Interfaces: map[string]code.InterfaceType{}, - }, - }, - { - Name: "multiple line imports", - Source: `package user - -import ( - ctx "context" - "go.mongodb.org/mongo-driver/bson/primitive" -)`, - ExpectedOutput: code.File{ - Imports: []code.Import{ - {Name: "ctx", Path: "context"}, - {Path: "go.mongodb.org/mongo-driver/bson/primitive"}, - }, - Interfaces: map[string]code.InterfaceType{}, - }, - }, - { - Name: "struct declaration", - Source: `package user - -type UserModel struct { - ID primitive.ObjectID ` + "`bson:\"_id,omitempty\" json:\"id\"`" + ` - Username string ` + "`bson:\"username\" json:\"username\"`" + ` - Password string ` + "`bson:\"password\" json:\"-\" note:\"This should be hidden.\"`" + ` -}`, - ExpectedOutput: code.File{ - Structs: []code.Struct{ - { - Name: "UserModel", - Fields: code.StructFields{ - code.StructField{ - Name: "ID", - Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}, - Tag: `bson:"_id,omitempty" json:"id"`, - }, - code.StructField{ - Name: "Username", - Type: code.TypeString, - Tag: `bson:"username" json:"username"`, - }, - code.StructField{ - Name: "Password", - Type: code.TypeString, - Tag: `bson:"password" json:"-" note:"This should be hidden."`, - }, - }, - }, - }, - Interfaces: map[string]code.InterfaceType{}, - }, - }, - { - Name: "interface declaration", - Source: `package user - -type UserRepository interface { - FindByID(ctx context.Context, id primitive.ObjectID) (*UserModel, error) - FindAll(context.Context) ([]*UserModel, error) - FindByAgeBetween(ctx context.Context, fromAge, toAge int) ([]*UserModel, error) - InsertOne(ctx context.Context, user *UserModel) (interface{}, error) - UpdateAgreementByID(ctx context.Context, agreement map[string]bool, id primitive.ObjectID) (bool, error) - // CustomMethod does custom things. - CustomMethod(interface { - Run(arg1 int) - }) interface { - Do(arg2 string) - } -}`, - ExpectedOutput: code.File{ - Interfaces: map[string]code.InterfaceType{ - "UserRepository": { - Methods: []code.Method{ - { - Name: "FindByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.TypeError, - }, - }, - { - Name: "FindAll", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ - ContainedType: code.PointerType{ - ContainedType: code.SimpleType("UserModel"), - }, - }, - code.TypeError, - }, - }, - { - Name: "FindByAgeBetween", - Params: []code.Param{ - { - Name: "ctx", - Type: code.ExternalType{PackageAlias: "context", Name: "Context"}, - }, - { - Name: "fromAge", - Type: code.TypeInt, - }, - { - Name: "toAge", - Type: code.TypeInt, - }, - }, - Returns: []code.Type{ - code.ArrayType{ - ContainedType: code.PointerType{ - ContainedType: code.SimpleType("UserModel"), - }, - }, - code.TypeError, - }, - }, - { - Name: "InsertOne", - Params: []code.Param{ - { - Name: "ctx", - Type: code.ExternalType{PackageAlias: "context", Name: "Context"}, - }, - { - Name: "user", - Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}, - }, - }, - Returns: []code.Type{ - code.InterfaceType{}, - code.TypeError, - }, - }, - { - Name: "UpdateAgreementByID", - Params: []code.Param{ - { - Name: "ctx", - Type: code.ExternalType{PackageAlias: "context", Name: "Context"}, - }, - { - Name: "agreement", - Type: code.MapType{KeyType: code.TypeString, ValueType: code.TypeBool}, - }, - { - Name: "id", - Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}, - }, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - }, - { - Name: "CustomMethod", - Comments: []string{"CustomMethod does custom things."}, - Params: []code.Param{ - { - Type: code.InterfaceType{ - Methods: []code.Method{ - { - Name: "Run", - Params: []code.Param{ - {Name: "arg1", Type: code.TypeInt}, - }, - }, - }, - }, - }, - }, - Returns: []code.Type{ - code.InterfaceType{ - Methods: []code.Method{ - { - Name: "Do", - Params: []code.Param{ - {Name: "arg2", Type: code.TypeString}, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - { - Name: "integration", - Source: `package user - -import ( - "context" - - "go.mongodb.org/mongo-driver/bson/primitive" -) - -type UserModel struct { - ID primitive.ObjectID ` + "`bson:\"_id,omitempty\" json:\"id\"`" + ` - Username string ` + "`bson:\"username\" json:\"username\"`" + ` -} - -type UserRepository interface { - FindByID(ctx context.Context, id primitive.ObjectID) (*UserModel, error) - FindAll(ctx context.Context) ([]*UserModel, error) -} -`, - ExpectedOutput: code.File{ - Imports: []code.Import{ - {Path: "context"}, - {Path: "go.mongodb.org/mongo-driver/bson/primitive"}, - }, - Structs: []code.Struct{ - { - Name: "UserModel", - Fields: code.StructFields{ - code.StructField{ - Name: "ID", - Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}, - Tag: `bson:"_id,omitempty" json:"id"`, - }, - code.StructField{ - Name: "Username", - Type: code.TypeString, - Tag: `bson:"username" json:"username"`, - }, - }, - }, - }, - Interfaces: map[string]code.InterfaceType{ - "UserRepository": { - Methods: []code.Method{ - { - Name: "FindByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.TypeError, - }, - }, - { - Name: "FindAll", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ - ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}, - }, - code.TypeError, - }, - }, - }, - }, - }, - }, - }, - } - - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - fset := token.NewFileSet() - f, _ := parser.ParseFile(fset, "", testCase.Source, parser.ParseComments) - - file := code.ExtractComponents(f) - - if !reflect.DeepEqual(file, testCase.ExpectedOutput) { - t.Errorf("Expected = %+v\nReceived = %+v", testCase.ExpectedOutput, file) - } - }) - } -} diff --git a/internal/code/models.go b/internal/code/models.go index 8760cec..105dbf1 100644 --- a/internal/code/models.go +++ b/internal/code/models.go @@ -2,56 +2,31 @@ package code import ( "fmt" + "go/types" "reflect" ) -// File is a container of all required components for code generation in the file -type File struct { - Imports []Import - Structs []Struct - Interfaces map[string]InterfaceType -} - // Import is a model for package imports type Import struct { Name string Path string } -// Struct is a definition of the struct -type Struct struct { - Name string - Fields StructFields -} - -// ReferencedType returns a type variable of this struct -func (str Struct) ReferencedType() Type { - return SimpleType(str.Name) -} - -// StructFields is a group of the StructField model -type StructFields []StructField - -// ByName return struct field with matching name -func (fields StructFields) ByName(name string) (StructField, bool) { - for _, field := range fields { - if field.Name == name { - return field, true - } - } - return StructField{}, false +// LegacyStructField is a definition of the struct field +type LegacyStructField struct { + Name string + Type Type + Tag reflect.StructTag } // StructField is a definition of the struct field type StructField struct { - Name string - Type Type - Tag reflect.StructTag + Var *types.Var + Tag reflect.StructTag } // InterfaceType is a definition of the interface type InterfaceType struct { - Methods []Method } // Code returns token string in code format @@ -59,29 +34,9 @@ func (intf InterfaceType) Code() string { return `interface{}` } -// IsNumber returns false -func (intf InterfaceType) IsNumber() bool { - return false -} - -// Method is a definition of the method inside the interface -type Method struct { - Name string - Comments []string - Params []Param - Returns []Type -} - -// Param is a model of method parameter -type Param struct { - Name string - Type Type -} - // Type is an interface for value types type Type interface { Code() string - IsNumber() bool } // SimpleType is a type that can be called directly @@ -92,20 +47,13 @@ func (t SimpleType) Code() string { return string(t) } -// IsNumber returns true id a SimpleType is integer or float variants. -func (t SimpleType) IsNumber() bool { - return t == "uint" || t == "uint8" || t == "uint16" || t == "uint32" || t == "uint64" || - t == "int" || t == "int8" || t == "int16" || t == "int32" || t == "int64" || - t == "float32" || t == "float64" -} - -// commonly-used types -const ( - TypeBool = SimpleType("bool") - TypeInt = SimpleType("int") - TypeFloat64 = SimpleType("float64") - TypeString = SimpleType("string") - TypeError = SimpleType("error") +var ( + TypeBool = types.Typ[types.Bool] + TypeInt = types.Typ[types.Int] + TypeInt64 = types.Typ[types.Int64] + TypeFloat64 = types.Typ[types.Float64] + TypeString = types.Typ[types.String] + TypeError = types.Universe.Lookup("error").Type() ) // ExternalType is a type that is called to another package @@ -119,11 +67,6 @@ func (t ExternalType) Code() string { return fmt.Sprintf("%s.%s", t.PackageAlias, t.Name) } -// IsNumber returns false -func (t ExternalType) IsNumber() bool { - return false -} - // PointerType is a model of pointer type PointerType struct { ContainedType Type @@ -134,11 +77,6 @@ func (t PointerType) Code() string { return fmt.Sprintf("*%s", t.ContainedType.Code()) } -// IsNumber returns IsNumber of its contained type -func (t PointerType) IsNumber() bool { - return t.ContainedType.IsNumber() -} - // ArrayType is a model of array type ArrayType struct { ContainedType Type @@ -148,24 +86,3 @@ type ArrayType struct { func (t ArrayType) Code() string { return fmt.Sprintf("[]%s", t.ContainedType.Code()) } - -// IsNumber returns false -func (t ArrayType) IsNumber() bool { - return false -} - -// MapType is a model of map -type MapType struct { - KeyType Type - ValueType Type -} - -// Code returns token string in code format -func (t MapType) Code() string { - return fmt.Sprintf("map[%s]%s", t.KeyType.Code(), t.ValueType.Code()) -} - -// IsNumber returns false -func (t MapType) IsNumber() bool { - return false -} diff --git a/internal/code/models_test.go b/internal/code/models_test.go index b64a337..594f00d 100644 --- a/internal/code/models_test.go +++ b/internal/code/models_test.go @@ -1,37 +1,11 @@ package code_test import ( - "reflect" "testing" "github.com/sunboyy/repogen/internal/code" ) -func TestStructFieldsByName(t *testing.T) { - idField := code.StructField{Name: "ID", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}} - usernameField := code.StructField{Name: "Username", Type: code.TypeString} - fields := code.StructFields{idField, usernameField} - - t.Run("struct field found", func(t *testing.T) { - field, ok := fields.ByName("Username") - - if !ok { - t.Fail() - } - if !reflect.DeepEqual(field, usernameField) { - t.Errorf("Expected = %+v\nReceived = %+v", usernameField, field) - } - }) - - t.Run("struct field not found", func(t *testing.T) { - _, ok := fields.ByName("Password") - - if ok { - t.Fail() - } - }) -} - type TypeCodeTestCase struct { Name string Type code.Type @@ -60,14 +34,6 @@ func TestTypeCode(t *testing.T) { Type: code.ArrayType{ContainedType: code.SimpleType("UserModel")}, ExpectedCode: "[]UserModel", }, - { - Name: "map type", - Type: code.MapType{ - KeyType: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}, - ValueType: code.PointerType{ContainedType: code.SimpleType("UserModel")}, - }, - ExpectedCode: "map[primitive.ObjectID]*UserModel", - }, } for _, testCase := range testTable { @@ -80,94 +46,3 @@ func TestTypeCode(t *testing.T) { }) } } - -type TypeIsNumberTestCase struct { - Name string - Type code.Type - IsNumber bool -} - -func TestTypeIsNumber(t *testing.T) { - testTable := []TypeIsNumberTestCase{ - { - Name: "simple type: int", - Type: code.TypeInt, - IsNumber: true, - }, - { - Name: "simple type: other integer variants", - Type: code.SimpleType("int64"), - IsNumber: true, - }, - { - Name: "simple type: uint", - Type: code.SimpleType("uint"), - IsNumber: true, - }, - { - Name: "simple type: other unsigned integer variants", - Type: code.SimpleType("uint64"), - IsNumber: true, - }, - { - Name: "simple type: float32", - Type: code.SimpleType("float32"), - IsNumber: true, - }, - { - Name: "simple type: other float variants", - Type: code.TypeFloat64, - IsNumber: true, - }, - { - Name: "simple type: non-number primitive type", - Type: code.TypeString, - IsNumber: false, - }, - { - Name: "simple type: non-number custom type", - Type: code.SimpleType("UserModel"), - IsNumber: false, - }, - { - Name: "external type", - Type: code.ExternalType{PackageAlias: "context", Name: "Context"}, - IsNumber: false, - }, - { - Name: "pointer type: number", - Type: code.PointerType{ContainedType: code.TypeInt}, - IsNumber: true, - }, - { - Name: "pointer type: non-number", - Type: code.PointerType{ContainedType: code.TypeString}, - IsNumber: false, - }, - { - Name: "array type", - Type: code.ArrayType{ContainedType: code.TypeInt}, - IsNumber: false, - }, - { - Name: "map type", - Type: code.MapType{KeyType: code.TypeInt, ValueType: code.TypeFloat64}, - IsNumber: false, - }, - { - Name: "interface type", - Type: code.InterfaceType{}, - IsNumber: false, - }, - } - - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - isNumber := testCase.Type.IsNumber() - - if isNumber != testCase.IsNumber { - t.Errorf("Expected = %+v\nReceived = %+v", testCase.IsNumber, isNumber) - } - }) - } -} diff --git a/internal/code/package.go b/internal/code/package.go deleted file mode 100644 index 1e5a4f4..0000000 --- a/internal/code/package.go +++ /dev/null @@ -1,54 +0,0 @@ -package code - -import "golang.org/x/tools/go/packages" - -// ParsePackage extracts package name, struct and interface implementations from -// *packages.Package. -func ParsePackage(pkgPkg *packages.Package) (Package, error) { - pkg := NewPackage(pkgPkg.Name) - - for _, file := range pkgPkg.Syntax { - if err := pkg.addFile(ExtractComponents(file)); err != nil { - return Package{}, err - } - } - - return pkg, nil -} - -// Package stores package name, struct and interface implementations as a result -// from ParsePackage. -type Package struct { - Name string - Structs map[string]Struct - Interfaces map[string]InterfaceType -} - -// NewPackage is a constructor function for Package. -func NewPackage(name string) Package { - return Package{ - Name: name, - Structs: map[string]Struct{}, - Interfaces: map[string]InterfaceType{}, - } -} - -// addFile alters the Package by adding struct and interface implementations in -// the extracted file. If the package name conflicts, it will return error. -func (pkg *Package) addFile(file File) error { - for _, structImpl := range file.Structs { - if _, ok := pkg.Structs[structImpl.Name]; ok { - return DuplicateStructError(structImpl.Name) - } - pkg.Structs[structImpl.Name] = structImpl - } - - for interfaceName, interfaceImpl := range file.Interfaces { - if _, ok := pkg.Interfaces[interfaceName]; ok { - return DuplicateInterfaceError(interfaceName) - } - pkg.Interfaces[interfaceName] = interfaceImpl - } - - return nil -} diff --git a/internal/code/package_test.go b/internal/code/package_test.go deleted file mode 100644 index 9aaca4b..0000000 --- a/internal/code/package_test.go +++ /dev/null @@ -1,203 +0,0 @@ -package code_test - -import ( - "errors" - "go/ast" - "go/parser" - "go/token" - "testing" - - "github.com/sunboyy/repogen/internal/code" - "golang.org/x/tools/go/packages" -) - -const goImplFile1Data = ` -package codepkgsuccess - -import ( - "math" - "time" - - "go.mongodb.org/mongo-driver/bson/primitive" -) - -type Gender string - -const ( - GenderMale Gender = "MALE" - GenderFemale Gender = "FEMALE" -) - -type User struct { - ID primitive.ObjectID ` + "`json:\"id\"`" + ` - Name string ` + "`json:\"name\"`" + ` - Gender Gender ` + "`json:\"gender\"`" + ` - Birthday time.Time ` + "`json:\"birthday\"`" + ` -} - -func (u User) Age() int { - return int(math.Floor(time.Since(u.Birthday).Hours() / 24 / 365)) -} - -type ( - Product struct { - ID primitive.ObjectID ` + "`json:\"id\"`" + ` - Name string ` + "`json:\"name\"`" + ` - Price float64 ` + "`json:\"price\"`" + ` - } - - Order struct { - ID primitive.ObjectID ` + "`json:\"id\"`" + ` - ItemIDs map[primitive.ObjectID]int ` + "`json:\"itemIds\"`" + ` - TotalPrice float64 ` + "`json:\"totalPrice\"`" + ` - UserID primitive.ObjectID ` + "`json:\"userId\"`" + ` - CreatedAt time.Time ` + "`json:\"createdAt\"`" + ` - } -) -` - -const goImplFile2Data = ` -package codepkgsuccess - -import ( - "time" - - "go.mongodb.org/mongo-driver/bson/primitive" -) - -type OrderService interface { - CreateOrder(u User, products map[Product]int) Order -} - -type OrderServiceImpl struct{} - -func (s *OrderServiceImpl) CreateOrder(u User, products map[Product]int) Order { - itemIDs := map[primitive.ObjectID]int{} - var totalPrice float64 - for product, amount := range products { - itemIDs[product.ID] = amount - totalPrice += product.Price * float64(amount) - } - - return Order{ - ID: primitive.NewObjectID(), - ItemIDs: map[primitive.ObjectID]int{}, - TotalPrice: totalPrice, - UserID: u.ID, - CreatedAt: time.Now(), - } -} -` - -const goImplFile3Data = ` -package success -` - -const goImplFile4Data = ` -package codepkgsuccess - -type User struct { - Name string -} -` - -const goImplFile5Data = ` -package codepkgsuccess - -import "go.mongodb.org/mongo-driver/bson/primitive" - -type OrderService interface { - CancelOrder(orderID primitive.ObjectID) error -} -` - -var ( - goImplFile1 *ast.File - goImplFile2 *ast.File - goImplFile3 *ast.File - goImplFile4 *ast.File - goImplFile5 *ast.File -) - -func init() { - fset := token.NewFileSet() - goImplFile1, _ = parser.ParseFile(fset, "", goImplFile1Data, parser.ParseComments) - goImplFile2, _ = parser.ParseFile(fset, "", goImplFile2Data, parser.ParseComments) - goImplFile3, _ = parser.ParseFile(fset, "", goImplFile3Data, parser.ParseComments) - goImplFile4, _ = parser.ParseFile(fset, "", goImplFile4Data, parser.ParseComments) - goImplFile5, _ = parser.ParseFile(fset, "", goImplFile5Data, parser.ParseComments) -} - -func TestParsePackage_Success(t *testing.T) { - pkg, err := code.ParsePackage(&packages.Package{ - Name: "codepkgsuccess", - Syntax: []*ast.File{ - goImplFile1, - goImplFile2, - goImplFile3, - }, - }) - if err != nil { - t.Fatal(err) - } - - if pkg.Name != "codepkgsuccess" { - t.Errorf("expected package name 'codepkgsuccess', got '%s'", pkg.Name) - } - if _, ok := pkg.Structs["User"]; !ok { - t.Error("struct 'User' not found") - } - if _, ok := pkg.Structs["Product"]; !ok { - t.Error("struct 'Product' not found") - } - if _, ok := pkg.Structs["Order"]; !ok { - t.Error("struct 'Order' not found") - } - if _, ok := pkg.Structs["OrderServiceImpl"]; !ok { - t.Error("struct 'OrderServiceImpl' not found") - } - if _, ok := pkg.Interfaces["OrderService"]; !ok { - t.Error("interface 'OrderService' not found") - } - if _, ok := pkg.Structs["TestCase"]; ok { - t.Error("unexpected struct 'TestCase' in test file") - } -} - -func TestParsePackage_DuplicateStructs(t *testing.T) { - _, err := code.ParsePackage(&packages.Package{ - Name: "codepkgsuccess", - Syntax: []*ast.File{ - goImplFile1, - goImplFile2, - goImplFile4, - }, - }) - - if !errors.Is(err, code.DuplicateStructError("User")) { - t.Errorf( - "expected error '%s', got '%s'", - code.DuplicateStructError("User").Error(), - err.Error(), - ) - } -} - -func TestParsePackage_DuplicateInterfaces(t *testing.T) { - _, err := code.ParsePackage(&packages.Package{ - Name: "codepkgsuccess", - Syntax: []*ast.File{ - goImplFile1, - goImplFile2, - goImplFile5, - }, - }) - - if !errors.Is(err, code.DuplicateInterfaceError("OrderService")) { - t.Errorf( - "expected error '%s', got '%s'", - code.DuplicateInterfaceError("OrderService").Error(), - err.Error(), - ) - } -} diff --git a/internal/codegen/builder_test.go b/internal/codegen/builder_test.go index 84c0849..7459799 100644 --- a/internal/codegen/builder_test.go +++ b/internal/codegen/builder_test.go @@ -1,6 +1,8 @@ package codegen_test import ( + "go/token" + "go/types" "testing" "github.com/sunboyy/repogen/internal/code" @@ -55,7 +57,7 @@ func TestBuilderBuild(t *testing.T) { }) builder.AddImplementer(codegen.StructBuilder{ Name: "User", - Fields: code.StructFields{ + Fields: []code.LegacyStructField{ { Name: "ID", Type: code.ExternalType{ @@ -66,16 +68,18 @@ func TestBuilderBuild(t *testing.T) { }, { Name: "Username", - Type: code.TypeString, + Type: code.SimpleType("string"), }, }, }) builder.AddImplementer(codegen.FunctionBuilder{ Name: "NewUser", - Params: []code.Param{ - {Name: "username", Type: code.TypeString}, + Params: types.NewTuple( + types.NewVar(token.NoPos, nil, "username", code.TypeString), + ), + Returns: []types.Type{ + types.NewNamed(types.NewTypeName(token.NoPos, nil, "User", nil), nil, nil), }, - Returns: []code.Type{code.SimpleType("User")}, Body: codegen.FunctionBody{ codegen.ReturnStatement{ codegen.StructStatement{ @@ -103,7 +107,7 @@ func TestBuilderBuild(t *testing.T) { Receiver: codegen.MethodReceiver{Name: "u", Type: code.SimpleType("User")}, Name: "IDHex", Params: nil, - Returns: []code.Type{code.TypeString}, + Returns: []types.Type{code.TypeString}, Body: codegen.FunctionBody{ codegen.ReturnStatement{ codegen.ChainStatement{ diff --git a/internal/codegen/function.go b/internal/codegen/function.go index 73d7684..73726dd 100644 --- a/internal/codegen/function.go +++ b/internal/codegen/function.go @@ -3,10 +3,9 @@ package codegen import ( "bytes" "fmt" + "go/types" "strings" "text/template" - - "github.com/sunboyy/repogen/internal/code" ) const functionTemplate = ` @@ -17,9 +16,10 @@ func {{.Name}}({{.GenParams}}){{.GenReturns}} { // FunctionBuilder is an implementer of a function. type FunctionBuilder struct { + Pkg *types.Package Name string - Params []code.Param - Returns []code.Type + Params *types.Tuple + Returns []types.Type Body FunctionBody } @@ -37,36 +37,57 @@ func (fb FunctionBuilder) Impl(buffer *bytes.Buffer) error { } func (fb FunctionBuilder) GenParams() string { - return generateParams(fb.Params) + return generateParams(fb.Pkg, fb.Params) } func (fb FunctionBuilder) GenReturns() string { - return generateReturns(fb.Returns) + return generateReturns(fb.Pkg, fb.Returns) } -func generateParams(params []code.Param) string { +func generateParams(pkg *types.Package, params *types.Tuple) string { var paramList []string - for _, param := range params { + for i := 0; i < params.Len(); i++ { + param := params.At(i) + paramList = append( paramList, - fmt.Sprintf("%s %s", param.Name, param.Type.Code()), + fmt.Sprintf("%s %s", param.Name(), typeToString(pkg, param.Type())), ) } return strings.Join(paramList, ", ") } -func generateReturns(returns []code.Type) string { +func typeToString(pkg *types.Package, t types.Type) string { + switch t := t.(type) { + case *types.Pointer: + return fmt.Sprintf("*%s", typeToString(pkg, t.Elem())) + + case *types.Slice: + return fmt.Sprintf("[]%s", typeToString(pkg, t.Elem())) + + case *types.Named: + if t.Obj().Pkg() == nil || t.Obj().Pkg().Path() == pkg.Path() { + return t.Obj().Name() + } + return fmt.Sprintf("%s.%s", t.Obj().Pkg().Name(), t.Obj().Name()) + + default: + return t.String() + } +} + +func generateReturns(pkg *types.Package, returns []types.Type) string { if len(returns) == 0 { return "" } if len(returns) == 1 { - return " " + returns[0].Code() + return " " + typeToString(pkg, returns[0]) } var returnList []string for _, ret := range returns { - returnList = append(returnList, ret.Code()) + returnList = append(returnList, typeToString(pkg, ret)) } return fmt.Sprintf(" (%s)", strings.Join(returnList, ", ")) diff --git a/internal/codegen/function_test.go b/internal/codegen/function_test.go index aa9a4ee..0181608 100644 --- a/internal/codegen/function_test.go +++ b/internal/codegen/function_test.go @@ -2,6 +2,8 @@ package codegen_test import ( "bytes" + "go/token" + "go/types" "testing" "github.com/sunboyy/repogen/internal/code" @@ -12,7 +14,7 @@ import ( func TestFunctionBuilderBuild_NoReturn(t *testing.T) { fb := codegen.FunctionBuilder{ Name: "init", - Params: nil, + Params: types.NewTuple(), Returns: nil, Body: codegen.FunctionBody{ codegen.ChainStatement{ @@ -53,22 +55,14 @@ func init() { func TestFunctionBuilderBuild_OneReturn(t *testing.T) { fb := codegen.FunctionBuilder{ Name: "NewUser", - Params: []code.Param{ - { - Name: "username", - Type: code.TypeString, - }, - { - Name: "age", - Type: code.TypeInt, - }, - { - Name: "parent", - Type: code.PointerType{ContainedType: code.SimpleType("User")}, - }, - }, - Returns: []code.Type{ - code.SimpleType("User"), + Params: types.NewTuple( + types.NewVar(token.NoPos, nil, "username", code.TypeString), + types.NewVar(token.NoPos, nil, "age", code.TypeInt), + types.NewVar(token.NoPos, nil, "parent", + types.NewPointer(types.NewNamed(types.NewTypeName(token.NoPos, nil, "User", nil), nil, nil))), + ), + Returns: []types.Type{ + types.NewNamed(types.NewTypeName(token.NoPos, nil, "User", nil), nil, nil), }, Body: codegen.FunctionBody{ codegen.ReturnStatement{ @@ -111,14 +105,12 @@ func NewUser(username string, age int, parent *User) User { func TestFunctionBuilderBuild_MultiReturn(t *testing.T) { fb := codegen.FunctionBuilder{ Name: "Save", - Params: []code.Param{ - { - Name: "user", - Type: code.SimpleType("User"), - }, - }, - Returns: []code.Type{ - code.SimpleType("User"), + Params: types.NewTuple( + types.NewVar(token.NoPos, nil, "user", + types.NewNamed(types.NewTypeName(token.NoPos, nil, "User", nil), nil, nil)), + ), + Returns: []types.Type{ + types.NewNamed(types.NewTypeName(token.NoPos, nil, "User", nil), nil, nil), code.TypeError, }, Body: codegen.FunctionBody{ diff --git a/internal/codegen/method.go b/internal/codegen/method.go index 8730aa1..e36a71f 100644 --- a/internal/codegen/method.go +++ b/internal/codegen/method.go @@ -3,6 +3,7 @@ package codegen import ( "bytes" "fmt" + "go/types" "text/template" "github.com/sunboyy/repogen/internal/code" @@ -16,10 +17,11 @@ func ({{.GenReceiver}}) {{.Name}}({{.GenParams}}){{.GenReturns}} { // MethodBuilder is an implementer of a method. type MethodBuilder struct { + Pkg *types.Package Receiver MethodReceiver Name string - Params []code.Param - Returns []code.Type + Params *types.Tuple + Returns []types.Type Body FunctionBody } @@ -58,9 +60,9 @@ func (mb MethodBuilder) generateReceiverType() string { } func (mb MethodBuilder) GenParams() string { - return generateParams(mb.Params) + return generateParams(mb.Pkg, mb.Params) } func (mb MethodBuilder) GenReturns() string { - return generateReturns(mb.Returns) + return generateReturns(mb.Pkg, mb.Returns) } diff --git a/internal/codegen/method_test.go b/internal/codegen/method_test.go index fb0a8d9..3269238 100644 --- a/internal/codegen/method_test.go +++ b/internal/codegen/method_test.go @@ -2,6 +2,8 @@ package codegen_test import ( "bytes" + "go/token" + "go/types" "testing" "github.com/sunboyy/repogen/internal/code" @@ -56,7 +58,7 @@ func TestMethodBuilderBuild_IgnorePoinerReceiverOneReturn(t *testing.T) { }, Name: "Init", Params: nil, - Returns: []code.Type{code.TypeError}, + Returns: []types.Type{code.TypeError}, Body: codegen.FunctionBody{ codegen.ReturnStatement{ codegen.ChainStatement{ @@ -99,10 +101,13 @@ func TestMethodBuilderBuild_UseReceiverMultiReturn(t *testing.T) { Type: "User", }, Name: "WithAge", - Params: []code.Param{ - {Name: "age", Type: code.TypeInt}, + Params: types.NewTuple( + types.NewVar(token.NoPos, nil, "age", code.TypeInt), + ), + Returns: []types.Type{ + types.NewNamed(types.NewTypeName(token.NoPos, nil, "User", nil), nil, nil), + code.TypeError, }, - Returns: []code.Type{code.SimpleType("User"), code.TypeError}, Body: codegen.FunctionBody{ codegen.AssignStatement{ Vars: []string{"u.Age"}, @@ -145,9 +150,9 @@ func TestMethodBuilderBuild_UsePointerReceiverNoReturn(t *testing.T) { Pointer: true, }, Name: "SetAge", - Params: []code.Param{ - {Name: "age", Type: code.TypeInt}, - }, + Params: types.NewTuple( + types.NewVar(token.NoPos, nil, "age", code.TypeInt), + ), Returns: nil, Body: codegen.FunctionBody{ codegen.AssignStatement{ diff --git a/internal/codegen/struct.go b/internal/codegen/struct.go index 82530a6..ce1b1cf 100644 --- a/internal/codegen/struct.go +++ b/internal/codegen/struct.go @@ -18,7 +18,7 @@ type {{.Name}} struct { // StructBuilder is an implementer of a struct. type StructBuilder struct { Name string - Fields code.StructFields + Fields []code.LegacyStructField } // Impl writes struct declatation code to the buffer. diff --git a/internal/codegen/struct_test.go b/internal/codegen/struct_test.go index c51697b..56aa09e 100644 --- a/internal/codegen/struct_test.go +++ b/internal/codegen/struct_test.go @@ -21,7 +21,7 @@ type User struct { func TestStructBuilderBuild(t *testing.T) { sb := codegen.StructBuilder{ Name: "User", - Fields: []code.StructField{ + Fields: []code.LegacyStructField{ { Name: "ID", Type: code.ExternalType{ @@ -32,18 +32,18 @@ func TestStructBuilderBuild(t *testing.T) { }, { Name: "Username", - Type: code.TypeString, + Type: code.SimpleType("string"), Tag: `bson:"username" json:"username"`, }, { Name: "Age", - Type: code.TypeInt, + Type: code.SimpleType("int"), Tag: `bson:"age"`, }, { Name: "orderCount", Type: code.PointerType{ - ContainedType: code.TypeInt, + ContainedType: code.SimpleType("int"), }, }, }, diff --git a/internal/generator/generator.go b/internal/generator/generator.go index 1fa7d4b..dc210c9 100644 --- a/internal/generator/generator.go +++ b/internal/generator/generator.go @@ -1,7 +1,8 @@ package generator import ( - "github.com/sunboyy/repogen/internal/code" + "go/types" + "github.com/sunboyy/repogen/internal/codegen" "github.com/sunboyy/repogen/internal/mongo" "github.com/sunboyy/repogen/internal/spec" @@ -9,14 +10,14 @@ import ( // GenerateRepository generates repository implementation code from repository // interface specification. -func GenerateRepository(packageName string, structModel code.Struct, +func GenerateRepository(pkg *types.Package, structModelName string, interfaceName string, methodSpecs []spec.MethodSpec) (string, error) { - generator := mongo.NewGenerator(structModel, interfaceName) + generator := mongo.NewGenerator(pkg, structModelName, interfaceName) codeBuilder := codegen.NewBuilder( "repogen", - packageName, + pkg.Name(), generator.Imports(), ) diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go index dc051a4..6e29f7c 100644 --- a/internal/generator/generator_test.go +++ b/internal/generator/generator_test.go @@ -1,6 +1,8 @@ package generator_test import ( + "go/token" + "go/types" "os" "testing" @@ -10,52 +12,43 @@ import ( "github.com/sunboyy/repogen/internal/testutils" ) -var ( - idField = code.StructField{ - Name: "ID", - Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}, - Tag: `bson:"_id,omitempty"`, - } - genderField = code.StructField{ - Name: "Gender", - Type: code.SimpleType("Gender"), - Tag: `bson:"gender"`, - } - ageField = code.StructField{ - Name: "Age", - Type: code.TypeInt, - Tag: `bson:"age"`, - } -) +func createSignature(params []*types.Var, results []*types.Var) *types.Signature { + return types.NewSignatureType(nil, nil, nil, types.NewTuple(params...), types.NewTuple(results...), false) +} + +func createTypeVar(t types.Type) *types.Var { + return types.NewVar(token.NoPos, nil, "", t) +} func TestGenerateMongoRepository(t *testing.T) { - userModel := code.Struct{ - Name: "UserModel", - Fields: code.StructFields{ - idField, - code.StructField{ - Name: "Username", - Type: code.TypeString, - Tag: `bson:"username"`, - }, - genderField, - ageField, - }, - } methods := []spec.MethodSpec{ // test find: One mode { Name: "FindByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{code.PointerType{ContainedType: code.SimpleType("UserModel")}, code.TypeError}, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeObjectIDNamed), + }, + []*types.Var{ + createTypeVar(types.NewPointer(testutils.TypeUserNamed)), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 1}, + { + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID"), + Tag: `bson:"_id,omitempty"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, }, }, }, @@ -63,29 +56,41 @@ func TestGenerateMongoRepository(t *testing.T) { // test find: Many mode, And operator, NOT and LessThan comparator { Name: "FindByGenderNotAndAgeLessThan", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(types.NewPointer(testutils.TypeUserNamed)), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Operator: spec.OperatorAnd, Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorNot, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + Comparator: spec.ComparatorNot, + ParamIndex: 1, }, { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorLessThan, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorLessThan, + ParamIndex: 2, }, }, }, @@ -93,101 +98,153 @@ func TestGenerateMongoRepository(t *testing.T) { }, { Name: "FindByAgeLessThanEqualOrderByAge", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserNamed))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorLessThanEqual, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorLessThanEqual, + ParamIndex: 1, }, }, }, Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingAscending}, + { + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Ordering: spec.OrderingAscending, + }, }, }, }, { Name: "FindByAgeGreaterThanOrderByAgeAsc", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserNamed))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorGreaterThan, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorGreaterThan, + ParamIndex: 1, }, }, }, Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingAscending}, + { + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Ordering: spec.OrderingAscending, + }, }, }, }, { Name: "FindByAgeGreaterThanEqualOrderByAgeDesc", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserNamed))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorGreaterThanEqual, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorGreaterThanEqual, + ParamIndex: 1, }, }, }, Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending}, + { + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Ordering: spec.OrderingDescending, + }, }, }, }, { Name: "FindByAgeBetween", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "fromAge", Type: code.TypeInt}, - {Name: "toAge", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserNamed))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorBetween, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorBetween, + ParamIndex: 1, }, }, }, @@ -195,29 +252,41 @@ func TestGenerateMongoRepository(t *testing.T) { }, { Name: "FindByGenderOrAge", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserNamed))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Operator: spec.OperatorOr, Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, }, { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, }, }, @@ -230,7 +299,7 @@ func TestGenerateMongoRepository(t *testing.T) { } expectedCode := string(expectedBytes) - code, err := generator.GenerateRepository("user", userModel, "UserRepository", methods) + code, err := generator.GenerateRepository(testutils.Pkg, "User", "UserRepository", methods) if err != nil { t.Fatal(err) diff --git a/internal/mongo/common.go b/internal/mongo/common.go index b47d5a8..1387989 100644 --- a/internal/mongo/common.go +++ b/internal/mongo/common.go @@ -1,6 +1,7 @@ package mongo import ( + "go/types" "strings" "github.com/sunboyy/repogen/internal/code" @@ -49,7 +50,8 @@ var ifErrReturnFalseErr = codegen.IfBlock{ } type baseMethodGenerator struct { - structModel code.Struct + pkg *types.Package + structModelName string } func (g baseMethodGenerator) bsonFieldReference(fieldReference spec.FieldReference) (string, error) { @@ -67,7 +69,7 @@ func (g baseMethodGenerator) bsonFieldReference(fieldReference spec.FieldReferen func (g baseMethodGenerator) bsonTagFromField(field code.StructField) (string, error) { bsonTag, ok := field.Tag.Lookup("bson") if !ok { - return "", NewBsonTagNotFoundError(field.Name) + return "", NewBsonTagNotFoundError(field.Var.Name()) } documentKey := strings.Split(bsonTag, ",")[0] diff --git a/internal/mongo/count_test.go b/internal/mongo/count_test.go index 87c9c0c..af7c195 100644 --- a/internal/mongo/count_test.go +++ b/internal/mongo/count_test.go @@ -2,6 +2,8 @@ package mongo_test import ( "fmt" + "go/token" + "go/types" "reflect" "testing" @@ -18,21 +20,28 @@ func TestGenerateMethod_Count(t *testing.T) { Name: "simple count method", MethodSpec: spec.MethodSpec{ Name: "CountByGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.CountOperation{ Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, }, }, }, @@ -50,28 +59,40 @@ func TestGenerateMethod_Count(t *testing.T) { Name: "count with And operator", MethodSpec: spec.MethodSpec{ Name: "CountByGenderAndCity", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.CountOperation{ Query: spec.QuerySpec{ Operator: spec.OperatorAnd, Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, }, { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, }, }, @@ -96,28 +117,40 @@ func TestGenerateMethod_Count(t *testing.T) { Name: "count with Or operator", MethodSpec: spec.MethodSpec{ Name: "CountByGenderOrCity", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.CountOperation{ Query: spec.QuerySpec{ Operator: spec.OperatorOr, Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, }, { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, }, }, @@ -142,21 +175,28 @@ func TestGenerateMethod_Count(t *testing.T) { Name: "count with Not comparator", MethodSpec: spec.MethodSpec{ Name: "CountByGenderNot", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.CountOperation{ Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorNot, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + Comparator: spec.ComparatorNot, + ParamIndex: 1, }, }, }, @@ -176,21 +216,28 @@ func TestGenerateMethod_Count(t *testing.T) { Name: "count with LessThan comparator", MethodSpec: spec.MethodSpec{ Name: "CountByAgeLessThan", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.CountOperation{ Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorLessThan, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorLessThan, + ParamIndex: 1, }, }, }, @@ -210,21 +257,28 @@ func TestGenerateMethod_Count(t *testing.T) { Name: "count with LessThanEqual comparator", MethodSpec: spec.MethodSpec{ Name: "CountByAgeLessThanEqual", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.CountOperation{ Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorLessThanEqual, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorLessThanEqual, + ParamIndex: 1, }, }, }, @@ -244,21 +298,28 @@ func TestGenerateMethod_Count(t *testing.T) { Name: "count with GreaterThan comparator", MethodSpec: spec.MethodSpec{ Name: "CountByAgeGreaterThan", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.CountOperation{ Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorGreaterThan, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorGreaterThan, + ParamIndex: 1, }, }, }, @@ -278,21 +339,28 @@ func TestGenerateMethod_Count(t *testing.T) { Name: "count with GreaterThanEqual comparator", MethodSpec: spec.MethodSpec{ Name: "CountByAgeGreaterThanEqual", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.CountOperation{ Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorGreaterThanEqual, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorGreaterThanEqual, + ParamIndex: 1, }, }, }, @@ -312,22 +380,29 @@ func TestGenerateMethod_Count(t *testing.T) { Name: "count with Between comparator", MethodSpec: spec.MethodSpec{ Name: "CountByAgeBetween", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.CountOperation{ Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorBetween, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorBetween, + ParamIndex: 1, }, }, }, @@ -348,21 +423,28 @@ func TestGenerateMethod_Count(t *testing.T) { Name: "count with In comparator", MethodSpec: spec.MethodSpec{ Name: "CountByAgeIn", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ArrayType{ContainedType: code.TypeInt}}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeObjectIDNamed), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.CountOperation{ Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorIn, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Comparator: spec.ComparatorIn, + ParamIndex: 1, }, }, }, @@ -382,18 +464,24 @@ func TestGenerateMethod_Count(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, "UserModel", "UserRepository") expectedReceiver := codegen.MethodReceiver{ Name: "r", Type: "UserRepositoryMongo", Pointer: true, } - var expectedParams []code.Param - for i, param := range testCase.MethodSpec.Params { - expectedParams = append(expectedParams, code.Param{ - Name: fmt.Sprintf("arg%d", i), - Type: param.Type, - }) + + params := testCase.MethodSpec.Signature.Params() + var expectedParamVars []*types.Var + for i := 0; i < params.Len(); i++ { + expectedParamVars = append(expectedParamVars, types.NewVar(token.NoPos, nil, fmt.Sprintf("arg%d", i), + params.At(i).Type())) + } + expectedParams := types.NewTuple(expectedParamVars...) + returns := testCase.MethodSpec.Signature.Results() + var expectedReturns []types.Type + for i := 0; i < returns.Len(); i++ { + expectedReturns = append(expectedReturns, returns.At(i).Type()) } actual, err := generator.GenerateMethod(testCase.MethodSpec) @@ -422,10 +510,10 @@ func TestGenerateMethod_Count(t *testing.T) { actual.Params, ) } - if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) { + if !reflect.DeepEqual(expectedReturns, actual.Returns) { t.Errorf( "incorrect struct returns: expected %+v, got %+v", - testCase.MethodSpec.Returns, + expectedReturns, actual.Returns, ) } diff --git a/internal/mongo/delete_test.go b/internal/mongo/delete_test.go index 1f16aa6..c852e3f 100644 --- a/internal/mongo/delete_test.go +++ b/internal/mongo/delete_test.go @@ -2,6 +2,8 @@ package mongo_test import ( "fmt" + "go/token" + "go/types" "reflect" "testing" @@ -18,19 +20,29 @@ func TestGenerateMethod_Delete(t *testing.T) { Name: "simple delete one method", MethodSpec: spec.MethodSpec{ Name: "DeleteByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{code.TypeBool, code.TypeError}, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeObjectIDNamed), + }, + []*types.Var{ + createTypeVar(code.TypeBool), + createTypeVar(code.TypeError), + }, + ), Operation: spec.DeleteOperation{ Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{idField}, - ParamIndex: 1, + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID"), + Tag: `bson:"_id,omitempty"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -48,22 +60,29 @@ func TestGenerateMethod_Delete(t *testing.T) { Name: "simple delete many method", MethodSpec: spec.MethodSpec{ Name: "DeleteByGender", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.DeleteOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -81,29 +100,41 @@ func TestGenerateMethod_Delete(t *testing.T) { Name: "delete with And operator", MethodSpec: spec.MethodSpec{ Name: "DeleteByGenderAndAge", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.DeleteOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Operator: spec.OperatorAnd, Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + ParamIndex: 1, }, { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 2, + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 2, }, }, }, @@ -128,29 +159,41 @@ func TestGenerateMethod_Delete(t *testing.T) { Name: "delete with Or operator", MethodSpec: spec.MethodSpec{ Name: "DeleteByGenderOrAge", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.DeleteOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Operator: spec.OperatorOr, Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + ParamIndex: 1, }, { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 2, + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 2, }, }, }, @@ -175,22 +218,29 @@ func TestGenerateMethod_Delete(t *testing.T) { Name: "delete with Not comparator", MethodSpec: spec.MethodSpec{ Name: "DeleteByGenderNot", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.DeleteOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorNot, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, + Comparator: spec.ComparatorNot, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -210,22 +260,29 @@ func TestGenerateMethod_Delete(t *testing.T) { Name: "delete with LessThan comparator", MethodSpec: spec.MethodSpec{ Name: "DeleteByAgeLessThan", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.DeleteOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorLessThan, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, + Comparator: spec.ComparatorLessThan, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -245,22 +302,29 @@ func TestGenerateMethod_Delete(t *testing.T) { Name: "delete with LessThanEqual comparator", MethodSpec: spec.MethodSpec{ Name: "DeleteByAgeLessThanEqual", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.DeleteOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorLessThanEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, + Comparator: spec.ComparatorLessThanEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -280,22 +344,29 @@ func TestGenerateMethod_Delete(t *testing.T) { Name: "delete with GreaterThan comparator", MethodSpec: spec.MethodSpec{ Name: "DeleteByAgeGreaterThan", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.DeleteOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorGreaterThan, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, + Comparator: spec.ComparatorGreaterThan, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -315,22 +386,29 @@ func TestGenerateMethod_Delete(t *testing.T) { Name: "delete with GreaterThanEqual comparator", MethodSpec: spec.MethodSpec{ Name: "DeleteByAgeGreaterThanEqual", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.DeleteOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorGreaterThanEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, + Comparator: spec.ComparatorGreaterThanEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -350,23 +428,30 @@ func TestGenerateMethod_Delete(t *testing.T) { Name: "delete with Between comparator", MethodSpec: spec.MethodSpec{ Name: "DeleteByAgeBetween", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "fromAge", Type: code.TypeInt}, - {Name: "toAge", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.DeleteOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorBetween, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, + Comparator: spec.ComparatorBetween, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -387,22 +472,29 @@ func TestGenerateMethod_Delete(t *testing.T) { Name: "delete with In comparator", MethodSpec: spec.MethodSpec{ Name: "DeleteByGenderIn", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.DeleteOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorIn, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, + Comparator: spec.ComparatorIn, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -422,18 +514,24 @@ func TestGenerateMethod_Delete(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, "UserModel", "UserRepository") expectedReceiver := codegen.MethodReceiver{ Name: "r", Type: "UserRepositoryMongo", Pointer: true, } - var expectedParams []code.Param - for i, param := range testCase.MethodSpec.Params { - expectedParams = append(expectedParams, code.Param{ - Name: fmt.Sprintf("arg%d", i), - Type: param.Type, - }) + + params := testCase.MethodSpec.Signature.Params() + var expectedParamVars []*types.Var + for i := 0; i < params.Len(); i++ { + expectedParamVars = append(expectedParamVars, types.NewVar(token.NoPos, nil, fmt.Sprintf("arg%d", i), + params.At(i).Type())) + } + expectedParams := types.NewTuple(expectedParamVars...) + returns := testCase.MethodSpec.Signature.Results() + var expectedReturns []types.Type + for i := 0; i < returns.Len(); i++ { + expectedReturns = append(expectedReturns, returns.At(i).Type()) } actual, err := generator.GenerateMethod(testCase.MethodSpec) @@ -462,10 +560,10 @@ func TestGenerateMethod_Delete(t *testing.T) { actual.Params, ) } - if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) { + if !reflect.DeepEqual(expectedReturns, actual.Returns) { t.Errorf( "incorrect struct returns: expected %+v, got %+v", - testCase.MethodSpec.Returns, + expectedReturns, actual.Returns, ) } diff --git a/internal/mongo/find.go b/internal/mongo/find.go index 632fc08..2a71434 100644 --- a/internal/mongo/find.go +++ b/internal/mongo/find.go @@ -46,7 +46,7 @@ func (g findBodyGenerator) generateFindOneBody(querySpec querySpec, return codegen.FunctionBody{ codegen.DeclStatement{ Name: "entity", - Type: code.SimpleType(g.structModel.Name), + Type: code.SimpleType(g.structModelName), }, codegen.IfBlock{ Condition: []codegen.Statement{ @@ -104,7 +104,7 @@ func (g findBodyGenerator) generateFindManyBody(querySpec querySpec, codegen.SliceStatement{ Type: code.ArrayType{ ContainedType: code.PointerType{ - ContainedType: code.SimpleType(g.structModel.Name), + ContainedType: code.SimpleType(g.structModelName), }, }, Values: []codegen.Statement{}, diff --git a/internal/mongo/find_test.go b/internal/mongo/find_test.go index 862cf68..850db7f 100644 --- a/internal/mongo/find_test.go +++ b/internal/mongo/find_test.go @@ -2,6 +2,8 @@ package mongo_test import ( "fmt" + "go/token" + "go/types" "reflect" "testing" @@ -18,22 +20,29 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "simple find one method", MethodSpec: spec.MethodSpec{ Name: "FindByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeObjectIDNamed), + }, + []*types.Var{ + createTypeVar(types.NewPointer(testutils.TypeUserNamed)), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{idField}, - ParamIndex: 1, + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID"), + Tag: `bson:"_id,omitempty"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -52,22 +61,29 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "simple find many method", MethodSpec: spec.MethodSpec{ Name: "FindByGender", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -91,22 +107,33 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with deep field reference", MethodSpec: spec.MethodSpec{ Name: "FindByNameFirst", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "firstName", Type: code.TypeString}, - }, - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeString), + }, + []*types.Var{ + createTypeVar(types.NewPointer(testutils.TypeUserStruct)), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{nameField, firstNameField}, - ParamIndex: 1, + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Name"), + Tag: `bson:"name"`, + }, + { + Var: testutils.FindStructFieldByName(testutils.TypeNameStruct, "First"), + Tag: `bson:"first"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -125,29 +152,41 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with And operator", MethodSpec: spec.MethodSpec{ Name: "FindByGenderAndAge", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Operator: spec.OperatorAnd, Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + ParamIndex: 1, }, { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 2, + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 2, }, }, }, @@ -178,29 +217,41 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with Or operator", MethodSpec: spec.MethodSpec{ Name: "FindByGenderOrAge", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Operator: spec.OperatorOr, Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + ParamIndex: 1, }, { - Comparator: spec.ComparatorEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 2, + Comparator: spec.ComparatorEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 2, }, }, }, @@ -231,22 +282,29 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with Not comparator", MethodSpec: spec.MethodSpec{ Name: "FindByGenderNot", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorNot, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, + Comparator: spec.ComparatorNot, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -272,22 +330,29 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with LessThan comparator", MethodSpec: spec.MethodSpec{ Name: "FindByAgeLessThan", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorLessThan, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, + Comparator: spec.ComparatorLessThan, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -313,22 +378,29 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with LessThanEqual comparator", MethodSpec: spec.MethodSpec{ Name: "FindByAgeLessThanEqual", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorLessThanEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, + Comparator: spec.ComparatorLessThanEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -354,22 +426,29 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with GreaterThan comparator", MethodSpec: spec.MethodSpec{ Name: "FindByAgeGreaterThan", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorGreaterThan, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, + Comparator: spec.ComparatorGreaterThan, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -395,22 +474,29 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with GreaterThanEqual comparator", MethodSpec: spec.MethodSpec{ Name: "FindByAgeGreaterThanEqual", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorGreaterThanEqual, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, + Comparator: spec.ComparatorGreaterThanEqual, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -436,23 +522,30 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with Between comparator", MethodSpec: spec.MethodSpec{ Name: "FindByAgeBetween", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "fromAge", Type: code.TypeInt}, - {Name: "toAge", Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + createTypeVar(code.TypeInt), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorBetween, - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, + Comparator: spec.ComparatorBetween, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -479,22 +572,29 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with In comparator", MethodSpec: spec.MethodSpec{ Name: "FindByGenderIn", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorIn, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, + Comparator: spec.ComparatorIn, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -520,22 +620,29 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with NotIn comparator", MethodSpec: spec.MethodSpec{ Name: "FindByGenderNotIn", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "gender", Type: code.ArrayType{ContainedType: code.SimpleType("Gender")}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeGenderNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorNotIn, - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, + Comparator: spec.ComparatorNotIn, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -561,21 +668,28 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with True comparator", MethodSpec: spec.MethodSpec{ Name: "FindByEnabledTrue", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorTrue, - FieldReference: spec.FieldReference{enabledField}, - ParamIndex: 1, + Comparator: spec.ComparatorTrue, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Enabled"), + Tag: `bson:"enabled"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -599,21 +713,28 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with False comparator", MethodSpec: spec.MethodSpec{ Name: "FindByEnabledFalse", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorFalse, - FieldReference: spec.FieldReference{enabledField}, - ParamIndex: 1, + Comparator: spec.ComparatorFalse, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Enabled"), + Tag: `bson:"enabled"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -637,21 +758,28 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with Exists comparator", MethodSpec: spec.MethodSpec{ Name: "FindByReferrerExists", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorExists, - FieldReference: spec.FieldReference{referrerField}, - ParamIndex: 1, + Comparator: spec.ComparatorExists, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Referrer"), + Tag: `bson:"referrer"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -677,21 +805,28 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with NotExists comparator", MethodSpec: spec.MethodSpec{ Name: "FindByReferrerNotExists", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - Comparator: spec.ComparatorNotExists, - FieldReference: spec.FieldReference{referrerField}, - ParamIndex: 1, + Comparator: spec.ComparatorNotExists, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Referrer"), + Tag: `bson:"referrer"`, + }, + }, + ParamIndex: 1, }, }, }, @@ -717,17 +852,27 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with sort ascending", MethodSpec: spec.MethodSpec{ Name: "FindAllOrderByAge", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingAscending}, + { + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Ordering: spec.OrderingAscending, + }, }, }, }, @@ -749,17 +894,27 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with sort descending", MethodSpec: spec.MethodSpec{ Name: "FindAllOrderByAgeDesc", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending}, + { + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Ordering: spec.OrderingDescending, + }, }, }, }, @@ -781,19 +936,30 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with deep sort ascending", MethodSpec: spec.MethodSpec{ Name: "FindAllOrderByNameFirst", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Sorts: []spec.Sort{ { - FieldReference: spec.FieldReference{nameField, firstNameField}, - Ordering: spec.OrderingAscending, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Name"), + Tag: `bson:"name"`, + }, + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "First"), + Tag: `bson:"first"`, + }, + }, + Ordering: spec.OrderingAscending, }, }, }, @@ -816,18 +982,36 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with multiple sorts", MethodSpec: spec.MethodSpec{ Name: "FindAllOrderByGenderAndAgeDesc", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{genderField}, Ordering: spec.OrderingAscending}, - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending}, + { + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + Ordering: spec.OrderingAscending, + }, + { + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Ordering: spec.OrderingDescending, + }, }, }, }, @@ -850,17 +1034,27 @@ func TestGenerateMethod_Find(t *testing.T) { Name: "find with limit", MethodSpec: spec.MethodSpec{ Name: "FindTop5AllOrderByAgeDesc", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserStruct))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeMany, Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending}, + { + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + Ordering: spec.OrderingDescending, + }, }, Limit: 5, }, @@ -883,18 +1077,24 @@ func TestGenerateMethod_Find(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, "UserModel", "UserRepository") expectedReceiver := codegen.MethodReceiver{ Name: "r", Type: "UserRepositoryMongo", Pointer: true, } - var expectedParams []code.Param - for i, param := range testCase.MethodSpec.Params { - expectedParams = append(expectedParams, code.Param{ - Name: fmt.Sprintf("arg%d", i), - Type: param.Type, - }) + + params := testCase.MethodSpec.Signature.Params() + var expectedParamVars []*types.Var + for i := 0; i < params.Len(); i++ { + expectedParamVars = append(expectedParamVars, types.NewVar(token.NoPos, nil, fmt.Sprintf("arg%d", i), + params.At(i).Type())) + } + expectedParams := types.NewTuple(expectedParamVars...) + returns := testCase.MethodSpec.Signature.Results() + var expectedReturns []types.Type + for i := 0; i < returns.Len(); i++ { + expectedReturns = append(expectedReturns, returns.At(i).Type()) } actual, err := generator.GenerateMethod(testCase.MethodSpec) @@ -923,10 +1123,10 @@ func TestGenerateMethod_Find(t *testing.T) { actual.Params, ) } - if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) { + if !reflect.DeepEqual(expectedReturns, actual.Returns) { t.Errorf( "incorrect struct returns: expected %+v, got %+v", - testCase.MethodSpec.Returns, + expectedReturns, actual.Returns, ) } diff --git a/internal/mongo/generator.go b/internal/mongo/generator.go index 44437a1..a3a595a 100644 --- a/internal/mongo/generator.go +++ b/internal/mongo/generator.go @@ -2,17 +2,21 @@ package mongo import ( "fmt" + "go/token" + "go/types" "github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/codegen" "github.com/sunboyy/repogen/internal/spec" + "golang.org/x/tools/go/packages" ) // NewGenerator creates a new instance of MongoDB repository generator -func NewGenerator(structModel code.Struct, interfaceName string) RepositoryGenerator { +func NewGenerator(pkg *types.Package, structModelName string, interfaceName string) RepositoryGenerator { return RepositoryGenerator{ baseMethodGenerator: baseMethodGenerator{ - structModel: structModel, + pkg: pkg, + structModelName: structModelName, }, InterfaceName: interfaceName, } @@ -45,7 +49,7 @@ func (g RepositoryGenerator) Imports() [][]code.Import { func (g RepositoryGenerator) GenerateStruct() codegen.StructBuilder { return codegen.StructBuilder{ Name: g.repoImplStructName(), - Fields: code.StructFields{ + Fields: []code.LegacyStructField{ { Name: "collection", Type: code.PointerType{ @@ -62,23 +66,21 @@ func (g RepositoryGenerator) GenerateStruct() codegen.StructBuilder { // GenerateConstructor creates codegen.FunctionBuilder of a constructor for // mongo repository implementation struct. func (g RepositoryGenerator) GenerateConstructor() (codegen.FunctionBuilder, error) { + mongoPkgs, err := packages.Load(&packages.Config{Mode: packages.NeedTypes}, "go.mongodb.org/mongo-driver/mongo") + if err != nil { + return codegen.FunctionBuilder{}, err + } + mongoPkg := mongoPkgs[0] + collectionObj := mongoPkg.Types.Scope().Lookup("Collection") + collectionType := collectionObj.Type() + return codegen.FunctionBuilder{ - Name: "New" + g.InterfaceName, - Params: []code.Param{ - { - Name: "collection", - Type: code.PointerType{ - ContainedType: code.ExternalType{ - PackageAlias: "mongo", - Name: "Collection", - }, - }, - }, - }, - Returns: []code.Type{ - code.PointerType{ - ContainedType: code.SimpleType(g.repoImplStructName()), - }, + Pkg: g.pkg, + Name: "New" + g.InterfaceName, + Params: types.NewTuple(types.NewVar(token.NoPos, nil, "collection", types.NewPointer(collectionType))), + Returns: []types.Type{ + types.NewPointer(types.NewNamed( + types.NewTypeName(token.NoPos, nil, g.repoImplStructName(), nil), nil, nil)), }, Body: codegen.FunctionBody{ codegen.ReturnStatement{ @@ -97,12 +99,16 @@ func (g RepositoryGenerator) GenerateConstructor() (codegen.FunctionBuilder, err // GenerateMethod creates codegen.MethodBuilder of repository method from the // provided method specification. func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec) (codegen.MethodBuilder, error) { - var params []code.Param - for i, param := range methodSpec.Params { - params = append(params, code.Param{ - Name: fmt.Sprintf("arg%d", i), - Type: param.Type, - }) + var paramVars []*types.Var + for i := 0; i < methodSpec.Signature.Params().Len(); i++ { + param := types.NewVar(token.NoPos, nil, fmt.Sprintf("arg%d", i), + methodSpec.Signature.Params().At(i).Type()) + paramVars = append(paramVars, param) + } + + var returns []types.Type + for i := 0; i < methodSpec.Signature.Results().Len(); i++ { + returns = append(returns, methodSpec.Signature.Results().At(i).Type()) } implementation, err := g.generateMethodImplementation(methodSpec) @@ -111,14 +117,15 @@ func (g RepositoryGenerator) GenerateMethod(methodSpec spec.MethodSpec) (codegen } return codegen.MethodBuilder{ + Pkg: g.pkg, Receiver: codegen.MethodReceiver{ Name: "r", Type: code.SimpleType(g.repoImplStructName()), Pointer: true, }, Name: methodSpec.Name, - Params: params, - Returns: methodSpec.Returns, + Params: types.NewTuple(paramVars...), + Returns: returns, Body: implementation, }, nil } diff --git a/internal/mongo/generator_test.go b/internal/mongo/generator_test.go index 9904a2b..bdb1afb 100644 --- a/internal/mongo/generator_test.go +++ b/internal/mongo/generator_test.go @@ -2,6 +2,8 @@ package mongo_test import ( "errors" + "go/token" + "go/types" "reflect" "testing" @@ -9,77 +11,11 @@ import ( "github.com/sunboyy/repogen/internal/codegen" "github.com/sunboyy/repogen/internal/mongo" "github.com/sunboyy/repogen/internal/spec" + "github.com/sunboyy/repogen/internal/testutils" ) -var ( - idField = code.StructField{ - Name: "ID", - Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}, - Tag: `bson:"_id,omitempty"`, - } - genderField = code.StructField{ - Name: "Gender", - Type: code.SimpleType("Gender"), - Tag: `bson:"gender"`, - } - ageField = code.StructField{ - Name: "Age", - Type: code.TypeInt, - Tag: `bson:"age"`, - } - nameField = code.StructField{ - Name: "Name", - Type: code.SimpleType("NameModel"), - Tag: `bson:"name"`, - } - referrerField = code.StructField{ - Name: "Referrer", - Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}, - Tag: `bson:"referrer"`, - } - consentHistoryField = code.StructField{ - Name: "ConsentHistory", - Type: code.ArrayType{ContainedType: code.SimpleType("ConsentHistory")}, - Tag: `bson:"consent_history"`, - } - enabledField = code.StructField{ - Name: "Enabled", - Type: code.TypeBool, - Tag: `bson:"enabled"`, - } - accessTokenField = code.StructField{ - Name: "AccessToken", - Type: code.TypeString, - } - - firstNameField = code.StructField{ - Name: "First", - Type: code.TypeString, - Tag: `bson:"first"`, - } -) - -var userModel = code.Struct{ - Name: "UserModel", - Fields: code.StructFields{ - idField, - code.StructField{ - Name: "Username", - Type: code.TypeString, - Tag: `bson:"username"`, - }, - genderField, - ageField, - nameField, - referrerField, - consentHistoryField, - enabledField, - accessTokenField, - }, -} - func TestImports(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, "User", "UserRepository") expected := [][]code.Import{ { {Path: "context"}, @@ -100,10 +36,10 @@ func TestImports(t *testing.T) { } func TestGenerateStruct(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, "User", "UserRepository") expected := codegen.StructBuilder{ Name: "UserRepositoryMongo", - Fields: []code.StructField{ + Fields: []code.LegacyStructField{ { Name: "collection", Type: code.PointerType{ @@ -135,22 +71,14 @@ func TestGenerateStruct(t *testing.T) { } func TestGenerateConstructor(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, "User", "UserRepository") expected := codegen.FunctionBuilder{ Name: "NewUserRepository", - Params: []code.Param{ - { - Name: "collection", - Type: code.PointerType{ - ContainedType: code.ExternalType{ - PackageAlias: "mongo", - Name: "Collection", - }, - }, - }, - }, - Returns: []code.Type{ - code.SimpleType("UserRepository"), + Params: types.NewTuple( + types.NewVar(token.NoPos, nil, "collection", types.NewPointer(testutils.TypeCollectionNamed)), + ), + Returns: []types.Type{ + types.NewNamed(types.NewTypeName(token.NoPos, nil, "UserRepository", nil), nil, nil), }, Body: codegen.FunctionBody{ codegen.ReturnStatement{ @@ -177,13 +105,30 @@ func TestGenerateConstructor(t *testing.T) { actual.Name, ) } - if !reflect.DeepEqual(expected.Params, actual.Params) { + if expected.Params.Len() != actual.Params.Len() { t.Errorf( - "incorrect struct params: expected %+v, got %+v", - expected.Params, - actual.Params, + "incorrect function params length: expected %d, got %d", + expected.Params.Len(), + actual.Params.Len(), ) } + for i := 0; i < expected.Params.Len(); i++ { + if expected.Params.At(i).Name() != actual.Params.At(i).Name() { + t.Errorf( + "incorrect function param name: expected %s, got %s", + expected.Params.At(i).Name(), + actual.Params.At(i).Name(), + ) + } + if expected.Params.At(i).Type().String() != actual.Params.At(i).Type().String() { + t.Errorf( + "incorrect function param type at %d: expected %s, got %s", + i, + expected.Params.At(i).Type(), + actual.Params.At(i).Type(), + ) + } + } if !reflect.DeepEqual(expected.Body, actual.Body) { t.Errorf("incorrect function body: expected %+v got %+v", expected.Body, @@ -217,14 +162,16 @@ func TestGenerateMethod_Invalid(t *testing.T) { Name: "operation not supported", Method: spec.MethodSpec{ Name: "SearchByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeObjectIDNamed), + }, + []*types.Var{ + createTypeVar(types.NewPointer(testutils.TypeUserNamed)), + createTypeVar(code.TypeError), + }, + ), Operation: StubOperation{}, }, ExpectedError: mongo.NewOperationNotSupportedError("Stub"), @@ -233,22 +180,28 @@ func TestGenerateMethod_Invalid(t *testing.T) { Name: "bson tag not found in query", Method: spec.MethodSpec{ Name: "FindByAccessToken", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeString), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserNamed))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{accessTokenField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "AccessToken"), + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, }, }, }, @@ -260,17 +213,26 @@ func TestGenerateMethod_Invalid(t *testing.T) { Name: "bson tag not found in sort", Method: spec.MethodSpec{ Name: "FindAllOrderByAccessToken", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserNamed))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.FindOperation{ Mode: spec.QueryModeOne, Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{accessTokenField}, Ordering: spec.OrderingAscending}, + { + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "AccessToken"), + }, + }, + Ordering: spec.OrderingAscending, + }, }, }, }, @@ -280,30 +242,41 @@ func TestGenerateMethod_Invalid(t *testing.T) { Name: "bson tag not found in update field", Method: spec.MethodSpec{ Name: "UpdateAccessTokenByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeString), + createTypeVar(testutils.TypeObjectIDNamed), + }, + []*types.Var{ + createTypeVar(code.TypeBool), + createTypeVar(code.TypeError), + }, + ), Operation: spec.UpdateOperation{ Update: spec.UpdateFields{ spec.UpdateField{ - FieldReference: spec.FieldReference{accessTokenField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "AccessToken"), + }, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, }, }, Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID"), + Tag: `bson:"_id,omitempty"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, }, }, @@ -315,24 +288,31 @@ func TestGenerateMethod_Invalid(t *testing.T) { Name: "update type not supported", Method: spec.MethodSpec{ Name: "UpdateAgeByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + createTypeVar(testutils.TypeObjectIDNamed), + }, + []*types.Var{ + createTypeVar(code.TypeBool), + createTypeVar(code.TypeError), + }, + ), Operation: spec.UpdateOperation{ Update: StubUpdate{}, Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID"), + Tag: `bson:"_id,omitempty"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, }, }, @@ -344,30 +324,42 @@ func TestGenerateMethod_Invalid(t *testing.T) { Name: "update operator not supported", Method: spec.MethodSpec{ Name: "UpdateConsentHistoryAppendByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + createTypeVar(testutils.TypeObjectIDNamed), + }, + []*types.Var{ + createTypeVar(code.TypeBool), + createTypeVar(code.TypeError), + }, + ), Operation: spec.UpdateOperation{ Update: spec.UpdateFields{ spec.UpdateField{ - FieldReference: spec.FieldReference{consentHistoryField}, - ParamIndex: 1, - Operator: "APPEND", + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ConsentHistory"), + Tag: `bson:"consent_history"`, + }, + }, + ParamIndex: 1, + Operator: "APPEND", }, }, Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID"), + Tag: `bson:"_id,omitempty"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, }, }, @@ -379,7 +371,7 @@ func TestGenerateMethod_Invalid(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, "UserModel", "UserRepository") _, err := generator.GenerateMethod(testCase.Method) diff --git a/internal/mongo/insert_test.go b/internal/mongo/insert_test.go index c921192..d5039e6 100644 --- a/internal/mongo/insert_test.go +++ b/internal/mongo/insert_test.go @@ -2,6 +2,8 @@ package mongo_test import ( "fmt" + "go/token" + "go/types" "reflect" "testing" @@ -12,20 +14,30 @@ import ( "github.com/sunboyy/repogen/internal/testutils" ) +func createSignature(params []*types.Var, results []*types.Var) *types.Signature { + return types.NewSignatureType(nil, nil, nil, types.NewTuple(params...), types.NewTuple(results...), false) +} + +func createTypeVar(t types.Type) *types.Var { + return types.NewVar(token.NoPos, nil, "", t) +} + func TestGenerateMethod_Insert(t *testing.T) { testTable := []GenerateMethodTestCase{ { Name: "insert one method", MethodSpec: spec.MethodSpec{ Name: "InsertOne", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "userModel", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - }, - Returns: []code.Type{ - code.InterfaceType{}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(types.NewPointer(testutils.TypeUserNamed)), + }, + []*types.Var{ + createTypeVar(types.NewInterfaceType(nil, nil)), + createTypeVar(code.TypeError), + }, + ), Operation: spec.InsertOperation{ Mode: spec.QueryModeOne, }, @@ -40,16 +52,16 @@ func TestGenerateMethod_Insert(t *testing.T) { Name: "insert many method", MethodSpec: spec.MethodSpec{ Name: "Insert", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "userModel", Type: code.ArrayType{ - ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}, - }}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.InterfaceType{}}, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(types.NewSlice(types.NewPointer(testutils.TypeUserNamed))), + }, + []*types.Var{ + createTypeVar(types.NewSlice(types.NewInterfaceType(nil, nil))), + createTypeVar(code.TypeError), + }, + ), Operation: spec.InsertOperation{ Mode: spec.QueryModeMany, }, @@ -68,18 +80,24 @@ func TestGenerateMethod_Insert(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, "UserModel", "UserRepository") expectedReceiver := codegen.MethodReceiver{ Name: "r", Type: "UserRepositoryMongo", Pointer: true, } - var expectedParams []code.Param - for i, param := range testCase.MethodSpec.Params { - expectedParams = append(expectedParams, code.Param{ - Name: fmt.Sprintf("arg%d", i), - Type: param.Type, - }) + + params := testCase.MethodSpec.Signature.Params() + var expectedParamVars []*types.Var + for i := 0; i < params.Len(); i++ { + expectedParamVars = append(expectedParamVars, types.NewVar(token.NoPos, nil, fmt.Sprintf("arg%d", i), + params.At(i).Type())) + } + expectedParams := types.NewTuple(expectedParamVars...) + returns := testCase.MethodSpec.Signature.Results() + var expectedReturns []types.Type + for i := 0; i < returns.Len(); i++ { + expectedReturns = append(expectedReturns, returns.At(i).Type()) } actual, err := generator.GenerateMethod(testCase.MethodSpec) @@ -108,10 +126,10 @@ func TestGenerateMethod_Insert(t *testing.T) { actual.Params, ) } - if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) { + if !reflect.DeepEqual(expectedReturns, actual.Returns) { t.Errorf( "incorrect struct returns: expected %+v, got %+v", - testCase.MethodSpec.Returns, + expectedReturns, actual.Returns, ) } diff --git a/internal/mongo/update_test.go b/internal/mongo/update_test.go index d67a6e5..e4ebfc0 100644 --- a/internal/mongo/update_test.go +++ b/internal/mongo/update_test.go @@ -2,6 +2,8 @@ package mongo_test import ( "fmt" + "go/token" + "go/types" "reflect" "testing" @@ -18,24 +20,31 @@ func TestGenerateMethod_Update(t *testing.T) { Name: "update model method", MethodSpec: spec.MethodSpec{ Name: "UpdateByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "model", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeUserNamed), + createTypeVar(testutils.TypeObjectIDNamed), + }, + []*types.Var{ + createTypeVar(code.TypeBool), + createTypeVar(code.TypeError), + }, + ), Operation: spec.UpdateOperation{ Update: spec.UpdateModel{}, Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID"), + Tag: `bson:"_id,omitempty"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, }, }, @@ -55,30 +64,42 @@ func TestGenerateMethod_Update(t *testing.T) { Name: "simple update one method", MethodSpec: spec.MethodSpec{ Name: "UpdateAgeByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + createTypeVar(testutils.TypeObjectIDNamed), + }, + []*types.Var{ + createTypeVar(code.TypeBool), + createTypeVar(code.TypeError), + }, + ), Operation: spec.UpdateOperation{ Update: spec.UpdateFields{ spec.UpdateField{ - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, }, }, Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID"), + Tag: `bson:"_id,omitempty"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, }, }, @@ -100,30 +121,42 @@ func TestGenerateMethod_Update(t *testing.T) { Name: "simple update many method", MethodSpec: spec.MethodSpec{ Name: "UpdateAgeByGender", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - {Name: "gender", Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + createTypeVar(testutils.TypeGenderNamed), + }, + []*types.Var{ + createTypeVar(code.TypeInt), + createTypeVar(code.TypeError), + }, + ), Operation: spec.UpdateOperation{ Update: spec.UpdateFields{ spec.UpdateField{ - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, }, }, Mode: spec.QueryModeMany, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + Tag: `bson:"gender"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, }, }, @@ -145,30 +178,42 @@ func TestGenerateMethod_Update(t *testing.T) { Name: "simple update push method", MethodSpec: spec.MethodSpec{ Name: "UpdateConsentHistoryPushByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "consentHistory", Type: code.SimpleType("ConsentHistory")}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(testutils.TypeConsentHistoryNamed), + createTypeVar(testutils.TypeObjectIDNamed), + }, + []*types.Var{ + createTypeVar(code.TypeBool), + createTypeVar(code.TypeError), + }, + ), Operation: spec.UpdateOperation{ Update: spec.UpdateFields{ spec.UpdateField{ - FieldReference: spec.FieldReference{consentHistoryField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorPush, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ConsentHistory"), + Tag: `bson:"consent_history"`, + }, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorPush, }, }, Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID"), + Tag: `bson:"_id,omitempty"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, }, }, @@ -190,30 +235,42 @@ func TestGenerateMethod_Update(t *testing.T) { Name: "simple update inc method", MethodSpec: spec.MethodSpec{ Name: "UpdateAgeIncByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "age", Type: code.TypeInt}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeInt), + createTypeVar(testutils.TypeObjectIDNamed), + }, + []*types.Var{ + createTypeVar(code.TypeBool), + createTypeVar(code.TypeError), + }, + ), Operation: spec.UpdateOperation{ Update: spec.UpdateFields{ spec.UpdateField{ - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorInc, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age"), + Tag: `bson:"age"`, + }, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorInc, }, }, Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID"), + Tag: `bson:"_id,omitempty"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, }, }, @@ -235,36 +292,53 @@ func TestGenerateMethod_Update(t *testing.T) { Name: "simple update set and push method", MethodSpec: spec.MethodSpec{ Name: "UpdateEnabledAndConsentHistoryPushByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "enabled", Type: code.TypeBool}, - {Name: "consentHistory", Type: code.SimpleType("ConsentHistory")}, - {Name: "gender", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeBool), + createTypeVar(testutils.TypeConsentHistoryNamed), + createTypeVar(testutils.TypeGenderNamed), + }, + []*types.Var{ + createTypeVar(code.TypeBool), + createTypeVar(code.TypeError), + }, + ), Operation: spec.UpdateOperation{ Update: spec.UpdateFields{ spec.UpdateField{ - FieldReference: spec.FieldReference{enabledField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Enabled"), + Tag: `bson:"enabled"`, + }, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, }, spec.UpdateField{ - FieldReference: spec.FieldReference{consentHistoryField}, - ParamIndex: 2, - Operator: spec.UpdateOperatorPush, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ConsentHistory"), + Tag: `bson:"consent_history"`, + }, + }, + ParamIndex: 2, + Operator: spec.UpdateOperatorPush, }, }, Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 3, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID"), + Tag: `bson:"_id,omitempty"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 3, }, }, }, @@ -289,30 +363,46 @@ func TestGenerateMethod_Update(t *testing.T) { Name: "update with deeply referenced field", MethodSpec: spec.MethodSpec{ Name: "UpdateNameFirstByID", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "firstName", Type: code.TypeString}, - {Name: "id", Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, + Signature: createSignature( + []*types.Var{ + createTypeVar(testutils.TypeContextNamed), + createTypeVar(code.TypeString), + createTypeVar(testutils.TypeObjectIDNamed), + }, + []*types.Var{ + createTypeVar(code.TypeBool), + createTypeVar(code.TypeError), + }, + ), Operation: spec.UpdateOperation{ Update: spec.UpdateFields{ spec.UpdateField{ - FieldReference: spec.FieldReference{nameField, firstNameField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Name"), + Tag: `bson:"name"`, + }, + { + Var: testutils.FindStructFieldByName(testutils.TypeNameStruct, "FirstName"), + Tag: `bson:"first"`, + }, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, }, }, Mode: spec.QueryModeOne, Query: spec.QuerySpec{ Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + { + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID"), + Tag: `bson:"_id,omitempty"`, + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, }, }, @@ -334,18 +424,24 @@ func TestGenerateMethod_Update(t *testing.T) { for _, testCase := range testTable { t.Run(testCase.Name, func(t *testing.T) { - generator := mongo.NewGenerator(userModel, "UserRepository") + generator := mongo.NewGenerator(testutils.Pkg, "UserModel", "UserRepository") expectedReceiver := codegen.MethodReceiver{ Name: "r", Type: "UserRepositoryMongo", Pointer: true, } - var expectedParams []code.Param - for i, param := range testCase.MethodSpec.Params { - expectedParams = append(expectedParams, code.Param{ - Name: fmt.Sprintf("arg%d", i), - Type: param.Type, - }) + + params := testCase.MethodSpec.Signature.Params() + var expectedParamVars []*types.Var + for i := 0; i < params.Len(); i++ { + expectedParamVars = append(expectedParamVars, types.NewVar(token.NoPos, nil, fmt.Sprintf("arg%d", i), + params.At(i).Type())) + } + expectedParams := types.NewTuple(expectedParamVars...) + returns := testCase.MethodSpec.Signature.Results() + var expectedReturns []types.Type + for i := 0; i < returns.Len(); i++ { + expectedReturns = append(expectedReturns, returns.At(i).Type()) } actual, err := generator.GenerateMethod(testCase.MethodSpec) @@ -374,10 +470,10 @@ func TestGenerateMethod_Update(t *testing.T) { actual.Params, ) } - if !reflect.DeepEqual(testCase.MethodSpec.Returns, actual.Returns) { + if !reflect.DeepEqual(expectedReturns, actual.Returns) { t.Errorf( "incorrect struct returns: expected %+v, got %+v", - testCase.MethodSpec.Returns, + expectedReturns, actual.Returns, ) } diff --git a/internal/spec/errors.go b/internal/spec/errors.go index 37fb336..07aea07 100644 --- a/internal/spec/errors.go +++ b/internal/spec/errors.go @@ -3,6 +3,7 @@ package spec import ( "errors" "fmt" + "go/types" "strings" "github.com/sunboyy/repogen/internal/code" @@ -20,7 +21,7 @@ var ( ) // NewUnsupportedReturnError creates unsupportedReturnError -func NewUnsupportedReturnError(givenType code.Type, index int) error { +func NewUnsupportedReturnError(givenType types.Type, index int) error { return unsupportedReturnError{ GivenType: givenType, Index: index, @@ -28,12 +29,12 @@ func NewUnsupportedReturnError(givenType code.Type, index int) error { } type unsupportedReturnError struct { - GivenType code.Type + GivenType types.Type Index int } func (err unsupportedReturnError) Error() string { - return fmt.Sprintf("return type '%s' at index %d is not supported", err.GivenType.Code(), err.Index) + return fmt.Sprintf("return type '%s' at index %d is not supported", err.GivenType.String(), err.Index) } // NewOperationReturnCountUnmatchedError creates @@ -79,7 +80,7 @@ func (err invalidSortError) Error() string { } // NewArgumentTypeNotMatchedError creates argumentTypeNotMatchedError -func NewArgumentTypeNotMatchedError(fieldName string, requiredType code.Type, givenType code.Type) error { +func NewArgumentTypeNotMatchedError(fieldName string, requiredType types.Type, givenType types.Type) error { return argumentTypeNotMatchedError{ FieldName: fieldName, RequiredType: requiredType, @@ -89,13 +90,13 @@ func NewArgumentTypeNotMatchedError(fieldName string, requiredType code.Type, gi type argumentTypeNotMatchedError struct { FieldName string - RequiredType code.Type - GivenType code.Type + RequiredType types.Type + GivenType types.Type } func (err argumentTypeNotMatchedError) Error() string { return fmt.Sprintf("field '%s' requires an argument of type '%s' (got '%s')", - err.FieldName, err.RequiredType.Code(), err.GivenType.Code()) + err.FieldName, err.RequiredType.String(), err.GivenType.String()) } // NewUnknownOperationError creates unknownOperationError @@ -139,7 +140,7 @@ type incompatibleComparatorError struct { func (err incompatibleComparatorError) Error() string { return fmt.Sprintf("cannot use comparator %s with struct field '%s' of type '%s'", - err.Comparator, err.Field.Name, err.Field.Type.Code()) + err.Comparator, err.Field.Var.Name(), err.Field.Var.Type()) } // NewIncompatibleUpdateOperatorError creates incompatibleUpdateOperatorError @@ -147,17 +148,17 @@ func NewIncompatibleUpdateOperatorError(updateOperator UpdateOperator, fieldRefe return incompatibleUpdateOperatorError{ UpdateOperator: updateOperator, ReferencingCode: fieldReference.ReferencingCode(), - ReferencedType: fieldReference.ReferencedField().Type, + ReferencedType: fieldReference.ReferencedField().Var.Type(), } } type incompatibleUpdateOperatorError struct { UpdateOperator UpdateOperator ReferencingCode string - ReferencedType code.Type + ReferencedType types.Type } func (err incompatibleUpdateOperatorError) Error() string { return fmt.Sprintf("cannot use update operator %s with struct field '%s' of type '%s'", - err.UpdateOperator, err.ReferencingCode, err.ReferencedType.Code()) + err.UpdateOperator, err.ReferencingCode, err.ReferencedType.String()) } diff --git a/internal/spec/errors_test.go b/internal/spec/errors_test.go index 9b75340..09dce14 100644 --- a/internal/spec/errors_test.go +++ b/internal/spec/errors_test.go @@ -1,6 +1,8 @@ package spec_test import ( + "go/token" + "go/types" "testing" "github.com/sunboyy/repogen/internal/code" @@ -26,8 +28,9 @@ func TestError(t *testing.T) { ExpectedString: "struct field 'PhoneNumber' not found", }, { - Name: "UnsupportedReturnError", - Error: spec.NewUnsupportedReturnError(code.SimpleType("User"), 0), + Name: "UnsupportedReturnError", + Error: spec.NewUnsupportedReturnError(types.NewNamed( + types.NewTypeName(token.NoPos, nil, "User", nil), nil, nil), 0), ExpectedString: "return type 'User' at index 0 is not supported", }, { @@ -43,8 +46,7 @@ func TestError(t *testing.T) { { Name: "IncompatibleComparatorError", Error: spec.NewIncompatibleComparatorError(spec.ComparatorTrue, code.StructField{ - Name: "Age", - Type: code.TypeInt, + Var: types.NewVar(token.NoPos, nil, "Age", code.TypeInt), }), ExpectedString: "cannot use comparator EQUAL_TRUE with struct field 'Age' of type 'int'", }, @@ -62,8 +64,7 @@ func TestError(t *testing.T) { Name: "IncompatibleUpdateOperatorError", Error: spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorInc, spec.FieldReference{ code.StructField{ - Name: "City", - Type: code.TypeString, + Var: types.NewVar(token.NoPos, nil, "City", code.TypeString), }, }), ExpectedString: "cannot use update operator INC with struct field 'City' of type 'string'", diff --git a/internal/spec/field.go b/internal/spec/field.go index 0d4d4a7..bac5950 100644 --- a/internal/spec/field.go +++ b/internal/spec/field.go @@ -1,6 +1,8 @@ package spec import ( + "go/types" + "reflect" "strings" "github.com/sunboyy/repogen/internal/code" @@ -19,56 +21,71 @@ func (r FieldReference) ReferencedField() code.StructField { func (r FieldReference) ReferencingCode() string { var fieldNames []string for _, field := range r { - fieldNames = append(fieldNames, field.Name) + fieldNames = append(fieldNames, field.Var.Name()) } return strings.Join(fieldNames, ".") } -type fieldResolver struct { - Structs map[string]code.Struct -} - -func (r fieldResolver) ResolveStructField(structModel code.Struct, tokens []string) (FieldReference, bool) { +func resolveStructField(structModel *types.Struct, tokens []string) (FieldReference, bool) { fieldName := strings.Join(tokens, "") - field, ok := structModel.Fields.ByName(fieldName) - if ok { - return FieldReference{field}, true + for i := 0; i < structModel.NumFields(); i++ { + field := structModel.Field(i) + if field.Name() == fieldName { + return FieldReference{ + code.StructField{ + Var: field, + Tag: reflect.StructTag(structModel.Tag(i)), + }, + }, true + } } for i := len(tokens) - 1; i > 0; i-- { fieldName := strings.Join(tokens[:i], "") - field, ok := structModel.Fields.ByName(fieldName) - if !ok { - continue + var foundField *types.Var + var foundFieldIndex int + for j := 0; j < structModel.NumFields(); j++ { + field := structModel.Field(j) + if field.Name() == fieldName { + foundField = field + foundFieldIndex = j + break + } } - - fieldSimpleType, ok := getSimpleType(field.Type) - if !ok { + if foundField == nil { continue } - childStruct, ok := r.Structs[fieldSimpleType.Code()] + underlyingStructType, ok := getUnderlyingStructType(foundField.Type()) if !ok { continue } - fields, ok := r.ResolveStructField(childStruct, tokens[i:]) + fields, ok := resolveStructField(underlyingStructType, tokens[i:]) if !ok { continue } - return append(FieldReference{field}, fields...), true + + return append(FieldReference{ + code.StructField{ + Var: foundField, + Tag: reflect.StructTag(structModel.Tag(foundFieldIndex)), + }, + }, fields...), true } return nil, false } -func getSimpleType(t code.Type) (code.SimpleType, bool) { +func getUnderlyingStructType(t types.Type) (*types.Struct, bool) { switch t := t.(type) { - case code.SimpleType: + case *types.Named: + return getUnderlyingStructType(t.Underlying()) + case *types.Struct: return t, true - case code.PointerType: - return getSimpleType(t.ContainedType) + case *types.Pointer: + return getUnderlyingStructType(t.Elem()) default: - return "", false + return nil, false } } diff --git a/internal/spec/models.go b/internal/spec/models.go index 5c856cf..cdb581e 100644 --- a/internal/spec/models.go +++ b/internal/spec/models.go @@ -1,8 +1,6 @@ package spec -import ( - "github.com/sunboyy/repogen/internal/code" -) +import "go/types" // QueryMode one or many type QueryMode string @@ -16,8 +14,7 @@ const ( // MethodSpec is a method specification inside repository specification type MethodSpec struct { Name string - Params []code.Param - Returns []code.Type + Signature *types.Signature Operation Operation } diff --git a/internal/spec/parser.go b/internal/spec/parser.go index aa5d9c6..ec01730 100644 --- a/internal/spec/parser.go +++ b/internal/spec/parser.go @@ -1,6 +1,7 @@ package spec import ( + "go/types" "strconv" "github.com/fatih/camelcase" @@ -9,24 +10,24 @@ import ( // ParseInterfaceMethod returns repository method spec from declared interface // method. -func ParseInterfaceMethod(structs map[string]code.Struct, structModel code.Struct, - method code.Method) (MethodSpec, error) { +func ParseInterfaceMethod(pkg *types.Package, namedStruct *types.Named, + method *types.Func) (MethodSpec, error) { parser := interfaceMethodParser{ - fieldResolver: fieldResolver{ - Structs: structs, - }, - StructModel: structModel, - Method: method, + NamedStruct: namedStruct, + UnderlyingStruct: namedStruct.Underlying().(*types.Struct), + Method: method, + Signature: method.Type().(*types.Signature), } return parser.Parse() } type interfaceMethodParser struct { - fieldResolver fieldResolver - StructModel code.Struct - Method code.Method + NamedStruct *types.Named + UnderlyingStruct *types.Struct + Method *types.Func + Signature *types.Signature } func (p interfaceMethodParser) Parse() (MethodSpec, error) { @@ -36,15 +37,14 @@ func (p interfaceMethodParser) Parse() (MethodSpec, error) { } return MethodSpec{ - Name: p.Method.Name, - Params: p.Method.Params, - Returns: p.Method.Returns, + Name: p.Method.Name(), + Signature: p.Signature, Operation: operation, }, nil } func (p interfaceMethodParser) parseMethod() (Operation, error) { - methodNameTokens := camelcase.Split(p.Method.Name) + methodNameTokens := camelcase.Split(p.Method.Name()) switch methodNameTokens[0] { case "Insert": return p.parseInsertOperation(methodNameTokens[1:]) @@ -61,7 +61,9 @@ func (p interfaceMethodParser) parseMethod() (Operation, error) { } func (p interfaceMethodParser) parseInsertOperation(tokens []string) (Operation, error) { - mode, err := p.extractInsertReturns(p.Method.Returns) + signature := p.Method.Type().(*types.Signature) + + mode, err := p.extractInsertReturns(signature.Results()) if err != nil { return nil, err } @@ -70,13 +72,13 @@ func (p interfaceMethodParser) parseInsertOperation(tokens []string) (Operation, return nil, err } - pointerType := code.PointerType{ContainedType: p.StructModel.ReferencedType()} - if mode == QueryModeOne && p.Method.Params[1].Type != pointerType { + pointerType := types.NewPointer(p.NamedStruct) + if mode == QueryModeOne && !types.Identical(signature.Params().At(1).Type(), pointerType) { return nil, ErrInvalidParam } - arrayType := code.ArrayType{ContainedType: pointerType} - if mode == QueryModeMany && p.Method.Params[1].Type != arrayType { + arrayType := types.NewSlice(pointerType) + if mode == QueryModeMany && !types.Identical(signature.Params().At(1).Type(), arrayType) { return nil, ErrInvalidParam } @@ -85,33 +87,35 @@ func (p interfaceMethodParser) parseInsertOperation(tokens []string) (Operation, }, nil } -func (p interfaceMethodParser) extractInsertReturns(returns []code.Type) (QueryMode, error) { - if len(returns) != 2 { +func (p interfaceMethodParser) extractInsertReturns(returns *types.Tuple) (QueryMode, error) { + if returns.Len() != 2 { return "", NewOperationReturnCountUnmatchedError(2) } - if returns[1] != code.TypeError { - return "", NewUnsupportedReturnError(returns[1], 1) + if !types.Identical(returns.At(1).Type(), code.TypeError) { + return "", NewUnsupportedReturnError(returns.At(1).Type(), 1) } - switch t := returns[0].(type) { - case code.InterfaceType: - if len(t.Methods) == 0 { + switch t := returns.At(0).Type().(type) { + case *types.Interface: + if t.Empty() { return QueryModeOne, nil } - case code.ArrayType: - interfaceType, ok := t.ContainedType.(code.InterfaceType) - if ok && len(interfaceType.Methods) == 0 { + case *types.Slice: + interfaceType, ok := t.Elem().(*types.Interface) + if ok && interfaceType.Empty() { return QueryModeMany, nil } } - return "", NewUnsupportedReturnError(returns[0], 0) + return "", NewUnsupportedReturnError(returns.At(0).Type(), 0) } func (p interfaceMethodParser) parseFindOperation(tokens []string) (Operation, error) { - mode, err := p.extractModelOrSliceReturns(p.Method.Returns) + signature := p.Method.Type().(*types.Signature) + + mode, err := p.extractModelOrSliceReturns(signature.Results()) if err != nil { return nil, err } @@ -140,7 +144,7 @@ func (p interfaceMethodParser) parseFindOperation(tokens []string) (Operation, e return nil, err } - if err := p.validateQueryFromParams(p.Method.Params[1:], querySpec); err != nil { + if err := p.validateQueryFromParams(signature.Params(), 1, querySpec); err != nil { return nil, err } @@ -207,7 +211,7 @@ func (p interfaceMethodParser) parseSortToken(t []string) (Sort, error) { } func (p interfaceMethodParser) createSort(t []string, ordering Ordering) (Sort, error) { - fields, ok := p.fieldResolver.ResolveStructField(p.StructModel, t) + fields, ok := resolveStructField(p.UnderlyingStruct, t) if !ok { return Sort{}, NewStructFieldNotFoundError(t) } @@ -234,33 +238,31 @@ func (p interfaceMethodParser) splitQueryAndSortTokens(tokens []string) ([]strin return queryTokens, sortTokens } -func (p interfaceMethodParser) extractModelOrSliceReturns(returns []code.Type) (QueryMode, error) { - if len(returns) != 2 { +func (p interfaceMethodParser) extractModelOrSliceReturns(returns *types.Tuple) (QueryMode, error) { + if returns.Len() != 2 { return "", NewOperationReturnCountUnmatchedError(2) } - if returns[1] != code.TypeError { - return "", NewUnsupportedReturnError(returns[1], 1) + if !types.Identical(returns.At(1).Type(), code.TypeError) { + return "", NewUnsupportedReturnError(returns.At(1).Type(), 1) } - switch t := returns[0].(type) { - case code.PointerType: - simpleType := t.ContainedType - if simpleType == code.SimpleType(p.StructModel.Name) { + switch t := returns.At(0).Type().(type) { + case *types.Pointer: + if t.Elem() == p.NamedStruct { return QueryModeOne, nil } - case code.ArrayType: - pointerType, ok := t.ContainedType.(code.PointerType) + case *types.Slice: + pointerType, ok := t.Elem().(*types.Pointer) if ok { - simpleType := pointerType.ContainedType - if simpleType == code.SimpleType(p.StructModel.Name) { + if pointerType.Elem() == p.NamedStruct { return QueryModeMany, nil } } } - return "", NewUnsupportedReturnError(returns[0], 0) + return "", NewUnsupportedReturnError(returns.At(0).Type(), 0) } func splitByAnd(tokens []string) ([][]string, bool) { @@ -302,7 +304,9 @@ func (p interfaceMethodParser) splitUpdateAndQueryTokens(tokens []string) ([]str } func (p interfaceMethodParser) parseDeleteOperation(tokens []string) (Operation, error) { - mode, err := p.extractIntOrBoolReturns(p.Method.Returns) + signature := p.Method.Type().(*types.Signature) + + mode, err := p.extractIntOrBoolReturns(signature.Results()) if err != nil { return nil, err } @@ -316,7 +320,7 @@ func (p interfaceMethodParser) parseDeleteOperation(tokens []string) (Operation, return nil, err } - if err := p.validateQueryFromParams(p.Method.Params[1:], querySpec); err != nil { + if err := p.validateQueryFromParams(signature.Params(), 1, querySpec); err != nil { return nil, err } @@ -327,7 +331,8 @@ func (p interfaceMethodParser) parseDeleteOperation(tokens []string) (Operation, } func (p interfaceMethodParser) parseCountOperation(tokens []string) (Operation, error) { - if err := p.validateCountReturns(p.Method.Returns); err != nil { + signature := p.Method.Type().(*types.Signature) + if err := p.validateCountReturns(signature.Results()); err != nil { return nil, err } @@ -340,7 +345,7 @@ func (p interfaceMethodParser) parseCountOperation(tokens []string) (Operation, return nil, err } - if err := p.validateQueryFromParams(p.Method.Params[1:], querySpec); err != nil { + if err := p.validateQueryFromParams(signature.Params(), 1, querySpec); err != nil { return nil, err } @@ -349,73 +354,73 @@ func (p interfaceMethodParser) parseCountOperation(tokens []string) (Operation, }, nil } -func (p interfaceMethodParser) validateCountReturns(returns []code.Type) error { - if len(returns) != 2 { +func (p interfaceMethodParser) validateCountReturns(returns *types.Tuple) error { + if returns.Len() != 2 { return NewOperationReturnCountUnmatchedError(2) } - if returns[0] != code.TypeInt { - return NewUnsupportedReturnError(returns[0], 0) + if !types.Identical(returns.At(0).Type(), code.TypeInt) { + return NewUnsupportedReturnError(returns.At(0).Type(), 0) } - if returns[1] != code.TypeError { - return NewUnsupportedReturnError(returns[1], 1) + if !types.Identical(returns.At(1).Type(), code.TypeError) { + return NewUnsupportedReturnError(returns.At(1).Type(), 1) } return nil } -func (p interfaceMethodParser) extractIntOrBoolReturns(returns []code.Type) (QueryMode, error) { - if len(returns) != 2 { +func (p interfaceMethodParser) extractIntOrBoolReturns(returns *types.Tuple) (QueryMode, error) { + if returns.Len() != 2 { return "", NewOperationReturnCountUnmatchedError(2) } - if returns[1] != code.TypeError { - return "", NewUnsupportedReturnError(returns[1], 1) + if !types.Identical(returns.At(1).Type(), code.TypeError) { + return "", NewUnsupportedReturnError(returns.At(1).Type(), 1) } - simpleType, ok := returns[0].(code.SimpleType) + basicType, ok := returns.At(0).Type().(*types.Basic) if ok { - if simpleType == code.TypeBool { + if types.Identical(basicType, code.TypeBool) { return QueryModeOne, nil } - if simpleType == code.TypeInt { + if types.Identical(basicType, code.TypeInt) { return QueryModeMany, nil } } - return "", NewUnsupportedReturnError(returns[0], 0) + return "", NewUnsupportedReturnError(returns.At(0).Type(), 0) } func (p interfaceMethodParser) validateContextParam() error { - contextType := code.ExternalType{PackageAlias: "context", Name: "Context"} - if len(p.Method.Params) == 0 || p.Method.Params[0].Type != contextType { + signature := p.Method.Type().(*types.Signature) + if signature.Params().Len() == 0 || signature.Params().At(0).Type().String() != "context.Context" { return ErrContextParamRequired } return nil } -func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, querySpec QuerySpec) error { - if querySpec.NumberOfArguments() != len(params) { +func (p interfaceMethodParser) validateQueryFromParams(params *types.Tuple, startIndex int, querySpec QuerySpec) error { + if params.Len()-startIndex != querySpec.NumberOfArguments() { return ErrInvalidParam } - var currentParamIndex int + currentParamIndex := startIndex for _, predicate := range querySpec.Predicates { if (predicate.Comparator == ComparatorTrue || predicate.Comparator == ComparatorFalse) && - predicate.FieldReference.ReferencedField().Type != code.TypeBool { + !types.Identical(predicate.FieldReference.ReferencedField().Var.Type(), code.TypeBool) { return NewIncompatibleComparatorError(predicate.Comparator, predicate.FieldReference.ReferencedField()) } for i := 0; i < predicate.Comparator.NumberOfArguments(); i++ { requiredType := predicate.Comparator.ArgumentTypeFromFieldType( - predicate.FieldReference.ReferencedField().Type, + predicate.FieldReference.ReferencedField().Var.Type(), ) - if params[currentParamIndex].Type != requiredType { + if !types.Identical(params.At(currentParamIndex).Type(), requiredType) { return NewArgumentTypeNotMatchedError(predicate.FieldReference.ReferencingCode(), requiredType, - params[currentParamIndex].Type) + params.At(currentParamIndex).Type()) } currentParamIndex++ } @@ -426,8 +431,7 @@ func (p interfaceMethodParser) validateQueryFromParams(params []code.Param, quer func (p interfaceMethodParser) parseQuery(queryTokens []string, paramIndex int) (QuerySpec, error) { queryParser := queryParser{ - fieldResolver: p.fieldResolver, - StructModel: p.StructModel, + UnderlyingStruct: p.UnderlyingStruct, } return queryParser.parseQuery(queryTokens, paramIndex) } diff --git a/internal/spec/parser_test.go b/internal/spec/parser_test.go index 916e1d2..7716a1c 100644 --- a/internal/spec/parser_test.go +++ b/internal/spec/parser_test.go @@ -2,1467 +2,995 @@ package spec_test import ( "errors" + "go/types" "reflect" "testing" "github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/spec" + "github.com/sunboyy/repogen/internal/testutils" ) -var ( - idField = code.StructField{ - Name: "ID", - Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}, - } - phoneNumberField = code.StructField{ - Name: "PhoneNumber", - Type: code.TypeString, - } - genderField = code.StructField{ - Name: "Gender", - Type: code.SimpleType("Gender"), - } - cityField = code.StructField{ - Name: "City", - Type: code.TypeString, - } - ageField = code.StructField{ - Name: "Age", - Type: code.TypeInt, - } - nameField = code.StructField{ - Name: "Name", - Type: code.SimpleType("NameModel"), - } - contactField = code.StructField{ - Name: "Contact", - Type: code.SimpleType("ContactModel"), - } - referrerField = code.StructField{ - Name: "Referrer", - Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}, - } - defaultPaymentField = code.StructField{ - Name: "DefaultPayment", - Type: code.ExternalType{PackageAlias: "payment", Name: "Payment"}, - } - enabledField = code.StructField{ - Name: "Enabled", - Type: code.TypeBool, - } - consentHistoryField = code.StructField{ - Name: "ConsentHistory", - Type: code.ArrayType{ContainedType: code.SimpleType("ConsentHistoryItem")}, - } - - firstNameField = code.StructField{ - Name: "First", - Type: code.TypeString, - } - lastNameField = code.StructField{ - Name: "Last", - Type: code.TypeString, - } -) +func TestParseInterfaceMethod_Insert(t *testing.T) { + repoIntf := testutils.Pkg.Scope().Lookup("UserRepositoryInsert").Type().Underlying().(*types.Interface) -var ( - nameStruct = code.Struct{ - Name: "NameModel", - Fields: code.StructFields{ - firstNameField, - lastNameField, + expectedOperations := []spec.Operation{ + // InsertMany + spec.InsertOperation{ + Mode: spec.QueryModeMany, }, - } - - structModel = code.Struct{ - Name: "UserModel", - Fields: code.StructFields{ - idField, - phoneNumberField, - genderField, - cityField, - ageField, - nameField, - contactField, - referrerField, - defaultPaymentField, - consentHistoryField, - enabledField, + // InsertOne + spec.InsertOperation{ + Mode: spec.QueryModeOne, }, } -) - -var structs = map[string]code.Struct{ - nameStruct.Name: nameStruct, - structModel.Name: structModel, -} - -type ParseInterfaceMethodTestCase struct { - Name string - Method code.Method - ExpectedOperation spec.Operation -} -func TestParseInterfaceMethod_Insert(t *testing.T) { - testTable := []ParseInterfaceMethodTestCase{ - { - Name: "InsertOne method", - Method: code.Method{ - Name: "InsertOne", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - }, - Returns: []code.Type{ - code.InterfaceType{}, - code.TypeError, - }, - }, - ExpectedOperation: spec.InsertOperation{ - Mode: spec.QueryModeOne, - }, - }, - { - Name: "InsertMany method", - Method: code.Method{ - Name: "InsertMany", - Params: []code.Param{ - { - Type: code.ExternalType{ - PackageAlias: "context", - Name: "Context", - }, - }, - { - Type: code.ArrayType{ - ContainedType: code.PointerType{ - ContainedType: code.SimpleType("UserModel"), - }, - }, - }, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.InterfaceType{}}, - code.TypeError, - }, - }, - ExpectedOperation: spec.InsertOperation{ - Mode: spec.QueryModeMany, - }, - }, - } + for i := 0; i < repoIntf.NumMethods(); i++ { + method := repoIntf.Method(i) - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - actualSpec, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) + t.Run(method.Name(), func(t *testing.T) { + actualSpec, err := spec.ParseInterfaceMethod(testutils.Pkg, testutils.TypeUserNamed, method) if err != nil { t.Errorf("Error = %s", err) } - expectedOutput := spec.MethodSpec{ - Name: testCase.Method.Name, - Params: testCase.Method.Params, - Returns: testCase.Method.Returns, - Operation: testCase.ExpectedOperation, + if method.Name() != actualSpec.Name { + t.Errorf("Expected = %+v\nReceived = %+v", method.Name(), actualSpec.Name) } - if !reflect.DeepEqual(actualSpec, expectedOutput) { - t.Errorf("Expected = %+v\nReceived = %+v", expectedOutput, actualSpec) + if !types.Identical(method.Type(), actualSpec.Signature) { + t.Errorf("Expected = %+v\nReceived = %+v", method.Type(), actualSpec.Signature) + } + if !reflect.DeepEqual(expectedOperations[i], actualSpec.Operation) { + t.Errorf("Expected = %+v\nReceived = %+v", expectedOperations[i], actualSpec.Operation) } }) } } func TestParseInterfaceMethod_Find(t *testing.T) { - testTable := []ParseInterfaceMethodTestCase{ - { - Name: "FindByArg one-mode method", - Method: code.Method{ - Name: "FindByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.TypeError, - }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 1}, - }}, - }, - }, - { - Name: "FindByArg many-mode method", - Method: code.Method{ - Name: "FindByCity", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + repoIntf := testutils.Pkg.Scope().Lookup("UserRepositoryFind").Type().Underlying().(*types.Interface) + + expectedOperations := []spec.Operation{ + // FindAll + spec.FindOperation{ + Mode: spec.QueryModeMany, + }, + // FindByAgeBetween + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, + }, + Comparator: spec.ComparatorBetween, + ParamIndex: 1, + }, + }}, + }, + // FindByAgeGreaterThan + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, + }, + Comparator: spec.ComparatorGreaterThan, + ParamIndex: 1, + }, + }}, + }, + // FindByAgeGreaterThanEqual + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, + }, + Comparator: spec.ComparatorGreaterThanEqual, + ParamIndex: 1, + }, + }}, + }, + // FindByAgeLessThan + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, + }, + Comparator: spec.ComparatorLessThan, + ParamIndex: 1, + }, + }}, + }, + // FindByAgeLessThanEqual + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, + }, + Comparator: spec.ComparatorLessThanEqual, + ParamIndex: 1, + }, + }}, + }, + // FindByCity + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{cityField}, Comparator: spec.ComparatorEqual, ParamIndex: 1}, - }}, - }, + }}, }, - { - Name: "FindByMultiWordArg method", - Method: code.Method{ - Name: "FindByPhoneNumber", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.TypeError, - }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ + // FindByCityAndGender + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorAnd, + Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{phoneNumberField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, }, - }}, - }, - }, - { - Name: "FindByDeepArg method", - Method: code.Method{ - Name: "FindByNameFirst", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{nameField, firstNameField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, - }}, - }, - }, - { - Name: "FindByDeepPointerArg method", - Method: code.Method{ - Name: "FindByReferrerID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, }, }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{referrerField, idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, - }, - }}, - }, }, - { - Name: "FindAll method", - Method: code.Method{ - Name: "FindAll", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + // FindByCityIn + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, + }, + Comparator: spec.ComparatorIn, + ParamIndex: 1, + }, + }}, + }, + // FindByCityNot + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, + }, + Comparator: spec.ComparatorNot, + ParamIndex: 1, + }, + }}, + }, + // FindByCityNotIn + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, + }, + Comparator: spec.ComparatorNotIn, + ParamIndex: 1, }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - }, + }}, }, - { - Name: "FindByArgAndArg method", - Method: code.Method{ - Name: "FindByCityAndGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Operator: spec.OperatorAnd, - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{cityField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, + // FindByCityOrGender + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorOr, + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, }, - { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender")}, }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, }, }, }, - { - Name: "FindByArgOrArg method", - Method: code.Method{ - Name: "FindByCityOrGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Operator: spec.OperatorOr, - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{cityField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, - }, - { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, - }, + // FindByCityOrderByAge + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + }}, + Sorts: []spec.Sort{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, + }, + Ordering: spec.OrderingAscending, }, }, }, - { - Name: "FindByArgNot method", - Method: code.Method{ - Name: "FindByCityNot", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + // FindByCityOrderByAgeAsc + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + }}, + Sorts: []spec.Sort{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, + }, + Ordering: spec.OrderingAscending, }, }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{cityField}, Comparator: spec.ComparatorNot, ParamIndex: 1}, - }}, - }, }, - { - Name: "FindByArgLessThan method", - Method: code.Method{ - Name: "FindByAgeLessThan", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + // FindByCityOrderByAgeDesc + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + }}, + Sorts: []spec.Sort{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, + }, + Ordering: spec.OrderingDescending, }, }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{ageField}, Comparator: spec.ComparatorLessThan, ParamIndex: 1}, - }}, - }, }, - { - Name: "FindByArgLessThanEqual method", - Method: code.Method{ - Name: "FindByAgeLessThanEqual", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorLessThanEqual, - ParamIndex: 1, + // FindByCityOrderByCityAndAgeDesc + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, }, - }}, - }, - }, - { - Name: "FindByArgGreaterThan method", - Method: code.Method{ - Name: "FindByAgeGreaterThan", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorGreaterThan, - ParamIndex: 1, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + }}, + Sorts: []spec.Sort{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, }, - }}, - }, - }, - { - Name: "FindByArgGreaterThanEqual method", - Method: code.Method{ - Name: "FindByAgeGreaterThanEqual", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + Ordering: spec.OrderingAscending, }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorGreaterThanEqual, - ParamIndex: 1, + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, }, - }}, - }, - }, - { - Name: "FindByArgBetween method", - Method: code.Method{ - Name: "FindByAgeBetween", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{ageField}, Comparator: spec.ComparatorBetween, ParamIndex: 1}, - }}, - }, - }, - { - Name: "FindByArgIn method", - Method: code.Method{ - Name: "FindByCityIn", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ArrayType{ContainedType: code.TypeString}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + Ordering: spec.OrderingDescending, }, }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{cityField}, Comparator: spec.ComparatorIn, ParamIndex: 1}, - }}, - }, }, - { - Name: "FindByArgNotIn method", - Method: code.Method{ - Name: "FindByCityNotIn", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ArrayType{ContainedType: code.TypeString}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + // FindByCityOrderByNameFirst + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + }}, + Sorts: []spec.Sort{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Name")}, + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeNameStruct, "First")}, + }, + Ordering: spec.OrderingAscending, }, }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{cityField}, Comparator: spec.ComparatorNotIn, ParamIndex: 1}, - }}, - }, }, - { - Name: "FindByArgTrue method", - Method: code.Method{ - Name: "FindByEnabledTrue", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + // FindByEnabledFalse + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Enabled")}, + }, + Comparator: spec.ComparatorFalse, + ParamIndex: 1, + }, + }}, + }, + // FindByEnabledTrue + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Enabled")}, + }, + Comparator: spec.ComparatorTrue, + ParamIndex: 1, + }, + }}, + }, + // FindByID + spec.FindOperation{ + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + }}, + }, + // FindByNameFirst + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Name")}, + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeNameStruct, "First")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + }}, + }, + // FindByPhoneNumber + spec.FindOperation{ + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{ + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "PhoneNumber"), + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + }}, + }, + // FindByReferrerExists + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{ + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Referrer"), + }, + }, + Comparator: spec.ComparatorExists, + ParamIndex: 1, + }, + }}, + }, + // FindByReferrerID + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Referrer")}, + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + }}, + }, + // FindByReferrerNotExists + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Referrer")}, + }, + Comparator: spec.ComparatorNotExists, + ParamIndex: 1, + }, + }}, + }, + // FindTop5ByGenderOrderByAgeDesc + spec.FindOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + }}, + Sorts: []spec.Sort{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, + }, + Ordering: spec.OrderingDescending, }, }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{enabledField}, Comparator: spec.ComparatorTrue, ParamIndex: 1}, - }}, - }, + Limit: 5, }, - { - Name: "FindByArgFalse method", - Method: code.Method{ - Name: "FindByEnabledFalse", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + } + + for i := 0; i < repoIntf.NumMethods(); i++ { + method := repoIntf.Method(i) + + t.Run(method.Name(), func(t *testing.T) { + actualSpec, err := spec.ParseInterfaceMethod(testutils.Pkg, testutils.TypeUserNamed, method) + + if err != nil { + t.Errorf("Error = %s", err) + } + if method.Name() != actualSpec.Name { + t.Errorf("Expected = %+v\nReceived = %+v", method.Name(), actualSpec.Name) + } + if !types.Identical(method.Type(), actualSpec.Signature) { + t.Errorf("Expected = %+v\nReceived = %+v", method.Type(), actualSpec.Signature) + } + if !reflect.DeepEqual(expectedOperations[i], actualSpec.Operation) { + t.Errorf("Expected = %+v\nReceived = %+v", expectedOperations[i], actualSpec.Operation) + } + }) + } +} + +func TestParseInterfaceMethod_Update(t *testing.T) { + repoIntf := testutils.Pkg.Scope().Lookup("UserRepositoryUpdate").Type().Underlying().(*types.Interface) + + expectedOperations := []spec.Operation{ + // UpdateAgeIncByID + spec.UpdateOperation{ + Update: spec.UpdateFields{ + spec.UpdateField{ + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorInc, }, }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{enabledField}, - Comparator: spec.ComparatorFalse, - ParamIndex: 1, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID")}, }, - }}, - }, - }, - { - Name: "FindByArgExists method", - Method: code.Method{ - Name: "FindByReferrerExists", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, + }, + }}, + }, + // UpdateByID + spec.UpdateOperation{ + Update: spec.UpdateModel{}, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, + }, + }}, + }, + // UpdateConsentHistoryPushByID + spec.UpdateOperation{ + Update: spec.UpdateFields{ + spec.UpdateField{ + FieldReference: spec.FieldReference{ + code.StructField{ + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ConsentHistory"), + }, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorPush, }, }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{referrerField}, - Comparator: spec.ComparatorExists, - ParamIndex: 1, + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID")}, }, - }}, - }, - }, - { - Name: "FindByArgNotExists method", - Method: code.Method{ - Name: "FindByReferrerNotExists", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + }}, + }, + // UpdateEnabledAndConsentHistoryPushByID + spec.UpdateOperation{ + Update: spec.UpdateFields{ + spec.UpdateField{ + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Enabled")}, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{referrerField}, - Comparator: spec.ComparatorNotExists, - ParamIndex: 1, + spec.UpdateField{ + FieldReference: spec.FieldReference{ + code.StructField{ + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ConsentHistory"), + }, }, - }}, - }, - }, - { - Name: "FindByArgOrderByArg method", - Method: code.Method{ - Name: "FindByCityOrderByAge", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + ParamIndex: 2, + Operator: spec.UpdateOperatorPush, }, }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{cityField}, Comparator: spec.ComparatorEqual, ParamIndex: 1}, - }}, - Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingAscending}, + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 3, }, - }, + }}, }, - { - Name: "FindByArgOrderByArgAsc method", - Method: code.Method{ - Name: "FindByCityOrderByAgeAsc", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{cityField}, Comparator: spec.ComparatorEqual, ParamIndex: 1}, - }}, - Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingAscending}, + // UpdateGenderAndCityByID + spec.UpdateOperation{ + Update: spec.UpdateFields{ + spec.UpdateField{ + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender")}, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, }, - }, - }, - { - Name: "FindByArgOrderByArgDesc method", - Method: code.Method{ - Name: "FindByCityOrderByAgeDesc", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + spec.UpdateField{ + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, + }, + ParamIndex: 2, + Operator: spec.UpdateOperatorSet, }, }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{cityField}, Comparator: spec.ComparatorEqual, ParamIndex: 1}, - }}, - Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending}, + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 3, }, - }, + }}, }, - { - Name: "FindByArgOrderByDeepArg method", - Method: code.Method{ - Name: "FindByCityOrderByNameFirst", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + // UpdateGenderByAge + spec.UpdateOperation{ + Update: spec.UpdateFields{ + spec.UpdateField{ + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender")}, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, }, }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{cityField}, Comparator: spec.ComparatorEqual, ParamIndex: 1}, - }}, - Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{nameField, firstNameField}, Ordering: spec.OrderingAscending}, + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, - }, + }}, }, - { - Name: "FindByArgOrderByArgAndArg method", - Method: code.Method{ - Name: "FindByCityOrderByCityAndAgeDesc", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + // UpdateGenderByID + spec.UpdateOperation{ + Update: spec.UpdateFields{ + spec.UpdateField{ + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender")}, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, }, }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{cityField}, Comparator: spec.ComparatorEqual, ParamIndex: 1}, - }}, - Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{cityField}, Ordering: spec.OrderingAscending}, - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending}, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, + }, + }}, + }, + // UpdateNameFirstByID + spec.UpdateOperation{ + Update: spec.UpdateFields{ + spec.UpdateField{ + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Name")}, + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeNameStruct, "First")}, + }, + ParamIndex: 1, + Operator: spec.UpdateOperatorSet, }, }, - }, - { - Name: "FindTopNByArg method", - Method: code.Method{ - Name: "FindTop5ByGenderOrderByAgeDesc", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, - }, - ExpectedOperation: spec.FindOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{genderField}, Comparator: spec.ComparatorEqual, ParamIndex: 1}, - }}, - Sorts: []spec.Sort{ - {FieldReference: spec.FieldReference{ageField}, Ordering: spec.OrderingDescending}, - }, - Limit: 5, - }, + }}, }, } - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - actualSpec, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) + for i := 0; i < repoIntf.NumMethods(); i++ { + method := repoIntf.Method(i) + + t.Run(method.Name(), func(t *testing.T) { + actualSpec, err := spec.ParseInterfaceMethod(testutils.Pkg, testutils.TypeUserNamed, method) if err != nil { t.Errorf("Error = %s", err) } - expectedOutput := spec.MethodSpec{ - Name: testCase.Method.Name, - Params: testCase.Method.Params, - Returns: testCase.Method.Returns, - Operation: testCase.ExpectedOperation, + if method.Name() != actualSpec.Name { + t.Errorf("Expected = %+v\nReceived = %+v", method.Name(), actualSpec.Name) + } + if !types.Identical(method.Type(), actualSpec.Signature) { + t.Errorf("Expected = %+v\nReceived = %+v", method.Type(), actualSpec.Signature) } - if !reflect.DeepEqual(actualSpec, expectedOutput) { - t.Errorf("Expected = %+v\nReceived = %+v", expectedOutput, actualSpec) + if !reflect.DeepEqual(expectedOperations[i], actualSpec.Operation) { + t.Errorf("Expected = %+v\nReceived = %+v", expectedOperations[i], actualSpec.Operation) } }) } } -func TestParseInterfaceMethod_Update(t *testing.T) { - testTable := []ParseInterfaceMethodTestCase{ - { - Name: "UpdateByArg", - Method: code.Method{ - Name: "UpdateByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - }, - ExpectedOperation: spec.UpdateOperation{ - Update: spec.UpdateModel{}, - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 2}, - }}, - }, - }, - { - Name: "UpdateArgByArg one method", - Method: code.Method{ - Name: "UpdateGenderByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - }, - ExpectedOperation: spec.UpdateOperation{ - Update: spec.UpdateFields{ - spec.UpdateField{ - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, +func TestParseInterfaceMethod_Delete(t *testing.T) { + repoIntf := testutils.Pkg.Scope().Lookup("UserRepositoryDelete").Type().Underlying().(*types.Interface) + + expectedOperations := []spec.Operation{ + // DeleteAll + spec.DeleteOperation{ + Mode: spec.QueryModeMany, + }, + // DeleteByAgeBetween + spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, }, - }, - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + Comparator: spec.ComparatorBetween, + ParamIndex: 1, + }, + }}, + }, + // DeleteByAgeGreaterThan + spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, }, - }}, - }, - }, - { - Name: "UpdateArgByArg many method", - Method: code.Method{ - Name: "UpdateGenderByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.UpdateOperation{ - Update: spec.UpdateFields{ - spec.UpdateField{ - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, + Comparator: spec.ComparatorGreaterThan, + ParamIndex: 1, + }, + }}, + }, + // DeleteByAgeGreaterThanEqual + spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, }, - }, - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + Comparator: spec.ComparatorGreaterThanEqual, + ParamIndex: 1, + }, + }}, + }, + // DeleteByAgeLessThan + spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, }, - }}, - }, - }, - { - Name: "UpdateArgByArg one with deeply referenced update field method", - Method: code.Method{ - Name: "UpdateNameFirstByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - }, - ExpectedOperation: spec.UpdateOperation{ - Update: spec.UpdateFields{ - spec.UpdateField{ - FieldReference: spec.FieldReference{nameField, firstNameField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, + Comparator: spec.ComparatorLessThan, + ParamIndex: 1, + }, + }}, + }, + // DeleteByAgeLessThanEqual + spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Age")}, }, - }, - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + Comparator: spec.ComparatorLessThanEqual, + ParamIndex: 1, + }, + }}, + }, + // DeleteByCity + spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, }, - }}, - }, - }, - { - Name: "UpdateArgAndArgByArg method", - Method: code.Method{ - Name: "UpdateGenderAndCityByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - {Type: code.TypeString}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, }, - }, - ExpectedOperation: spec.UpdateOperation{ - Update: spec.UpdateFields{ - spec.UpdateField{ - FieldReference: spec.FieldReference{genderField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, - }, - spec.UpdateField{ - FieldReference: spec.FieldReference{cityField}, - ParamIndex: 2, - Operator: spec.UpdateOperatorSet, - }, - }, - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 3}, - }}, - }, + }}, }, - { - Name: "UpdateArgPushByArg method", - Method: code.Method{ - Name: "UpdateConsentHistoryPushByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("ConsentHistoryItem")}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.UpdateOperation{ - Update: spec.UpdateFields{ - spec.UpdateField{ - FieldReference: spec.FieldReference{consentHistoryField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorPush, + // DeleteByCityAndGender + spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorAnd, + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, }, - }, - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, - }}, - }, - }, - { - Name: "UpdateArgPushByArg method", - Method: code.Method{ - Name: "UpdateAgeIncByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, }, }, - ExpectedOperation: spec.UpdateOperation{ - Update: spec.UpdateFields{ - spec.UpdateField{ - FieldReference: spec.FieldReference{ageField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorInc, + }, + // DeleteByCityIn + spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, }, + Comparator: spec.ComparatorIn, + ParamIndex: 1, + }, + }}, + }, + // DeleteByCityNot + spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, + }, + Comparator: spec.ComparatorNot, + ParamIndex: 1, }, - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ + }}, + }, + // DeleteByCityOrGender + spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{ + Operator: spec.OperatorOr, + Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{idField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 2, }, - }}, + }, }, }, - { - Name: "UpdateArgAndArgPushByArg method", - Method: code.Method{ - Name: "UpdateEnabledAndConsentHistoryPushByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeBool}, - {Type: code.SimpleType("ConsentHistoryItem")}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, + // DeleteByID + spec.DeleteOperation{ + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "ID")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + }}, + }, + // DeleteByNameFirst + spec.DeleteOperation{ + Mode: spec.QueryModeMany, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Name")}, + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeNameStruct, "First")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, + }, + }}, + }, + // DeleteByPhoneNumber + spec.DeleteOperation{ + Mode: spec.QueryModeOne, + Query: spec.QuerySpec{Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{ + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "PhoneNumber"), + }, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, }, - }, - ExpectedOperation: spec.UpdateOperation{ - Update: spec.UpdateFields{ - spec.UpdateField{ - FieldReference: spec.FieldReference{enabledField}, - ParamIndex: 1, - Operator: spec.UpdateOperatorSet, - }, - spec.UpdateField{ - FieldReference: spec.FieldReference{consentHistoryField}, - ParamIndex: 2, - Operator: spec.UpdateOperatorPush, - }, - }, - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 3}, - }}, - }, + }}, }, } - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - actualSpec, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) + for i := 0; i < repoIntf.NumMethods(); i++ { + method := repoIntf.Method(i) + + t.Run(method.Name(), func(t *testing.T) { + actualSpec, err := spec.ParseInterfaceMethod(testutils.Pkg, testutils.TypeUserNamed, method) if err != nil { t.Errorf("Error = %s", err) } - expectedOutput := spec.MethodSpec{ - Name: testCase.Method.Name, - Params: testCase.Method.Params, - Returns: testCase.Method.Returns, - Operation: testCase.ExpectedOperation, + if method.Name() != actualSpec.Name { + t.Errorf("Expected = %+v\nReceived = %+v", method.Name(), actualSpec.Name) } - if !reflect.DeepEqual(actualSpec, expectedOutput) { - t.Errorf("Expected = %+v\nReceived = %+v", expectedOutput, actualSpec) + if !types.Identical(method.Type(), actualSpec.Signature) { + t.Errorf("Expected = %+v\nReceived = %+v", method.Type(), actualSpec.Signature) + } + if !reflect.DeepEqual(expectedOperations[i], actualSpec.Operation) { + t.Errorf("Expected = %+v\nReceived = %+v", expectedOperations[i], actualSpec.Operation) } }) } } -func TestParseInterfaceMethod_Delete(t *testing.T) { - testTable := []ParseInterfaceMethodTestCase{ - { - Name: "DeleteByArg one-mode method", - Method: code.Method{ - Name: "DeleteByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - }, - ExpectedOperation: spec.DeleteOperation{ - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{idField}, Comparator: spec.ComparatorEqual, ParamIndex: 1}, - }}, - }, - }, - { - Name: "DeleteByArg many-mode method", - Method: code.Method{ - Name: "DeleteByCity", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{cityField}, Comparator: spec.ComparatorEqual, ParamIndex: 1}, - }}, - }, - }, - { - Name: "DeleteByMultiWordArg method", - Method: code.Method{ - Name: "DeleteByPhoneNumber", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - }, - ExpectedOperation: spec.DeleteOperation{ - Mode: spec.QueryModeOne, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ +func TestParseInterfaceMethod_Count(t *testing.T) { + repoIntf := testutils.Pkg.Scope().Lookup("UserRepositoryCount").Type().Underlying().(*types.Interface) + + expectedOperations := []spec.Operation{ + // CountAll + spec.CountOperation{ + Query: spec.QuerySpec{}, + }, + // CountByGender + spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ { - FieldReference: spec.FieldReference{phoneNumberField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender")}, + }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, }, - }}, - }, - }, - { - Name: "DeleteAll method", - Method: code.Method{ - Name: "DeleteAll", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, }, }, - ExpectedOperation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - }, }, - { - Name: "DeleteByArgAndArg method", - Method: code.Method{ - Name: "DeleteByCityAndGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Operator: spec.OperatorAnd, - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{cityField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, - }, - { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, + // CountByNameFirst + spec.CountOperation{ + Query: spec.QuerySpec{ + Predicates: []spec.Predicate{ + { + FieldReference: spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Name")}, + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeNameStruct, "First")}, }, + Comparator: spec.ComparatorEqual, + ParamIndex: 1, }, }, }, }, - { - Name: "DeleteByArgOrArg method", - Method: code.Method{ - Name: "DeleteByCityOrGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{ - Operator: spec.OperatorOr, - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{cityField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, - }, - { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 2, - }, - }, - }, - }, - }, - { - Name: "DeleteByArgNot method", - Method: code.Method{ - Name: "DeleteByCityNot", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{cityField}, Comparator: spec.ComparatorNot, ParamIndex: 1}, - }}, - }, - }, - { - Name: "DeleteByArgLessThan method", - Method: code.Method{ - Name: "DeleteByAgeLessThan", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{ageField}, Comparator: spec.ComparatorLessThan, ParamIndex: 1}, - }}, - }, - }, - { - Name: "DeleteByArgLessThanEqual method", - Method: code.Method{ - Name: "DeleteByAgeLessThanEqual", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorLessThanEqual, - ParamIndex: 1, - }, - }}, - }, - }, - { - Name: "DeleteByArgGreaterThan method", - Method: code.Method{ - Name: "DeleteByAgeGreaterThan", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorGreaterThan, - ParamIndex: 1, - }, - }}, - }, - }, - { - Name: "DeleteByArgGreaterThanEqual method", - Method: code.Method{ - Name: "DeleteByAgeGreaterThanEqual", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{ageField}, - Comparator: spec.ComparatorGreaterThanEqual, - ParamIndex: 1, - }, - }}, - }, - }, - { - Name: "DeleteByArgBetween method", - Method: code.Method{ - Name: "DeleteByAgeBetween", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{ageField}, Comparator: spec.ComparatorBetween, ParamIndex: 1}, - }}, - }, - }, - { - Name: "DeleteByArgIn method", - Method: code.Method{ - Name: "DeleteByCityIn", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ArrayType{ContainedType: code.TypeString}}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.DeleteOperation{ - Mode: spec.QueryModeMany, - Query: spec.QuerySpec{Predicates: []spec.Predicate{ - {FieldReference: spec.FieldReference{cityField}, Comparator: spec.ComparatorIn, ParamIndex: 1}, - }}, - }, - }, } - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - actualSpec, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) - - if err != nil { - t.Errorf("Error = %s", err) - } - expectedOutput := spec.MethodSpec{ - Name: testCase.Method.Name, - Params: testCase.Method.Params, - Returns: testCase.Method.Returns, - Operation: testCase.ExpectedOperation, - } - if !reflect.DeepEqual(actualSpec, expectedOutput) { - t.Errorf("Expected = %+v\nReceived = %+v", expectedOutput, actualSpec) - } - }) - } -} - -func TestParseInterfaceMethod_Count(t *testing.T) { - testTable := []ParseInterfaceMethodTestCase{ - { - Name: "CountAll method", - Method: code.Method{ - Name: "CountAll", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.CountOperation{ - Query: spec.QuerySpec{}, - }, - }, - { - Name: "CountByArg method", - Method: code.Method{ - Name: "CountByGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedOperation: spec.CountOperation{ - Query: spec.QuerySpec{ - Predicates: []spec.Predicate{ - { - FieldReference: spec.FieldReference{genderField}, - Comparator: spec.ComparatorEqual, - ParamIndex: 1, - }, - }, - }, - }, - }, - } + for i := 0; i < repoIntf.NumMethods(); i++ { + method := repoIntf.Method(i) - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - actualSpec, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) + t.Run(method.Name(), func(t *testing.T) { + actualSpec, err := spec.ParseInterfaceMethod(testutils.Pkg, testutils.TypeUserNamed, method) if err != nil { t.Errorf("Error = %s", err) } - expectedOutput := spec.MethodSpec{ - Name: testCase.Method.Name, - Params: testCase.Method.Params, - Returns: testCase.Method.Returns, - Operation: testCase.ExpectedOperation, + if method.Name() != actualSpec.Name { + t.Errorf("Expected = %+v\nReceived = %+v", method.Name(), actualSpec.Name) + } + if !types.Identical(method.Type(), actualSpec.Signature) { + t.Errorf("Expected = %+v\nReceived = %+v", method.Type(), actualSpec.Signature) } - if !reflect.DeepEqual(actualSpec, expectedOutput) { - t.Errorf("Expected = %+v\nReceived = %+v", expectedOutput, actualSpec) + if !reflect.DeepEqual(expectedOperations[i], actualSpec.Operation) { + t.Errorf("Expected = %+v\nReceived = %+v", expectedOperations[i], actualSpec.Operation) } }) } } -type ParseInterfaceMethodInvalidTestCase struct { - Name string - Method code.Method - ExpectedError error -} +func TestParseInterfaceMethod_InvalidOperation(t *testing.T) { + repoIntf := testutils.Pkg.Scope().Lookup("UserRepositoryInvalidOperation").Type().Underlying().(*types.Interface) + method := repoIntf.Method(0) -func TestParseInterfaceMethod_Invalid(t *testing.T) { - _, err := spec.ParseInterfaceMethod(structs, structModel, code.Method{ - Name: "SearchByID", - }) + _, err := spec.ParseInterfaceMethod(testutils.Pkg, testutils.TypeUserNamed, method) expectedError := spec.NewUnknownOperationError("Search") if !errors.Is(err, expectedError) { @@ -1471,1043 +999,240 @@ func TestParseInterfaceMethod_Invalid(t *testing.T) { } func TestParseInterfaceMethod_Insert_Invalid(t *testing.T) { - testTable := []ParseInterfaceMethodInvalidTestCase{ - { - Name: "invalid number of returns", - Method: code.Method{ - Name: "Insert", - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.InterfaceType{}, - code.TypeError, - }, - }, - ExpectedError: spec.NewOperationReturnCountUnmatchedError(2), - }, - { - Name: "unsupported return types from insert method", - Method: code.Method{ - Name: "Insert", - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.TypeError, - }, - }, - ExpectedError: spec.NewUnsupportedReturnError( - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - 0, - ), - }, - { - Name: "unempty interface return from insert method", - Method: code.Method{ - Name: "Insert", - Returns: []code.Type{ - code.InterfaceType{ - Methods: []code.Method{ - {Name: "DoSomething"}, - }, - }, - code.TypeError, - }, - }, - ExpectedError: spec.NewUnsupportedReturnError(code.InterfaceType{}, 0), - }, - { - Name: "error return not provided", - Method: code.Method{ - Name: "Insert", - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.InterfaceType{}, - }, - }, - ExpectedError: spec.NewUnsupportedReturnError(code.InterfaceType{}, 1), - }, - { - Name: "no context parameter", - Method: code.Method{ - Name: "Insert", - Params: []code.Param{ - {Name: "userModel", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - }, - Returns: []code.Type{ - code.InterfaceType{}, - code.TypeError, - }, - }, - ExpectedError: spec.ErrContextParamRequired, - }, - { - Name: "mismatched model parameter for one mode", - Method: code.Method{ - Name: "Insert", - Params: []code.Param{ - { - Name: "ctx", - Type: code.ExternalType{PackageAlias: "context", Name: "Context"}, - }, - { - Name: "userModel", - Type: code.ArrayType{ - ContainedType: code.PointerType{ - ContainedType: code.SimpleType("UserModel"), - }, - }, - }, - }, - Returns: []code.Type{ - code.InterfaceType{}, - code.TypeError, - }, - }, - ExpectedError: spec.ErrInvalidParam, - }, - { - Name: "mismatched model parameter for many mode", - Method: code.Method{ - Name: "Insert", - Params: []code.Param{ - {Name: "ctx", Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Name: "userModel", Type: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.InterfaceType{}}, - code.TypeError, - }, - }, - ExpectedError: spec.ErrInvalidParam, - }, + repoIntf := testutils.Pkg.Scope().Lookup("UserRepositoryInvalidInsert").Type().Underlying().(*types.Interface) + + expectedErrors := []error{ + // Insert1 + spec.NewOperationReturnCountUnmatchedError(2), + // Insert2 + spec.NewUnsupportedReturnError(types.NewPointer(testutils.TypeUserNamed), 0), + // Insert3 + spec.NewUnsupportedReturnError(repoIntf.Method(2).Type().(*types.Signature).Results().At(0).Type(), 0), + // Insert4 + spec.NewUnsupportedReturnError(repoIntf.Method(3).Type().(*types.Signature).Results().At(1).Type(), 1), + // Insert5 + spec.ErrContextParamRequired, + // Insert6 + spec.ErrInvalidParam, + // Insert7 + spec.ErrInvalidParam, } - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - _, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) + for i := 0; i < repoIntf.NumMethods(); i++ { + method := repoIntf.Method(i) - if err.Error() != testCase.ExpectedError.Error() { - t.Errorf("\nExpected = %+v\nReceived = %+v", testCase.ExpectedError, err) + t.Run(method.Name(), func(t *testing.T) { + _, err := spec.ParseInterfaceMethod(testutils.Pkg, testutils.TypeUserNamed, method) + + if err.Error() != expectedErrors[i].Error() { + t.Errorf("\nExpected = %+v\nReceived = %+v", expectedErrors[i], err) } }) } } func TestParseInterfaceMethod_Find_Invalid(t *testing.T) { - testTable := []ParseInterfaceMethodInvalidTestCase{ - { - Name: "invalid number of returns", - Method: code.Method{ - Name: "FindByID", - Returns: []code.Type{ - code.SimpleType("UserModel"), - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewOperationReturnCountUnmatchedError(2), - }, - { - Name: "unsupported return types from find method", - Method: code.Method{ - Name: "FindByID", - Returns: []code.Type{ - code.SimpleType("UserModel"), - code.TypeError, - }, - }, - ExpectedError: spec.NewUnsupportedReturnError(code.SimpleType("UserModel"), 0), - }, - { - Name: "error return not provided", - Method: code.Method{ - Name: "FindByID", - Returns: []code.Type{ - code.SimpleType("UserModel"), - code.TypeInt, - }, - }, - ExpectedError: spec.NewUnsupportedReturnError(code.TypeInt, 1), - }, - { - Name: "find method with Top keyword but no number and query", - Method: code.Method{ - Name: "FindTop", - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.ErrLimitAmountRequired, - }, - { - Name: "find method with Top keyword but no number", - Method: code.Method{ - Name: "FindTopAll", - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.ErrLimitAmountRequired, - }, - { - Name: "find method with TopN keyword where N is not positive", - Method: code.Method{ - Name: "FindTop0All", - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.ErrLimitNonPositive, - }, - { - Name: "find one method with TopN keyword", - Method: code.Method{ - Name: "FindTop5All", - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.TypeError, - }, - }, - ExpectedError: spec.ErrLimitOnFindOne, - }, - { - Name: "find method without query", - Method: code.Method{ - Name: "Find", - Returns: []code.Type{ - code.PointerType{ContainedType: code.SimpleType("UserModel")}, - code.TypeError, - }, - }, - ExpectedError: spec.ErrQueryRequired, - }, - { - Name: "misplaced operator token (leftmost)", - Method: code.Method{ - Name: "FindByAndGender", - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewInvalidQueryError([]string{"And", "Gender"}), - }, - { - Name: "misplaced operator token (rightmost)", - Method: code.Method{ - Name: "FindByGenderAnd", - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewInvalidQueryError([]string{"Gender", "And"}), - }, - { - Name: "misplaced operator token (double operator)", - Method: code.Method{ - Name: "FindByGenderAndAndCity", - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewInvalidQueryError([]string{"Gender", "And", "And", "City"}), - }, - { - Name: "ambiguous query", - Method: code.Method{ - Name: "FindByGenderAndCityOrAge", - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewInvalidQueryError([]string{"Gender", "And", "City", "Or", "Age"}), - }, - { - Name: "no context parameter", - Method: code.Method{ - Name: "FindByGender", - Params: []code.Param{ - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.ErrContextParamRequired, - }, - { - Name: "mismatched number of parameters", - Method: code.Method{ - Name: "FindByCity", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.ErrInvalidParam, - }, - { - Name: "struct field not found", - Method: code.Method{ - Name: "FindByCountry", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewStructFieldNotFoundError([]string{"Country"}), - }, - { - Name: "deeply referenced struct field not found", - Method: code.Method{ - Name: "FindByNameMiddle", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewStructFieldNotFoundError([]string{"Name", "Middle"}), - }, - { - Name: "deeply referenced struct not found", - Method: code.Method{ - Name: "FindByContactPhone", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewStructFieldNotFoundError([]string{"Contact", "Phone"}), - }, - { - Name: "deeply referenced external struct field", - Method: code.Method{ - Name: "FindByDefaultPaymentMethod", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewStructFieldNotFoundError([]string{"Default", "Payment", "Method"}), - }, - { - Name: "incompatible struct field for True comparator", - Method: code.Method{ - Name: "FindByGenderTrue", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewIncompatibleComparatorError(spec.ComparatorTrue, code.StructField{ - Name: "Gender", - Type: code.SimpleType("Gender"), - }), - }, - { - Name: "incompatible struct field for False comparator", - Method: code.Method{ - Name: "FindByGenderFalse", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewIncompatibleComparatorError(spec.ComparatorFalse, code.StructField{ - Name: "Gender", - Type: code.SimpleType("Gender"), - }), - }, - { - Name: "mismatched method parameter type", - Method: code.Method{ - Name: "FindByGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewArgumentTypeNotMatchedError(genderField.Name, genderField.Type, code.TypeString), - }, - { - Name: "mismatched method parameter type for special case", - Method: code.Method{ - Name: "FindByCityIn", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewArgumentTypeNotMatchedError(cityField.Name, - code.ArrayType{ContainedType: code.TypeString}, code.TypeString), - }, - { - Name: "misplaced operator token (leftmost)", - Method: code.Method{ - Name: "FindAllOrderByAndAge", - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewInvalidSortError([]string{"Order", "By", "And", "Age"}), - }, - { - Name: "misplaced operator token (rightmost)", - Method: code.Method{ - Name: "FindAllOrderByAgeAnd", - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewInvalidSortError([]string{"Order", "By", "Age", "And"}), - }, - { - Name: "misplaced operator token (double operator)", - Method: code.Method{ - Name: "FindAllOrderByAgeAndAndGender", - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewInvalidSortError([]string{"Order", "By", "Age", "And", "And", "Gender"}), - }, - { - Name: "sort field not found", - Method: code.Method{ - Name: "FindAllOrderByCountry", - Returns: []code.Type{ - code.ArrayType{ContainedType: code.PointerType{ContainedType: code.SimpleType("UserModel")}}, - code.TypeError, - }, - }, - ExpectedError: spec.NewStructFieldNotFoundError([]string{"Country"}), - }, + repoIntf := testutils.Pkg.Scope().Lookup("UserRepositoryInvalidFind").Type().Underlying().(*types.Interface) + + expectedErrors := []error{ + // Find + spec.ErrQueryRequired, + // FindAll + spec.NewOperationReturnCountUnmatchedError(2), + // FindAllOrderByAgeAnd + spec.NewInvalidSortError([]string{"Order", "By", "Age", "And"}), + // FindAllOrderByAgeAndAndGender + spec.NewInvalidSortError([]string{"Order", "By", "Age", "And", "And", "Gender"}), + // FindAllOrderByAndAge + spec.NewInvalidSortError([]string{"Order", "By", "And", "Age"}), + // FindAllOrderByCountry + spec.NewStructFieldNotFoundError([]string{"Country"}), + // FindByAge + spec.ErrContextParamRequired, + // FindByAndGender + spec.NewInvalidQueryError([]string{"And", "Gender"}), + // FindByCity + spec.ErrInvalidParam, + // FindByCityIn + spec.NewArgumentTypeNotMatchedError("City", types.NewSlice(code.TypeString), code.TypeString), + // FindByCountry + spec.NewStructFieldNotFoundError([]string{"Country"}), + // FindByGender + spec.NewArgumentTypeNotMatchedError("Gender", testutils.TypeGenderNamed, code.TypeString), + // FindByGenderAnd + spec.NewInvalidQueryError([]string{"Gender", "And"}), + // FindByGenderAndAndCity + spec.NewInvalidQueryError([]string{"Gender", "And", "And", "City"}), + // FindByGenderAndCityOrAge + spec.NewInvalidQueryError([]string{"Gender", "And", "City", "Or", "Age"}), + // FindByGenderFalse + spec.NewIncompatibleComparatorError(spec.ComparatorFalse, code.StructField{ + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + }), + // FindByGenderTrue + spec.NewIncompatibleComparatorError(spec.ComparatorTrue, code.StructField{ + Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender"), + }), + // FindByID + spec.NewUnsupportedReturnError(testutils.TypeUserNamed, 0), + // FindByNameMiddle + spec.NewStructFieldNotFoundError([]string{"Name", "Middle"}), + // FindTop + spec.ErrLimitAmountRequired, + // FindTop0All + spec.ErrLimitNonPositive, + // FindTop5All + spec.ErrLimitOnFindOne, + // FindTopAll + spec.ErrLimitAmountRequired, } - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - _, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) + for i := 0; i < repoIntf.NumMethods(); i++ { + method := repoIntf.Method(i) + + t.Run(method.Name(), func(t *testing.T) { + _, err := spec.ParseInterfaceMethod(testutils.Pkg, testutils.TypeUserNamed, method) - if err.Error() != testCase.ExpectedError.Error() { - t.Errorf("\nExpected = %+v\nReceived = %+v", testCase.ExpectedError.Error(), err.Error()) + if err.Error() != expectedErrors[i].Error() { + t.Errorf("\nExpected = %+v\nReceived = %+v", expectedErrors[i], err) } }) } } func TestParseInterfaceMethod_Update_Invalid(t *testing.T) { - testTable := []ParseInterfaceMethodInvalidTestCase{ - { - Name: "invalid number of returns", - Method: code.Method{ - Name: "UpdateAgeByID", - Returns: []code.Type{ - code.TypeBool, - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewOperationReturnCountUnmatchedError(2), - }, - { - Name: "unsupported return types from update method", - Method: code.Method{ - Name: "UpdateAgeByID", - Returns: []code.Type{ - code.TypeFloat64, - code.TypeError, - }, - }, - ExpectedError: spec.NewUnsupportedReturnError(code.TypeFloat64, 0), - }, - { - Name: "error return not provided", - Method: code.Method{ - Name: "UpdateAgeByID", - Returns: []code.Type{ - code.TypeBool, - code.TypeBool, - }, - }, - ExpectedError: spec.NewUnsupportedReturnError(code.TypeBool, 1), - }, - { - Name: "update with no field provided", - Method: code.Method{ - Name: "UpdateByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - }, - ExpectedError: spec.ErrInvalidUpdateFields, - }, - { - Name: "misplaced And token in update fields", - Method: code.Method{ - Name: "UpdateAgeAndAndGenderByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - }, - ExpectedError: spec.ErrInvalidUpdateFields, - }, - { - Name: "push operator in non-array field", - Method: code.Method{ - Name: "UpdateGenderPushByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - }, - ExpectedError: spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorPush, spec.FieldReference{ - code.StructField{ - Name: "Gender", - Type: code.SimpleType("Gender"), - }, - }), - }, - { - Name: "inc operator in non-number field", - Method: code.Method{ - Name: "UpdateCityIncByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - }, - ExpectedError: spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorInc, spec.FieldReference{ - code.StructField{ - Name: "City", - Type: code.TypeString, - }, - }), - }, - { - Name: "update method without query", - Method: code.Method{ - Name: "UpdateCity", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - }, - ExpectedError: spec.ErrQueryRequired, - }, - { - Name: "ambiguous query", - Method: code.Method{ - Name: "UpdateAgeByIDAndUsernameOrGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewInvalidQueryError([]string{"ID", "And", "Username", "Or", "Gender"}), - }, - { - Name: "parameters for push operator is not array's contained type", - Method: code.Method{ - Name: "UpdateConsentHistoryPushByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.ArrayType{ContainedType: code.SimpleType("ConsentHistoryItem")}}, - {Type: code.ExternalType{PackageAlias: "primitive", Name: "ObjectID"}}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewArgumentTypeNotMatchedError( - consentHistoryField.Name, - code.SimpleType("ConsentHistoryItem"), - code.ArrayType{ - ContainedType: code.SimpleType("ConsentHistoryItem"), - }, - ), - }, - { - Name: "insufficient function parameters", - Method: code.Method{ - Name: "UpdateEnabledAll", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - // {Type: code.SimpleType("Enabled")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.ErrInvalidUpdateFields, - }, - { - Name: "update model with invalid parameter", - Method: code.Method{ - Name: "UpdateByID", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.TypeBool, - code.TypeError, - }, - }, - ExpectedError: spec.ErrInvalidUpdateFields, - }, - { - Name: "no context parameter", - Method: code.Method{ - Name: "UpdateAgeByGender", - Params: []code.Param{ - {Type: code.TypeInt}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.ErrContextParamRequired, - }, - { - Name: "struct field not found in update fields", - Method: code.Method{ - Name: "UpdateCountryByGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewStructFieldNotFoundError([]string{"Country"}), - }, - { - Name: "struct field does not match parameter in update fields", - Method: code.Method{ - Name: "UpdateAgeByGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeFloat64}, - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewArgumentTypeNotMatchedError(ageField.Name, ageField.Type, code.TypeFloat64), - }, - { - Name: "struct field does not match parameter in query", - Method: code.Method{ - Name: "UpdateAgeByGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeInt}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewArgumentTypeNotMatchedError(genderField.Name, genderField.Type, code.TypeString), - }, + repoIntf := testutils.Pkg.Scope().Lookup("UserRepositoryInvalidUpdate").Type().Underlying().(*types.Interface) + + expectedErrors := []error{ + // UpdateAgeAndAndGenderByID + spec.ErrInvalidUpdateFields, + // UpdateAgeByGender + spec.ErrContextParamRequired, + // UpdateAgeByID + spec.NewOperationReturnCountUnmatchedError(2), + // UpdateAgeByIDAndUsernameOrGender + spec.NewInvalidQueryError([]string{"ID", "And", "Username", "Or", "Gender"}), + // UpdateByGender + spec.ErrInvalidUpdateFields, + // UpdateByID + spec.ErrInvalidUpdateFields, + // UpdateCity + spec.ErrQueryRequired, + // UpdateCityByID + spec.NewUnsupportedReturnError(code.TypeFloat64, 0), + // UpdateCityIncByID + spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorInc, spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "City")}, + }), + // UpdateConsentHistoryPushByID + spec.NewArgumentTypeNotMatchedError("ConsentHistory", + testutils.TypeConsentHistoryNamed, types.NewSlice(testutils.TypeConsentHistoryNamed)), + // UpdateCountryByGender + spec.NewStructFieldNotFoundError([]string{"Country"}), + // UpdateEnabledAll + spec.ErrInvalidUpdateFields, + // UpdateEnabledByCity + spec.NewArgumentTypeNotMatchedError("City", code.TypeString, code.TypeInt), + // UpdateEnabledByGender + spec.NewArgumentTypeNotMatchedError("Enabled", code.TypeBool, code.TypeInt), + // UpdateEnabledByID + spec.NewUnsupportedReturnError(code.TypeBool, 1), + // UpdateGenderPushByID + spec.NewIncompatibleUpdateOperatorError(spec.UpdateOperatorPush, spec.FieldReference{ + code.StructField{Var: testutils.FindStructFieldByName(testutils.TypeUserStruct, "Gender")}, + }), } - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - _, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) + for i := 0; i < repoIntf.NumMethods(); i++ { + method := repoIntf.Method(i) + + t.Run(method.Name(), func(t *testing.T) { + _, err := spec.ParseInterfaceMethod(testutils.Pkg, testutils.TypeUserNamed, method) - if err.Error() != testCase.ExpectedError.Error() { - t.Errorf("\nExpected = %+v\nReceived = %+v", testCase.ExpectedError, err) + if err.Error() != expectedErrors[i].Error() { + t.Errorf("\nExpected = %+v\nReceived = %+v", expectedErrors[i], err) } }) } } func TestParseInterfaceMethod_Delete_Invalid(t *testing.T) { - testTable := []ParseInterfaceMethodInvalidTestCase{ - { - Name: "invalid number of returns", - Method: code.Method{ - Name: "DeleteByID", - Returns: []code.Type{ - code.SimpleType("UserModel"), - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewOperationReturnCountUnmatchedError(2), - }, - { - Name: "unsupported return types from delete method", - Method: code.Method{ - Name: "DeleteByID", - Returns: []code.Type{ - code.TypeFloat64, - code.TypeError, - }, - }, - ExpectedError: spec.NewUnsupportedReturnError(code.TypeFloat64, 0), - }, - { - Name: "error return not provided", - Method: code.Method{ - Name: "DeleteByID", - Returns: []code.Type{ - code.TypeInt, - code.TypeBool, - }, - }, - ExpectedError: spec.NewUnsupportedReturnError(code.TypeBool, 1), - }, - { - Name: "delete method without query", - Method: code.Method{ - Name: "Delete", - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.ErrQueryRequired, - }, - { - Name: "misplaced operator token (leftmost)", - Method: code.Method{ - Name: "DeleteByAndGender", - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewInvalidQueryError([]string{"And", "Gender"}), - }, - { - Name: "misplaced operator token (rightmost)", - Method: code.Method{ - Name: "DeleteByGenderAnd", - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewInvalidQueryError([]string{"Gender", "And"}), - }, - { - Name: "misplaced operator token (double operator)", - Method: code.Method{ - Name: "DeleteByGenderAndAndCity", - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewInvalidQueryError([]string{"Gender", "And", "And", "City"}), - }, - { - Name: "ambiguous query", - Method: code.Method{ - Name: "DeleteByGenderAndCityOrAge", - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewInvalidQueryError([]string{"Gender", "And", "City", "Or", "Age"}), - }, - { - Name: "no context parameter", - Method: code.Method{ - Name: "DeleteByGender", - Params: []code.Param{ - {Type: code.SimpleType("Gender")}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.ErrContextParamRequired, - }, - { - Name: "mismatched number of parameters", - Method: code.Method{ - Name: "DeleteByCity", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.ErrInvalidParam, - }, - { - Name: "struct field not found", - Method: code.Method{ - Name: "DeleteByCountry", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewStructFieldNotFoundError([]string{"Country"}), - }, - { - Name: "mismatched method parameter type", - Method: code.Method{ - Name: "DeleteByGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewArgumentTypeNotMatchedError("Gender", code.SimpleType("Gender"), code.TypeString), - }, - { - Name: "mismatched method parameter type for special case", - Method: code.Method{ - Name: "DeleteByCityIn", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewArgumentTypeNotMatchedError("City", - code.ArrayType{ContainedType: code.TypeString}, code.TypeString), - }, + repoIntf := testutils.Pkg.Scope().Lookup("UserRepositoryInvalidDelete").Type().Underlying().(*types.Interface) + + expectedErrors := []error{ + // Delete + spec.ErrQueryRequired, + // DeleteAll + spec.NewOperationReturnCountUnmatchedError(2), + // DeleteByAge + spec.NewUnsupportedReturnError(code.TypeFloat64, 0), + // DeleteByAndGender + spec.NewInvalidQueryError([]string{"And", "Gender"}), + // DeleteByCity + spec.NewUnsupportedReturnError(code.TypeBool, 1), + // DeleteByCityIn + spec.NewArgumentTypeNotMatchedError("City", types.NewSlice(code.TypeString), code.TypeString), + // DeleteByCountry + spec.NewStructFieldNotFoundError([]string{"Country"}), + // DeleteByEnabled + spec.ErrInvalidParam, + // DeleteByGender + spec.ErrContextParamRequired, + // DeleteByGenderAnd + spec.NewInvalidQueryError([]string{"Gender", "And"}), + // DeleteByGenderAndAndCity + spec.NewInvalidQueryError([]string{"Gender", "And", "And", "City"}), + // DeleteByGenderAndCityOrAge + spec.NewInvalidQueryError([]string{"Gender", "And", "City", "Or", "Age"}), + // DeleteByPhoneNumber + spec.NewArgumentTypeNotMatchedError("PhoneNumber", code.TypeString, code.TypeInt), } - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - _, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) + for i := 0; i < repoIntf.NumMethods(); i++ { + method := repoIntf.Method(i) + + t.Run(method.Name(), func(t *testing.T) { + _, err := spec.ParseInterfaceMethod(testutils.Pkg, testutils.TypeUserNamed, method) - if err.Error() != testCase.ExpectedError.Error() { - t.Errorf("\nExpected = %+v\nReceived = %+v", testCase.ExpectedError, err) + if err.Error() != expectedErrors[i].Error() { + t.Errorf("\nExpected = %+v\nReceived = %+v", expectedErrors[i], err) } }) } } func TestParseInterfaceMethod_Count_Invalid(t *testing.T) { - testTable := []ParseInterfaceMethodInvalidTestCase{ - { - Name: "invalid number of returns", - Method: code.Method{ - Name: "CountAll", - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - code.TypeBool, - }, - }, - ExpectedError: spec.NewOperationReturnCountUnmatchedError(2), - }, - { - Name: "invalid integer return", - Method: code.Method{ - Name: "CountAll", - Returns: []code.Type{ - code.SimpleType("int64"), - code.TypeError, - }, - }, - ExpectedError: spec.NewUnsupportedReturnError(code.SimpleType("int64"), 0), - }, - { - Name: "error return not provided", - Method: code.Method{ - Name: "CountAll", - Returns: []code.Type{ - code.TypeInt, - code.TypeBool, - }, - }, - ExpectedError: spec.NewUnsupportedReturnError(code.TypeBool, 1), - }, - { - Name: "count method without query", - Method: code.Method{ - Name: "Count", - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.ErrQueryRequired, - }, - { - Name: "invalid query", - Method: code.Method{ - Name: "CountBy", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewInvalidQueryError([]string{"By"}), - }, - { - Name: "context parameter not provided", - Method: code.Method{ - Name: "CountAll", - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.ErrContextParamRequired, - }, - { - Name: "mismatched number of parameter", - Method: code.Method{ - Name: "CountByGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.SimpleType("Gender")}, - {Type: code.TypeInt}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.ErrInvalidParam, - }, - { - Name: "mismatched method parameter type", - Method: code.Method{ - Name: "CountByGender", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewArgumentTypeNotMatchedError("Gender", code.SimpleType("Gender"), code.TypeString), - }, - { - Name: "struct field not found", - Method: code.Method{ - Name: "CountByCountry", - Params: []code.Param{ - {Type: code.ExternalType{PackageAlias: "context", Name: "Context"}}, - {Type: code.TypeString}, - }, - Returns: []code.Type{ - code.TypeInt, - code.TypeError, - }, - }, - ExpectedError: spec.NewStructFieldNotFoundError([]string{"Country"}), - }, + repoIntf := testutils.Pkg.Scope().Lookup("UserRepositoryInvalidCount").Type().Underlying().(*types.Interface) + + expectedErrors := []error{ + // Count + spec.ErrQueryRequired, + // CountAll + spec.NewOperationReturnCountUnmatchedError(2), + // CountBy + spec.NewInvalidQueryError([]string{"By"}), + // CountByAge + spec.NewUnsupportedReturnError(code.TypeInt64, 0), + // CountByCity + spec.NewUnsupportedReturnError(code.TypeBool, 1), + // CountByCountry + spec.NewStructFieldNotFoundError([]string{"Country"}), + // CountByEnabled + spec.ErrInvalidParam, + // CountByGender + spec.ErrContextParamRequired, + // CountByPhoneNumber + spec.NewArgumentTypeNotMatchedError("PhoneNumber", code.TypeString, code.TypeInt), } - for _, testCase := range testTable { - t.Run(testCase.Name, func(t *testing.T) { - _, err := spec.ParseInterfaceMethod(structs, structModel, testCase.Method) + for i := 0; i < repoIntf.NumMethods(); i++ { + method := repoIntf.Method(i) + + t.Run(method.Name(), func(t *testing.T) { + _, err := spec.ParseInterfaceMethod(testutils.Pkg, testutils.TypeUserNamed, method) - if err.Error() != testCase.ExpectedError.Error() { - t.Errorf("\nExpected = %+v\nReceived = %+v", testCase.ExpectedError, err) + if err.Error() != expectedErrors[i].Error() { + t.Errorf("\nExpected = %+v\nReceived = %+v", expectedErrors[i], err) } }) } diff --git a/internal/spec/query.go b/internal/spec/query.go index 1365193..929c4ce 100644 --- a/internal/spec/query.go +++ b/internal/spec/query.go @@ -1,8 +1,6 @@ package spec -import ( - "github.com/sunboyy/repogen/internal/code" -) +import "go/types" // QuerySpec is a set of conditions of querying the database type QuerySpec struct { @@ -50,10 +48,10 @@ const ( // ArgumentTypeFromFieldType returns a type of required argument from the given // struct field type. -func (c Comparator) ArgumentTypeFromFieldType(t code.Type) code.Type { +func (c Comparator) ArgumentTypeFromFieldType(t types.Type) types.Type { switch c { case ComparatorIn, ComparatorNotIn: - return code.ArrayType{ContainedType: t} + return types.NewSlice(t) default: return t } @@ -80,8 +78,7 @@ type Predicate struct { } type queryParser struct { - fieldResolver fieldResolver - StructModel code.Struct + UnderlyingStruct *types.Struct } func (p queryParser) parseQuery(rawTokens []string, paramIndex int) (QuerySpec, @@ -200,7 +197,7 @@ func (p queryParser) parsePredicate(t []string, paramIndex int) (Predicate, func (p queryParser) createPredicate(t []string, comparator Comparator, paramIndex int) (Predicate, error) { - fields, ok := p.fieldResolver.ResolveStructField(p.StructModel, t) + fields, ok := resolveStructField(p.UnderlyingStruct, t) if !ok { return Predicate{}, NewStructFieldNotFoundError(t) } diff --git a/internal/spec/update.go b/internal/spec/update.go index 4c7cfea..8d1bccf 100644 --- a/internal/spec/update.go +++ b/internal/spec/update.go @@ -1,6 +1,6 @@ package spec -import "github.com/sunboyy/repogen/internal/code" +import "go/types" // UpdateOperation is a method specification for update operations type UpdateOperation struct { @@ -72,18 +72,20 @@ func (o UpdateOperator) NumberOfArguments() int { } // ArgumentType returns type that is required for function parameter -func (o UpdateOperator) ArgumentType(fieldType code.Type) code.Type { +func (o UpdateOperator) ArgumentType(fieldType types.Type) types.Type { switch o { case UpdateOperatorPush: - arrayType := fieldType.(code.ArrayType) - return arrayType.ContainedType + sliceType := fieldType.(*types.Slice) + return sliceType.Elem() default: return fieldType } } func (p interfaceMethodParser) parseUpdateOperation(tokens []string) (Operation, error) { - mode, err := p.extractIntOrBoolReturns(p.Method.Returns) + signature := p.Method.Type().(*types.Signature) + + mode, err := p.extractIntOrBoolReturns(signature.Results()) if err != nil { return nil, err } @@ -104,7 +106,7 @@ func (p interfaceMethodParser) parseUpdateOperation(tokens []string) (Operation, return nil, err } - if err := p.validateQueryFromParams(p.Method.Params[update.NumberOfArguments()+1:], querySpec); err != nil { + if err := p.validateQueryFromParams(signature.Params(), 1+update.NumberOfArguments(), querySpec); err != nil { return nil, err } @@ -116,9 +118,11 @@ func (p interfaceMethodParser) parseUpdateOperation(tokens []string) (Operation, } func (p interfaceMethodParser) parseUpdate(tokens []string) (Update, error) { + signature := p.Method.Type().(*types.Signature) + if len(tokens) == 0 { - requiredType := code.PointerType{ContainedType: p.StructModel.ReferencedType()} - if len(p.Method.Params) <= 1 || p.Method.Params[1].Type != requiredType { + expectedType := types.NewPointer(p.NamedStruct) + if signature.Params().Len() <= 1 || !types.Identical(signature.Params().At(1).Type(), expectedType) { return nil, ErrInvalidUpdateFields } return UpdateModel{}, nil @@ -143,16 +147,16 @@ func (p interfaceMethodParser) parseUpdate(tokens []string) (Update, error) { } for _, field := range updateFields { - if len(p.Method.Params) < field.ParamIndex+field.Operator.NumberOfArguments() { + if signature.Params().Len() < field.ParamIndex+field.Operator.NumberOfArguments() { return nil, ErrInvalidUpdateFields } - requiredType := field.Operator.ArgumentType(field.FieldReference.ReferencedField().Type) + expectedType := field.Operator.ArgumentType(field.FieldReference.ReferencedField().Var.Type()) for i := 0; i < field.Operator.NumberOfArguments(); i++ { - if requiredType != p.Method.Params[field.ParamIndex+i].Type { - return nil, NewArgumentTypeNotMatchedError(field.FieldReference.ReferencingCode(), requiredType, - p.Method.Params[field.ParamIndex+i].Type) + if !types.Identical(signature.Params().At(field.ParamIndex+i).Type(), expectedType) { + return nil, NewArgumentTypeNotMatchedError(field.FieldReference.ReferencingCode(), expectedType, + signature.Params().At(field.ParamIndex+i).Type()) } } } @@ -175,12 +179,12 @@ func (p interfaceMethodParser) parseUpdateField(t []string, func (p interfaceMethodParser) createUpdateField(t []string, operator UpdateOperator, paramIndex int) (UpdateField, error) { - fieldReference, ok := p.fieldResolver.ResolveStructField(p.StructModel, t) + fieldReference, ok := resolveStructField(p.UnderlyingStruct, t) if !ok { return UpdateField{}, NewStructFieldNotFoundError(t) } - if !p.validateUpdateOperator(fieldReference.ReferencedField().Type, operator) { + if !p.validateUpdateOperator(fieldReference.ReferencedField().Var.Type(), operator) { return UpdateField{}, NewIncompatibleUpdateOperatorError(operator, fieldReference) } @@ -191,13 +195,28 @@ func (p interfaceMethodParser) createUpdateField(t []string, }, nil } -func (p interfaceMethodParser) validateUpdateOperator(referencedType code.Type, operator UpdateOperator) bool { +func (p interfaceMethodParser) validateUpdateOperator(referencedType types.Type, operator UpdateOperator) bool { switch operator { case UpdateOperatorPush: - _, ok := referencedType.(code.ArrayType) + _, ok := referencedType.(*types.Slice) return ok case UpdateOperatorInc: - return referencedType.IsNumber() + switch t := referencedType.(type) { + case *types.Basic: + return t.Kind() == types.Int || t.Kind() == types.Int8 || t.Kind() == types.Int16 || + t.Kind() == types.Int32 || t.Kind() == types.Int64 || t.Kind() == types.Uint || + t.Kind() == types.Uint8 || t.Kind() == types.Uint16 || t.Kind() == types.Uint32 || + t.Kind() == types.Uint64 || t.Kind() == types.Float32 || t.Kind() == types.Float64 + + case *types.Pointer: + return p.validateUpdateOperator(t.Elem(), operator) + + case *types.Named: + return p.validateUpdateOperator(t.Underlying(), operator) + + default: + return false + } } return true } diff --git a/internal/teststub/user.go b/internal/teststub/user.go new file mode 100644 index 0000000..897d9c6 --- /dev/null +++ b/internal/teststub/user.go @@ -0,0 +1,316 @@ +package teststub + +import ( + "context" + + "go.mongodb.org/mongo-driver/bson/primitive" +) + +type Gender string + +type User struct { + ID primitive.ObjectID + PhoneNumber string + Gender Gender + City string + Age int + Name Name + Contact Contact + Referrer *User + Enabled bool + ConsentHistory []ConsentHistory + AccessToken string +} + +type Name struct { + First string + Last string +} + +type Contact struct { + Phone string +} + +type ConsentHistory struct { + ID primitive.ObjectID + Value bool +} + +type UserRepositoryInsert interface { + InsertMany(ctx context.Context, users []*User) ([]interface{}, error) + InsertOne(ctx context.Context, user *User) (interface{}, error) +} + +type UserRepositoryFind interface { + // Test find all + FindAll(ctx context.Context) ([]*User, error) + // Test find with Between operator + FindByAgeBetween(ctx context.Context, fromAge int, toAge int) ([]*User, error) + // Test find with GreaterThan operator + FindByAgeGreaterThan(ctx context.Context, age int) ([]*User, error) + // Test find with GreaterThanEqual operator + FindByAgeGreaterThanEqual(ctx context.Context, age int) ([]*User, error) + // Test find with LessThan operator + FindByAgeLessThan(ctx context.Context, age int) ([]*User, error) + // Test find with LessThanEqual operator + FindByAgeLessThanEqual(ctx context.Context, age int) ([]*User, error) + // Test find MANY mode + FindByCity(ctx context.Context, city string) ([]*User, error) + // Test find with And operator + FindByCityAndGender(ctx context.Context, city string, gender Gender) ([]*User, error) + // Test find with In operator + FindByCityIn(ctx context.Context, cities []string) ([]*User, error) + // Test find with Not operator + FindByCityNot(ctx context.Context, city string) ([]*User, error) + // Test find with NotIn operator + FindByCityNotIn(ctx context.Context, cities []string) ([]*User, error) + // Test find with Or operator + FindByCityOrGender(ctx context.Context, city string, gender Gender) ([]*User, error) + // Test find ordering without explicit direction + FindByCityOrderByAge(ctx context.Context, city string) ([]*User, error) + // Test find ordering with explicit ascending direction + FindByCityOrderByAgeAsc(ctx context.Context, city string) ([]*User, error) + // Test find ordering with explicit descending direction + FindByCityOrderByAgeDesc(ctx context.Context, city string) ([]*User, error) + // Test find with multiple ordering + FindByCityOrderByCityAndAgeDesc(ctx context.Context, city string) ([]*User, error) + // Test find with deep reference ordering + FindByCityOrderByNameFirst(ctx context.Context, city string) ([]*User, error) + // Test find with False operator + FindByEnabledFalse(ctx context.Context) ([]*User, error) + // Test find with True operator + FindByEnabledTrue(ctx context.Context) ([]*User, error) + // Test find ONE mode + FindByID(ctx context.Context, id primitive.ObjectID) (*User, error) + // Test find with deep referencing + FindByNameFirst(ctx context.Context, firstName string) ([]*User, error) + // Test find with multi-word arg + FindByPhoneNumber(ctx context.Context, phoneNumber string) (*User, error) + // Test find with Exists operator + FindByReferrerExists(ctx context.Context) ([]*User, error) + // Test find with deep pointer referencing + FindByReferrerID(ctx context.Context, id primitive.ObjectID) ([]*User, error) + // Test find with NotExists operator + FindByReferrerNotExists(ctx context.Context) ([]*User, error) + // Test find Top N + FindTop5ByGenderOrderByAgeDesc(ctx context.Context, gender Gender) ([]*User, error) +} + +type UserRepositoryUpdate interface { + // Test update inc operator + UpdateAgeIncByID(ctx context.Context, age int, id primitive.ObjectID) (bool, error) + // Test update model ONE mode + UpdateByID(ctx context.Context, user *User, id primitive.ObjectID) (bool, error) + // Test update push operator + UpdateConsentHistoryPushByID(ctx context.Context, consentHistoryItem ConsentHistory, + id primitive.ObjectID) (int, error) + // Test update multiple fields with push operator + UpdateEnabledAndConsentHistoryPushByID(ctx context.Context, enabled bool, + consentHistoryItem ConsentHistory, id primitive.ObjectID) (int, error) + // Test update multiple fields + UpdateGenderAndCityByID(ctx context.Context, gender Gender, city string, id primitive.ObjectID) (int, error) + // Test update field MANY mode + UpdateGenderByAge(ctx context.Context, gender Gender, age int) (int, error) + // Test update field ONE mode + UpdateGenderByID(ctx context.Context, gender Gender, id primitive.ObjectID) (bool, error) + // Test update deep reference field + UpdateNameFirstByID(ctx context.Context, firstName string, id primitive.ObjectID) (bool, error) +} + +type UserRepositoryDelete interface { + // Test delete all + DeleteAll(ctx context.Context) (int, error) + // Test delete with Between operator + DeleteByAgeBetween(ctx context.Context, fromAge int, toAge int) (int, error) + // Test delete with GreaterThan operator + DeleteByAgeGreaterThan(ctx context.Context, age int) (int, error) + // Test delete with GreaterThanEqual operator + DeleteByAgeGreaterThanEqual(ctx context.Context, age int) (int, error) + // Test delete with LessThan operator + DeleteByAgeLessThan(ctx context.Context, age int) (int, error) + // Test delete with LessThanEqual operator + DeleteByAgeLessThanEqual(ctx context.Context, age int) (int, error) + // Test delete MANY mode + DeleteByCity(ctx context.Context, city string) (int, error) + // Test delete with And operator + DeleteByCityAndGender(ctx context.Context, city string, gender Gender) (int, error) + // Test delete with In operator + DeleteByCityIn(ctx context.Context, cities []string) (int, error) + // Test delete with Not operator + DeleteByCityNot(ctx context.Context, city string) (int, error) + // Test delete with Or operator + DeleteByCityOrGender(ctx context.Context, city string, gender Gender) (int, error) + // Test delete ONE mode + DeleteByID(ctx context.Context, id primitive.ObjectID) (bool, error) + // Test delete with deep reference + DeleteByNameFirst(ctx context.Context, firstName string) (int, error) + // Test delete multi-word arg + DeleteByPhoneNumber(ctx context.Context, phoneNumber string) (bool, error) +} + +type UserRepositoryCount interface { + // Test count all + CountAll(ctx context.Context) (int, error) + // Test count with query + CountByGender(ctx context.Context, gender Gender) (int, error) + // Test count with deep reference + CountByNameFirst(ctx context.Context, firstName string) (int, error) +} + +type UserRepositoryInvalidOperation interface { + SearchByID(ctx context.Context, id primitive.ObjectID) (*User, error) +} + +type UserRepositoryInvalidInsert interface { + // Test insert with invalid number of returns + Insert1(ctx context.Context, user *User) (*User, interface{}, error) + // Test insert with invalid return type + Insert2(ctx context.Context, user *User) (*User, error) + // Test insert with unempty interface return + Insert3(ctx context.Context, user *User) (interface{ Foo() }, error) + // Test insert with no error return + Insert4(ctx context.Context, user *User) (interface{}, bool) + // Test insert with no context parameter + Insert5(user *User) (interface{}, error) + // Test insert with mismatched model parameter for ONE mode + Insert6(ctx context.Context, userModel []*User) (interface{}, error) + // Test insert with mismatched model parameter for MANY mode + Insert7(ctx context.Context, userModel []*User) (interface{}, error) +} + +type UserRepositoryInvalidFind interface { + // Test find without query + Find(ctx context.Context) ([]*User, error) + // Test find with invalid number of returns + FindAll(ctx context.Context) ([]*User, int, error) + // Test find with misplaced sort operator token (rightmost) + FindAllOrderByAgeAnd(ctx context.Context) ([]*User, error) + // Test find with misplaced sort operator token (double operator) + FindAllOrderByAgeAndAndGender(ctx context.Context) ([]*User, error) + // Test find with misplaced sort operator token (leftmost) + FindAllOrderByAndAge(ctx context.Context) ([]*User, error) + // Test find with sort struct field not found + FindAllOrderByCountry(ctx context.Context) ([]*User, error) + // Test find with no context parameter + FindByAge(age int) ([]*User, error) + // Test find with misplaced query operator token (leftmost) + FindByAndGender(ctx context.Context, gender Gender) ([]*User, error) + // Test find with mismatched number of parameters + FindByCity(ctx context.Context, city string, gender Gender) ([]*User, error) + // Test find with mismatched parameter with In query + FindByCityIn(ctx context.Context, city string) ([]*User, error) + // Test find with query struct field not found + FindByCountry(ctx context.Context, country string) ([]*User, error) + // test find with mismatched parameter type + FindByGender(ctx context.Context, gender string) ([]*User, error) + // Test find with misplaced query operator token (rightmost) + FindByGenderAnd(ctx context.Context, gender Gender) ([]*User, error) + // Test find with misplaced query operator token (double operator) + FindByGenderAndAndCity(ctx context.Context, gender Gender, city string) ([]*User, error) + // Test find with ambiguous operator + FindByGenderAndCityOrAge(ctx context.Context, gender Gender, city string, age int) ([]*User, error) + // Test find with incompatible struct field for False comparator + FindByGenderFalse(ctx context.Context) ([]*User, error) + // Test find with incompatible struct field for True comparator + FindByGenderTrue(ctx context.Context) ([]*User, error) + // Test find with invalid return type + FindByID(ctx context.Context, id primitive.ObjectID) (User, error) + // Test find with deep reference field not found + FindByNameMiddle(ctx context.Context, middleName string) ([]*User, error) + // Test find top with no number and query + FindTop(ctx context.Context) ([]*User, error) + // Test find top 0 + FindTop0All(ctx context.Context) ([]*User, error) + // Test find top in ONE mode + FindTop5All(ctx context.Context) (*User, error) + // Test find top with no number + FindTopAll(ctx context.Context) ([]*User, error) +} + +type UserRepositoryInvalidUpdate interface { + // Test update with mismatched And token in update fields + UpdateAgeAndAndGenderByID(ctx context.Context, age int, gender Gender, + id primitive.ObjectID) (bool, error) + // Test update without context parameter + UpdateAgeByGender(age int, gender Gender) (int, error) + // Test update with invalid number of returns + UpdateAgeByID(ctx context.Context, age int, id primitive.ObjectID) (bool, int, error) + // Test update with ambiguous query + UpdateAgeByIDAndUsernameOrGender(ctx context.Context, age int, id primitive.ObjectID, + username string, gender Gender) (bool, error) + // Test update model with invalid parameter type + UpdateByGender(ctx context.Context, gender Gender) (bool, error) + // Test update with no update parameter provided + UpdateByID(ctx context.Context, id primitive.ObjectID) (bool, error) + // Test update without query + UpdateCity(ctx context.Context, city string) (bool, error) + // Test update with invalid return type + UpdateCityByID(ctx context.Context, city string, id primitive.ObjectID) (float64, error) + // Test update with inc operator in non-number field + UpdateCityIncByID(ctx context.Context, city string, id primitive.ObjectID) (bool, error) + // Test update with push operator with incorrect parameter type + UpdateConsentHistoryPushByID(ctx context.Context, consentHistoryItem []ConsentHistory, + id primitive.ObjectID) (int, error) + // Test update field not found in struct + UpdateCountryByGender(ctx context.Context, country string, gender Gender) (int, error) + // Test update with insufficient function parameters + UpdateEnabledAll(ctx context.Context) (int, error) + // Test update with incorrect parameter type for query + UpdateEnabledByCity(ctx context.Context, enabled bool, city int) (bool, error) + // Test update with incorrect parameter type for update field + UpdateEnabledByGender(ctx context.Context, enabled int, gender Gender) (bool, error) + // Test update with no error return + UpdateEnabledByID(ctx context.Context, enabled bool, id primitive.ObjectID) (bool, bool) + // Test update with push operator in non-array field + UpdateGenderPushByID(ctx context.Context, gender Gender, id primitive.ObjectID) (bool, error) +} + +type UserRepositoryInvalidDelete interface { + // Test delete without query + Delete(ctx context.Context) (int, error) + // Test delete with invalid number of returns + DeleteAll(ctx context.Context) (*User, int, error) + // Test delete with unsupported return type + DeleteByAge(ctx context.Context, age int) (float64, error) + // Test delete with misplaced operator token (leftmost) + DeleteByAndGender(ctx context.Context, gender Gender) (bool, error) + // Test delete with no error return + DeleteByCity(ctx context.Context, city string) (int, bool) + // Test delete with mismatched parameter type for In operator + DeleteByCityIn(ctx context.Context, city string) (int, error) + // Test delete with query struct field not found + DeleteByCountry(ctx context.Context, country string) (int, error) + // Test delete with mismatched number of parameters + DeleteByEnabled(ctx context.Context, enabled bool, enabled2 bool) (int, error) + // Test delete without context parameter + DeleteByGender(gender Gender) (int, error) + // Test delete with misplaced operator token (rightmost) + DeleteByGenderAnd(ctx context.Context, gender Gender) (bool, error) + // Test delete with misplaced operator token (double operator) + DeleteByGenderAndAndCity(ctx context.Context, gender Gender, city string) (bool, error) + // Test delete with ambiguous query + DeleteByGenderAndCityOrAge(ctx context.Context, gender Gender, city string, age int) (bool, error) + // Test delete with mismatched parameter type + DeleteByPhoneNumber(ctx context.Context, phoneNumber int) (bool, error) +} + +type UserRepositoryInvalidCount interface { + // Test count with query + Count(ctx context.Context) (int, error) + // Test count with invalid number of returns + CountAll(ctx context.Context) (int, error, bool) + // Test count with invalid query + CountBy(ctx context.Context) (int, error) + // Test count with invalid integer return + CountByAge(ctx context.Context, age int) (int64, error) + // Test count with no error return + CountByCity(ctx context.Context, city string) (int, bool) + // Test count with struct field not found + CountByCountry(ctx context.Context, country string) (int, error) + // Test count with mismatched number of parameters + CountByEnabled(ctx context.Context, enabled bool, enabled2 bool) (int, error) + // Test count without context parameter + CountByGender(gender Gender) (int, error) + // Test count with mismatched parameter type + CountByPhoneNumber(ctx context.Context, phoneNumber int) (int, error) +} diff --git a/internal/testutils/stub_provider.go b/internal/testutils/stub_provider.go new file mode 100644 index 0000000..f1e03b6 --- /dev/null +++ b/internal/testutils/stub_provider.go @@ -0,0 +1,62 @@ +package testutils + +import ( + "go/types" + + "golang.org/x/tools/go/packages" +) + +var ( + TypeContextNamed *types.Named + TypeObjectIDNamed *types.Named + TypeCollectionNamed *types.Named + + Pkg *types.Package + TypeUserNamed *types.Named + TypeUserStruct *types.Struct + TypeGenderNamed *types.Named + TypeNameStruct *types.Struct + TypeConsentHistoryNamed *types.Named +) + +func init() { + cfg := &packages.Config{Mode: packages.NeedTypes} + + contextPkgs, err := packages.Load(cfg, "context") + if err != nil { + panic(err) + } + TypeContextNamed = contextPkgs[0].Types.Scope().Lookup("Context").Type().(*types.Named) + + primitivePkgs, err := packages.Load(cfg, "go.mongodb.org/mongo-driver/bson/primitive") + if err != nil { + panic(err) + } + TypeObjectIDNamed = primitivePkgs[0].Types.Scope().Lookup("ObjectID").Type().(*types.Named) + + mongoPkgs, err := packages.Load(cfg, "go.mongodb.org/mongo-driver/mongo") + if err != nil { + panic(err) + } + TypeCollectionNamed = mongoPkgs[0].Types.Scope().Lookup("Collection").Type().(*types.Named) + + stubPkgs, err := packages.Load(cfg, "../teststub") + if err != nil { + panic(err) + } + Pkg = stubPkgs[0].Types + TypeUserNamed = Pkg.Scope().Lookup("User").Type().(*types.Named) + TypeUserStruct = TypeUserNamed.Underlying().(*types.Struct) + TypeGenderNamed = Pkg.Scope().Lookup("Gender").Type().(*types.Named) + TypeNameStruct = Pkg.Scope().Lookup("Name").Type().Underlying().(*types.Struct) + TypeConsentHistoryNamed = Pkg.Scope().Lookup("ConsentHistory").Type().(*types.Named) +} + +func FindStructFieldByName(s *types.Struct, name string) *types.Var { + for i := 0; i < s.NumFields(); i++ { + if s.Field(i).Name() == name { + return s.Field(i) + } + } + return nil +} diff --git a/main.go b/main.go index 82fe498..21d529a 100644 --- a/main.go +++ b/main.go @@ -4,11 +4,11 @@ import ( "errors" "flag" "fmt" + "go/types" "log" "os" "path/filepath" - "github.com/sunboyy/repogen/internal/code" "github.com/sunboyy/repogen/internal/generator" "github.com/sunboyy/repogen/internal/spec" "golang.org/x/tools/go/packages" @@ -82,7 +82,7 @@ func printVersion() { func generateFromRequest(pkgDir, structModelName, repositoryInterfaceName string) (string, error) { cfg := packages.Config{ - Mode: packages.NeedName | packages.NeedSyntax | packages.NeedTypes, + Mode: packages.NeedName | packages.NeedTypes, } pkgs, err := packages.Load(&cfg, pkgDir) if err != nil { @@ -92,41 +92,49 @@ func generateFromRequest(pkgDir, structModelName, repositoryInterfaceName string return "", errNoPackageFound } - pkgPkg := pkgs[0] + pkg := pkgs[0] - pkg, err := code.ParsePackage(pkgPkg) - if err != nil { - return "", err - } - - return generateRepository(pkg, structModelName, repositoryInterfaceName) + return generateRepository(pkg.Types, structModelName, repositoryInterfaceName) } var ( errNoPackageFound = errors.New("no package found") errStructNotFound = errors.New("struct not found") + errNotNamedStruct = errors.New("not a named struct") errInterfaceNotFound = errors.New("interface not found") + errNotInterface = errors.New("not an interface") ) -func generateRepository(pkg code.Package, structModelName, repositoryInterfaceName string) (string, error) { - structModel, ok := pkg.Structs[structModelName] - if !ok { +func generateRepository(pkg *types.Package, structModelName, repositoryInterfaceName string) (string, error) { + structModelObj := pkg.Scope().Lookup(structModelName) + if structModelObj == nil { return "", errStructNotFound } - - intf, ok := pkg.Interfaces[repositoryInterfaceName] + namedStruct, ok := structModelObj.Type().(*types.Named) if !ok { + return "", errNotNamedStruct + } + + intfObj := pkg.Scope().Lookup(repositoryInterfaceName) + if intfObj == nil { return "", errInterfaceNotFound } + intf, ok := intfObj.Type().Underlying().(*types.Interface) + if !ok { + return "", errNotInterface + } var methodSpecs []spec.MethodSpec - for _, method := range intf.Methods { - methodSpec, err := spec.ParseInterfaceMethod(pkg.Structs, structModel, method) + for i := 0; i < intf.NumMethods(); i++ { + method := intf.Method(i) + log.Println("Generating method:", method.Name()) + + methodSpec, err := spec.ParseInterfaceMethod(pkg, namedStruct, method) if err != nil { return "", err } methodSpecs = append(methodSpecs, methodSpec) } - return generator.GenerateRepository(pkg.Name, structModel, repositoryInterfaceName, methodSpecs) + return generator.GenerateRepository(pkg, structModelName, repositoryInterfaceName, methodSpecs) } diff --git a/test/generator_test_expected.txt b/test/generator_test_expected.txt index 4b23b95..22a270e 100644 --- a/test/generator_test_expected.txt +++ b/test/generator_test_expected.txt @@ -1,5 +1,5 @@ // Code generated by repogen. DO NOT EDIT. -package user +package teststub import ( "context" @@ -20,8 +20,8 @@ type UserRepositoryMongo struct { collection *mongo.Collection } -func (r *UserRepositoryMongo) FindByID(arg0 context.Context, arg1 primitive.ObjectID) (*UserModel, error) { - var entity UserModel +func (r *UserRepositoryMongo) FindByID(arg0 context.Context, arg1 primitive.ObjectID) (*User, error) { + var entity User if err := r.collection.FindOne(arg0, bson.M{ "_id": arg1, }, options.FindOne().SetSort(bson.M{})).Decode(&entity); err != nil { @@ -30,7 +30,7 @@ func (r *UserRepositoryMongo) FindByID(arg0 context.Context, arg1 primitive.Obje return &entity, nil } -func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(arg0 context.Context, arg1 Gender, arg2 int) (*UserModel, error) { +func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(arg0 context.Context, arg1 Gender, arg2 int) (*User, error) { cursor, err := r.collection.Find(arg0, bson.M{ "$and": []bson.M{ { @@ -48,14 +48,14 @@ func (r *UserRepositoryMongo) FindByGenderNotAndAgeLessThan(arg0 context.Context if err != nil { return nil, err } - entities := []*UserModel{} + entities := []*User{} if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil } -func (r *UserRepositoryMongo) FindByAgeLessThanEqualOrderByAge(arg0 context.Context, arg1 int) ([]*UserModel, error) { +func (r *UserRepositoryMongo) FindByAgeLessThanEqualOrderByAge(arg0 context.Context, arg1 int) ([]*User, error) { cursor, err := r.collection.Find(arg0, bson.M{ "age": bson.M{ "$lte": arg1, @@ -66,14 +66,14 @@ func (r *UserRepositoryMongo) FindByAgeLessThanEqualOrderByAge(arg0 context.Cont if err != nil { return nil, err } - entities := []*UserModel{} + entities := []*User{} if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil } -func (r *UserRepositoryMongo) FindByAgeGreaterThanOrderByAgeAsc(arg0 context.Context, arg1 int) ([]*UserModel, error) { +func (r *UserRepositoryMongo) FindByAgeGreaterThanOrderByAgeAsc(arg0 context.Context, arg1 int) ([]*User, error) { cursor, err := r.collection.Find(arg0, bson.M{ "age": bson.M{ "$gt": arg1, @@ -84,14 +84,14 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanOrderByAgeAsc(arg0 context.Con if err != nil { return nil, err } - entities := []*UserModel{} + entities := []*User{} if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil } -func (r *UserRepositoryMongo) FindByAgeGreaterThanEqualOrderByAgeDesc(arg0 context.Context, arg1 int) ([]*UserModel, error) { +func (r *UserRepositoryMongo) FindByAgeGreaterThanEqualOrderByAgeDesc(arg0 context.Context, arg1 int) ([]*User, error) { cursor, err := r.collection.Find(arg0, bson.M{ "age": bson.M{ "$gte": arg1, @@ -102,14 +102,14 @@ func (r *UserRepositoryMongo) FindByAgeGreaterThanEqualOrderByAgeDesc(arg0 conte if err != nil { return nil, err } - entities := []*UserModel{} + entities := []*User{} if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil } -func (r *UserRepositoryMongo) FindByAgeBetween(arg0 context.Context, arg1 int, arg2 int) ([]*UserModel, error) { +func (r *UserRepositoryMongo) FindByAgeBetween(arg0 context.Context, arg1 int, arg2 int) ([]*User, error) { cursor, err := r.collection.Find(arg0, bson.M{ "age": bson.M{ "$gte": arg1, @@ -119,14 +119,14 @@ func (r *UserRepositoryMongo) FindByAgeBetween(arg0 context.Context, arg1 int, a if err != nil { return nil, err } - entities := []*UserModel{} + entities := []*User{} if err := cursor.All(arg0, &entities); err != nil { return nil, err } return entities, nil } -func (r *UserRepositoryMongo) FindByGenderOrAge(arg0 context.Context, arg1 Gender, arg2 int) ([]*UserModel, error) { +func (r *UserRepositoryMongo) FindByGenderOrAge(arg0 context.Context, arg1 Gender, arg2 int) ([]*User, error) { cursor, err := r.collection.Find(arg0, bson.M{ "$or": []bson.M{ { @@ -140,7 +140,7 @@ func (r *UserRepositoryMongo) FindByGenderOrAge(arg0 context.Context, arg1 Gende if err != nil { return nil, err } - entities := []*UserModel{} + entities := []*User{} if err := cursor.All(arg0, &entities); err != nil { return nil, err }