diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c51bb0df..eb206ac2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,6 +14,7 @@ jobs: matrix: go-version: [ 1.18.x, 1.19.x, 1.20.x, 1.21.x, 1.22.x ] platform: [ ubuntu-latest ] + mysql-version: ['mysql:latest', 'mysql:5.7'] runs-on: ${{ matrix.platform }} services: @@ -23,6 +24,21 @@ jobs: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5 ports: - 6379:6379 + mysql: + image: ${{ matrix.mysql-version }} + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + ports: + - 3306:3306 + options: >- + --health-cmd "mysqladmin ping -ugorm -pgorm" + --health-interval 10s + --health-start-period 10s + --health-timeout 5s + --health-retries 10 steps: - name: Checkout code diff --git a/go.mod b/go.mod index 0e052bcd..e917bb12 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/stretchr/testify v1.8.4 golang.org/x/term v0.17.0 golang.org/x/text v0.14.0 + gorm.io/driver/mysql v1.5.4 gorm.io/gorm v1.25.7 ) @@ -31,9 +32,12 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.18.0 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/gorilla/mux v1.8.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/kr/text v0.2.0 // indirect diff --git a/go.sum b/go.sum index dc7c25fc..133ab269 100644 --- a/go.sum +++ b/go.sum @@ -45,6 +45,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.18.0 h1:BvolUXjp4zuvkZ5YN5t7ebzbhlUtPsPm2S9NAZ5nl9U= github.com/go-playground/validator/v10 v10.18.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -57,7 +59,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= @@ -141,6 +146,9 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.4 h1:igQmHfKcbaTVyAIHNhhB888vvxh8EdQ2uSUT0LPcBso= +gorm.io/driver/mysql v1.5.4/go.mod h1:9rYxJph/u9SWkWc9yY4XJ1F/+xO0S/ChOmbk3+Z5Tvs= +gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= diff --git a/gorm/scopes/README.md b/gorm/scopes/README.md new file mode 100644 index 00000000..42298182 --- /dev/null +++ b/gorm/scopes/README.md @@ -0,0 +1,66 @@ +# Gorm/Scopes + +## Example + +```go +package main + +import ( + "time" + + "gorm.io/gorm" + + "github.com/go-kratos-ecosystem/components/v2/gorm/scopes" +) + +func main() { + var db *gorm.DB + + db.Scopes(scopes. + // trait + When(true, func(db *gorm.DB) *gorm.DB { + return db.Where("deleted_at IS NULL") + }). + Unless(true, func(db *gorm.DB) *gorm.DB { + return db.Where("deleted_at IS NOT NULL") + }). + + // Where + Where("name = ?", "Flc"). + WhereBetween("created_at", time.Now(), time.Now()). + WhereNotBetween("created_at", time.Now(), time.Now()). + WhereIn("name", "Flc", "Flc 2"). + WhereNotIn("name", "Flc", "Flc 2"). + WhereLike("name", "Flc%"). + WhereNotLike("name", "Flc%"). + WhereEq("name", "Flc"). + WhereNe("name", "Flc"). + WhereGt("age", 18). + WhereEgt("age", 18). + WhereLt("age", 18). + WhereElt("age", 18). + + // Order + OrderBy("id"). + OrderBy("id", "desc"). + OrderBy("id", "asc"). + OrderByDesc("id"). + OrderByAsc("id"). + OrderByRaw("id desc"). + + // Limit + Limit(10). + Take(10). + + // Offset + Offset(10). + Skip(10). + + // Page + Page(1, 20). + + // To Scope() + Scope()). + Find(&[]struct{}{}) +} +``` \ No newline at end of file diff --git a/gorm/scopes/init_test.go b/gorm/scopes/init_test.go new file mode 100644 index 00000000..60e4dd71 --- /dev/null +++ b/gorm/scopes/init_test.go @@ -0,0 +1,92 @@ +package scopes + +import ( + "log" + "os" + "time" + + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +var ( + DB *gorm.DB + dsn = "gorm:gorm@tcp(localhost:3306)/gorm?charset=utf8&parseTime=True&loc=Local" +) + +type User struct { + gorm.Model + Name string `gorm:"column:name"` + Age uint `gorm:"column:age"` + Sex string `gorm:"column:sex"` + Birthday *time.Time `gorm:"column:birthday"` + Address *string `gorm:"column:address"` +} + +func init() { + var err error + DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}) + if err != nil { + log.Println("failed to connect database, got error", err) + os.Exit(1) + } + + runMigrations() +} + +func runMigrations() { + var err error + models := []interface{}{&User{}} + + if err = DB.Migrator().DropTable(models...); err != nil { + log.Printf("Didn't drop table, got error %v\n", err) + os.Exit(1) + } + + if err = DB.AutoMigrate(models...); err != nil { + log.Printf("Failed to auto migrate, but got error %v\n", err) + os.Exit(1) + } + + for _, m := range models { + if !DB.Migrator().HasTable(m) { + log.Printf("Didn't create table for %#v\n", m) + os.Exit(1) + } + } +} + +type GetUserOptions struct { + Age int + Birthday *time.Time + Address *string +} + +func GetUser(name string, opts GetUserOptions) *User { + var ( + birthday = time.Now().Round(time.Second) + user = User{ + Name: name, + Age: 18, + Birthday: &birthday, + } + ) + + if opts.Age > 0 { + user.Age = uint(opts.Age) + } + + if opts.Birthday != nil { + user.Birthday = opts.Birthday + } + + if opts.Address != nil { + user.Address = opts.Address + } + + return &user +} + +func CleanUsers() { + DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&User{}) +} diff --git a/gorm/scopes/order.go b/gorm/scopes/order.go new file mode 100644 index 00000000..23a9a24f --- /dev/null +++ b/gorm/scopes/order.go @@ -0,0 +1,85 @@ +package scopes + +import ( + "fmt" + "strings" + + "gorm.io/gorm" +) + +// OrderBy add order by condition +// +// OrderBy("name") +// OrderBy("name", "desc") +// OrderBy("name", "asc") +func OrderBy(column string, reorder ...string) *Scopes { + return New().OrderBy(column, reorder...) +} + +// OrderByDesc add order by desc condition +// +// OrderByDesc("name") +func OrderByDesc(column string) *Scopes { + return New().OrderByDesc(column) +} + +// OrderByAsc add order by asc condition +// +// OrderByAsc("name") +func OrderByAsc(column string) *Scopes { + return New().OrderByAsc(column) +} + +// OrderByRaw add order by raw condition +// +// OrderByRaw("name desc") +// OrderByRaw("name asc") +// OrderByRaw("name desc, age asc") +// OrderByRaw("FIELD(id, 3, 1, 2)") +func OrderByRaw(sql interface{}) *Scopes { + return New().OrderByRaw(sql) +} + +// OrderBy add order by condition +// +// OrderBy("name") +// OrderBy("name", "desc") +// OrderBy("name", "asc") +func (s *Scopes) OrderBy(column string, reorder ...string) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Order(fmt.Sprintf("%s %s", column, s.buildReorder(reorder...))) + }) +} + +// OrderByDesc add order by desc condition +// +// OrderByDesc("name") +func (s *Scopes) OrderByDesc(column string) *Scopes { + return s.OrderBy(column, "desc") +} + +// OrderByAsc add order by asc condition +// +// OrderByAsc("name") +func (s *Scopes) OrderByAsc(column string) *Scopes { + return s.OrderBy(column, "asc") +} + +// OrderByRaw add order by raw condition +// +// OrderByRaw("name desc") +// OrderByRaw("name asc") +// OrderByRaw("name desc, age asc") +// OrderByRaw("FIELD(id, 3, 1, 2)") +func (s *Scopes) OrderByRaw(sql interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Order(sql) + }) +} + +func (s *Scopes) buildReorder(reorder ...string) string { + if len(reorder) > 0 && strings.ToUpper(reorder[0]) == "DESC" { + return "DESC" + } + return "ASC" +} diff --git a/gorm/scopes/order_test.go b/gorm/scopes/order_test.go new file mode 100644 index 00000000..ea57bc7a --- /dev/null +++ b/gorm/scopes/order_test.go @@ -0,0 +1,74 @@ +package scopes + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestOrderBy(t *testing.T) { + birthday1 := time.Now() + birthday2 := time.Now() + birthday3 := time.Now().Add(2 * time.Hour) + birthday4 := time.Now().Add(3 * time.Hour) + users := []*User{ + GetUser("OrderUser1", GetUserOptions{Age: 17, Birthday: &birthday1}), + GetUser("OrderUser2", GetUserOptions{Age: 20, Birthday: &birthday2}), + GetUser("OrderUser3", GetUserOptions{Age: 21, Birthday: &birthday3}), + GetUser("OrderUser4", GetUserOptions{Age: 22, Birthday: &birthday4}), + } + + CleanUsers() + DB.Create(&users) + + // OrderBy + var users1, users2, users3, users4 []User + DB.Scopes(OrderBy("age").Scope()).Limit(2).Find(&users1) + assert.Len(t, users1, 2) + assert.Equal(t, "OrderUser1", users1[0].Name) + assert.Equal(t, "OrderUser2", users1[1].Name) + + DB.Scopes(OrderBy("age", "asc").Scope()).Limit(2).Find(&users2) + assert.Len(t, users2, 2) + assert.Equal(t, "OrderUser1", users2[0].Name) + assert.Equal(t, "OrderUser2", users2[1].Name) + + DB.Scopes(OrderBy("age", "desc").Scope()).Limit(2).Find(&users3) + assert.Len(t, users3, 2) + assert.Equal(t, "OrderUser4", users3[0].Name) + assert.Equal(t, "OrderUser3", users3[1].Name) + + DB.Scopes(OrderBy("age", "unknown").Scope()).Limit(2).Find(&users4) + assert.Len(t, users4, 2) + assert.Equal(t, "OrderUser1", users4[0].Name) + assert.Equal(t, "OrderUser2", users4[1].Name) + + // OrderByAsc + var users5 []User + DB.Scopes(OrderByAsc("age").Scope()).Limit(2).Find(&users5) + assert.Len(t, users5, 2) + assert.Equal(t, "OrderUser1", users5[0].Name) + assert.Equal(t, "OrderUser2", users5[1].Name) + + // OrderByDesc + var users6 []User + DB.Scopes(OrderByDesc("age").Scope()).Limit(2).Find(&users6) + assert.Len(t, users6, 2) + assert.Equal(t, "OrderUser4", users6[0].Name) + assert.Equal(t, "OrderUser3", users6[1].Name) + + // OrderByRaw + var users7 []User + DB.Scopes(OrderByRaw("age % 2 asc").Scope()).Limit(2).Find(&users7) + assert.Len(t, users7, 2) + assert.Equal(t, "OrderUser2", users7[0].Name) + assert.Equal(t, "OrderUser4", users7[1].Name) + + // multiple OrderBy + var users8 []User + DB.Scopes(OrderBy("birthday", "asc").OrderBy("age", "desc").Scope()).Limit(2).Find(&users8) + assert.Len(t, users8, 2) + assert.Equal(t, "OrderUser2", users8[0].Name) + assert.Equal(t, "OrderUser1", users8[1].Name) +} diff --git a/gorm/scopes/pagination.go b/gorm/scopes/pagination.go new file mode 100644 index 00000000..d29650f2 --- /dev/null +++ b/gorm/scopes/pagination.go @@ -0,0 +1,95 @@ +package scopes + +import ( + "gorm.io/gorm" +) + +// Offset add offset condition +// +// Offset(3) +func Offset(offset int) *Scopes { + return New().Offset(offset) +} + +// Skip add offset condition +// +// Skip(3) +func Skip(offset int) *Scopes { + return New().Skip(offset) +} + +// Limit add limit condition +// +// Limit(3) +func Limit(limit int) *Scopes { + return New().Limit(limit) +} + +// Take add limit condition +// +// Take(3) +func Take(limit int) *Scopes { + return New().Take(limit) +} + +// Page add page condition +// +// Page(2, 10) +func Page(page, prePage int) *Scopes { + return New().Page(page, prePage) +} + +// Offset add offset condition +// +// Offset(3) +func (s *Scopes) Offset(offset int) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Offset(offset) + }) +} + +// Skip add offset condition +// +// Skip(3) +func (s *Scopes) Skip(offset int) *Scopes { + return s.Offset(offset) +} + +// Limit add limit condition +// +// Limit(3) +func (s *Scopes) Limit(limit int) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Limit(limit) + }) +} + +// Take add limit condition +// +// Take(3) +func (s *Scopes) Take(limit int) *Scopes { + return s.Limit(limit) +} + +// Page add page condition +// +// Page(2, 10) +func (s *Scopes) Page(page, prePage int) *Scopes { + return s.Limit(prePage).Offset((page - 1) * prePage) +} + +// Paginate TODO: 未实现 +// func (s *Scopes) Paginate(page, pageSize int, dest interface{}) *Scopes { +// var total int64 +// +// dest = pagination.Paginator{ +// Page: page, +// PrePage: pageSize, +// Total: int(total), +// } +// +// return s.Add(func(db *gorm.DB) *gorm.DB { +// db.Count(&total) +// return db.Offset((page - 1) * pageSize).Limit(pageSize) +// }) +// } diff --git a/gorm/scopes/pagination_test.go b/gorm/scopes/pagination_test.go new file mode 100644 index 00000000..d8547d7b --- /dev/null +++ b/gorm/scopes/pagination_test.go @@ -0,0 +1,51 @@ +package scopes + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPagination(t *testing.T) { + users := []*User{ + GetUser("PaginationUser1", GetUserOptions{}), + GetUser("PaginationUser2", GetUserOptions{}), + GetUser("PaginationUser3", GetUserOptions{}), + GetUser("PaginationUser4", GetUserOptions{}), + GetUser("PaginationUser5", GetUserOptions{}), + } + + CleanUsers() + DB.Create(&users) + + var users1, users2, users3, users4, users5, users6 []User + + // Offset&Limit + DB.Scopes(Offset(2).Limit(2).Scope()).Find(&users1) + assert.Len(t, users1, 2) + assert.Equal(t, "PaginationUser3", users1[0].Name) + assert.Equal(t, "PaginationUser4", users1[1].Name) + + DB.Scopes(Limit(2).Skip(4).Scope()).Find(&users2) + assert.Len(t, users2, 1) + assert.Equal(t, "PaginationUser5", users2[0].Name) + + // Skip&Take + DB.Scopes(Skip(2).Take(2).Scope()).Find(&users3) + assert.Len(t, users3, 2) + assert.Equal(t, "PaginationUser3", users3[0].Name) + assert.Equal(t, "PaginationUser4", users3[1].Name) + + DB.Scopes(Take(2).Skip(4).Scope()).Find(&users4) + assert.Len(t, users4, 1) + assert.Equal(t, "PaginationUser5", users4[0].Name) + + // Page + DB.Scopes(Page(2, 2).Scope()).Find(&users5) + assert.Len(t, users5, 2) + assert.Equal(t, "PaginationUser3", users5[0].Name) + + DB.Scopes(Page(3, 2).Scope()).Find(&users6) + assert.Len(t, users6, 1) + assert.Equal(t, "PaginationUser5", users6[0].Name) +} diff --git a/gorm/scopes/scopes.go b/gorm/scopes/scopes.go new file mode 100644 index 00000000..d2887827 --- /dev/null +++ b/gorm/scopes/scopes.go @@ -0,0 +1,28 @@ +package scopes + +import "gorm.io/gorm" + +type Scopes []func(*gorm.DB) *gorm.DB + +func New() *Scopes { + return &Scopes{} +} + +func (s *Scopes) Apply(db *gorm.DB) *gorm.DB { + return db.Scopes(*s...) +} + +func (s *Scopes) Scope() func(*gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + return s.Apply(db) + } +} + +func (s *Scopes) Scopes() []func(*gorm.DB) *gorm.DB { + return *s +} + +func (s *Scopes) Add(scopes ...func(*gorm.DB) *gorm.DB) *Scopes { + *s = append(*s, scopes...) + return s +} diff --git a/gorm/scopes/scopes_test.go b/gorm/scopes/scopes_test.go new file mode 100644 index 00000000..d90d094c --- /dev/null +++ b/gorm/scopes/scopes_test.go @@ -0,0 +1,42 @@ +package scopes + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "gorm.io/gorm" +) + +func TestScopes(t *testing.T) { + users := []*User{ + GetUser("ScopeUser1", GetUserOptions{}), + GetUser("ScopeUser2", GetUserOptions{}), + GetUser("ScopeUser3", GetUserOptions{}), + } + + scopes := New().Add(func(db *gorm.DB) *gorm.DB { + return db.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) + }) + + CleanUsers() + DB.Create(&users) + + var users1, users2, users3, users4 []User + + // Scope/Add + DB.Scopes(scopes.Scope()).Find(&users1) + assert.Len(t, users1, 2) + assert.Equal(t, "ScopeUser1", users1[0].Name) + assert.Equal(t, "ScopeUser2", users1[1].Name) + + // Apply + scopes.Apply(DB).Find(&users2) + assert.Len(t, users2, 2) + + DB.Scopes(scopes.Apply).Find(&users3) + assert.Len(t, users3, 2) + + // Scopes + DB.Scopes(scopes.Scopes()...).Find(&users4) + assert.Len(t, users4, 2) +} diff --git a/gorm/scopes/trait.go b/gorm/scopes/trait.go new file mode 100644 index 00000000..b78ffb8d --- /dev/null +++ b/gorm/scopes/trait.go @@ -0,0 +1,38 @@ +package scopes + +import "gorm.io/gorm" + +// When if condition is true, apply the scopes +// +// When(true, func(db *gorm.DB) *gorm.DB { return db.Where("name = ?", "Flc") }) +// When(false, func(db *gorm.DB) *gorm.DB { return db.Where("name = ?", "Flc") }) +func When(condition bool, f func(db *gorm.DB) *gorm.DB) *Scopes { + return New().When(condition, f) +} + +// Unless if condition is false, apply the scopes +// +// Unless(false, func(db *gorm.DB) *gorm.DB { return db.Where("name = ?", "Flc") }) +// Unless(true, func(db *gorm.DB) *gorm.DB { return db.Where("name = ?", "Flc") }) +func Unless(condition bool, f func(db *gorm.DB) *gorm.DB) *Scopes { + return New().Unless(condition, f) +} + +// When if condition is true, apply the scopes +// +// When(true, func(db *gorm.DB) *gorm.DB { return db.Where("name = ?", "Flc") }) +// When(false, func(db *gorm.DB) *gorm.DB { return db.Where("name = ?", "Flc") }) +func (s *Scopes) When(condition bool, fc func(*gorm.DB) *gorm.DB) *Scopes { + if condition { + return s.Add(fc) + } + return s +} + +// Unless if condition is false, apply the scopes +// +// Unless(false, func(db *gorm.DB) *gorm.DB { return db.Where("name = ?", "Flc") }) +// Unless(true, func(db *gorm.DB) *gorm.DB { return db.Where("name = ?", "Flc") }) +func (s *Scopes) Unless(condition bool, fc func(*gorm.DB) *gorm.DB) *Scopes { + return s.When(!condition, fc) +} diff --git a/gorm/scopes/trait_test.go b/gorm/scopes/trait_test.go new file mode 100644 index 00000000..2dcea6ab --- /dev/null +++ b/gorm/scopes/trait_test.go @@ -0,0 +1,57 @@ +package scopes + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "gorm.io/gorm" +) + +func TestTrait(t *testing.T) { + users := []*User{ + GetUser("TraitWhenUser1", GetUserOptions{Age: 18}), + GetUser("TraitWhenUser2", GetUserOptions{Age: 20}), + GetUser("TraitWhenUser3", GetUserOptions{Age: 22}), + } + + CleanUsers() + DB.Create(&users) + + // When + var users1, users2 []User + DB.Scopes(When(true, func(db *gorm.DB) *gorm.DB { + return db.Where("age > ?", 18) + }).Scope()).Find(&users1) + assert.Len(t, users1, 2) + assert.Equal(t, "TraitWhenUser2", users1[0].Name) + assert.Equal(t, "TraitWhenUser3", users1[1].Name) + + DB.Scopes(When(false, func(db *gorm.DB) *gorm.DB { + return db.Where("age > ?", 18) + }).Scope()).Find(&users2) + assert.Len(t, users2, 3) + + // Unless + var users3, users4 []User + DB.Scopes(Unless(true, func(db *gorm.DB) *gorm.DB { + return db.Where("age > ?", 18) + }).Scope()).Find(&users3) + assert.Len(t, users3, 3) + + DB.Scopes(Unless(false, func(db *gorm.DB) *gorm.DB { + return db.Where("age > ?", 18) + }).Scope()).Find(&users4) + assert.Len(t, users4, 2) + assert.Equal(t, "TraitWhenUser2", users4[0].Name) + assert.Equal(t, "TraitWhenUser3", users4[1].Name) + + // multiple When and Unless + var users5 []User + DB.Scopes(When(true, func(db *gorm.DB) *gorm.DB { + return db.Where("age > ?", 18) + }).Unless(false, func(db *gorm.DB) *gorm.DB { + return db.Where("age < ?", 22) + }).Scope()).Find(&users5) + assert.Len(t, users5, 1) + assert.Equal(t, "TraitWhenUser2", users5[0].Name) +} diff --git a/gorm/scopes/where.go b/gorm/scopes/where.go new file mode 100644 index 00000000..c8152eed --- /dev/null +++ b/gorm/scopes/where.go @@ -0,0 +1,297 @@ +package scopes + +import ( + "fmt" + + "gorm.io/gorm" +) + +// Where add where condition +// +// Where("name = ?", "Flc") +// Where("name = ? AND age = ?", "Flc", 20) +func Where(query interface{}, args ...interface{}) *Scopes { + return New().Where(query, args...) +} + +// WhereBetween add where between condition +// +// WhereBetween("age", 18, 20) +func WhereBetween(field string, start, end interface{}) *Scopes { + return New().WhereBetween(field, start, end) +} + +// WhereNotBetween add where not between condition +// +// WhereNotBetween("age", 18, 20) +func WhereNotBetween(field string, start, end interface{}) *Scopes { + return New().WhereNotBetween(field, start, end) +} + +// WhereIn add where in condition +// +// WhereIn("name", []string{"WhereInUser1", "WhereInUser2"}) +// WhereIn("age", []int{18, 20}) +// WhereIn("name", "WhereInUser1", "WhereInUser2") +func WhereIn(field string, values ...interface{}) *Scopes { + return New().WhereIn(field, values...) +} + +// WhereNotIn add where not in condition +// +// WhereNotIn("name", []string{"WhereInUser1", "WhereInUser2"}) +// WhereNotIn("age", []int{18, 20}) +// WhereNotIn("name", "WhereInUser1", "WhereInUser2") +func WhereNotIn(field string, values ...interface{}) *Scopes { + return New().WhereNotIn(field, values...) +} + +// WhereLike add where like condition +// +// WhereLike("name", "Flc") +// WhereLike("name", "Flc%") +// WhereLike("name", "%Flc") +// WhereLike("name", "%Flc%") +func WhereLike(field string, value interface{}) *Scopes { + return New().WhereLike(field, value) +} + +// WhereNotLike add where not like condition +// +// WhereNotLike("name", "Flc") +// WhereNotLike("name", "Flc%") +// WhereNotLike("name", "%Flc") +// WhereNotLike("name", "%Flc%") +func WhereNotLike(field string, value interface{}) *Scopes { + return New().WhereNotLike(field, value) +} + +// WhereEq add where eq condition +// +// WhereEq("name", "Flc") +// WhereEq("age", 18) +func WhereEq(field string, value interface{}) *Scopes { + return New().WhereEq(field, value) +} + +// WhereEgt add where egt condition +// +// WhereEgt("age", 18) +func WhereEgt(field string, value interface{}) *Scopes { + return New().WhereEgt(field, value) +} + +// WhereGt add where gt condition +// +// WhereGt("age", 18) +func WhereGt(field string, value interface{}) *Scopes { + return New().WhereGt(field, value) +} + +// WhereElt add where elt condition +// +// WhereElt("age", 18) +func WhereElt(field string, value interface{}) *Scopes { + return New().WhereElt(field, value) +} + +// WhereLt add where lt condition +// +// WhereLt("age", 18) +func WhereLt(field string, value interface{}) *Scopes { + return New().WhereLt(field, value) +} + +// WhereNe add where ne condition +// +// WhereNe("name", "Flc") +// WhereNe("age", 18) +func WhereNe(field string, value interface{}) *Scopes { + return New().WhereNe(field, value) +} + +// WhereNot add where not condition +// +// WhereNot("name = ?", "Flc") +// WhereNot("name = ? AND age = ?", "Flc", 20) +func WhereNot(query interface{}, args ...interface{}) *Scopes { + return New().WhereNot(query, args...) +} + +// WhereNull add where null condition +// +// WhereNull("name") +func WhereNull(field string) *Scopes { + return New().WhereNull(field) +} + +// WhereNotNull add where not null condition +// +// WhereNotNull("name") +func WhereNotNull(field string) *Scopes { + return New().WhereNotNull(field) +} + +// Where add where condition +// +// Where("name = ?", "Flc") +// Where("name = ? AND age = ?", "Flc", 20) +func (s *Scopes) Where(query interface{}, args ...interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Where(query, args...) + }) +} + +// WhereBetween add where between condition +// +// WhereBetween("age", 18, 20) +func (s *Scopes) WhereBetween(column string, start, end interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s BETWEEN ? AND ?", column), start, end) + }) +} + +// WhereNotBetween add where not between condition +// +// WhereNotBetween("age", 18, 20) +func (s *Scopes) WhereNotBetween(column string, start, end interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s NOT BETWEEN ? AND ?", column), start, end) + }) +} + +// WhereIn add where in condition +// +// WhereIn("name", []string{"WhereInUser1", "WhereInUser2"}) +// WhereIn("age", []int{18, 20}) +// WhereIn("name", "WhereInUser1", "WhereInUser2") +func (s *Scopes) WhereIn(column string, values ...interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + if len(values) > 1 { + return db.Where(fmt.Sprintf("%s IN (?)", column), values) + } + return db.Where(fmt.Sprintf("%s IN ?", column), values...) + }) +} + +// WhereNotIn add where not in condition +// +// WhereNotIn("name", []string{"WhereInUser1", "WhereInUser2"}) +// WhereNotIn("age", []int{18, 20}) +// WhereNotIn("name", "WhereInUser1", "WhereInUser2") +func (s *Scopes) WhereNotIn(column string, values ...interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + if len(values) > 1 { + return db.Where(fmt.Sprintf("%s NOT IN (?)", column), values) + } + return db.Where(fmt.Sprintf("%s NOT IN ?", column), values...) + }) +} + +// WhereLike add where like condition +// +// WhereLike("name", "Flc") +// WhereLike("name", "Flc%") +// WhereLike("name", "%Flc") +// WhereLike("name", "%Flc%") +func (s *Scopes) WhereLike(column string, value interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s LIKE ?", column), value) + }) +} + +// WhereNotLike add where not like condition +// +// WhereNotLike("name", "Flc") +// WhereNotLike("name", "Flc%") +// WhereNotLike("name", "%Flc") +// WhereNotLike("name", "%Flc%") +func (s *Scopes) WhereNotLike(column string, value interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s NOT LIKE ?", column), value) + }) +} + +// WhereEq add where eq condition +// +// WhereEq("name", "Flc") +// WhereEq("age", 18) +func (s *Scopes) WhereEq(column string, value interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s = ?", column), value) + }) +} + +// WhereEgt add where egt condition +// +// WhereEgt("age", 18) +func (s *Scopes) WhereEgt(column string, value interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s >= ?", column), value) + }) +} + +// WhereGt add where gt condition +// +// WhereGt("age", 18) +func (s *Scopes) WhereGt(column string, value interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s > ?", column), value) + }) +} + +// WhereElt add where elt condition +// +// WhereElt("age", 18) +func (s *Scopes) WhereElt(column string, value interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s <= ?", column), value) + }) +} + +// WhereLt add where lt condition +// +// WhereLt("age", 18) +func (s *Scopes) WhereLt(column string, value interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s < ?", column), value) + }) +} + +// WhereNe add where ne condition +// +// WhereNe("name", "Flc") +// WhereNe("age", 18) +func (s *Scopes) WhereNe(column string, value interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s <> ?", column), value) + }) +} + +// WhereNot add where not condition +// +// WhereNot("name = ?", "Flc") +// WhereNot("name = ? AND age = ?", "Flc", 20) +func (s *Scopes) WhereNot(query interface{}, args ...interface{}) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Not(query, args...) + }) +} + +// WhereNull add where null condition +// +// WhereNull("name") +func (s *Scopes) WhereNull(column string) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s IS NULL", column)) + }) +} + +// WhereNotNull add where not null condition +// +// WhereNotNull("name") +func (s *Scopes) WhereNotNull(column string) *Scopes { + return s.Add(func(db *gorm.DB) *gorm.DB { + return db.Where(fmt.Sprintf("%s IS NOT NULL", column)) + }) +} diff --git a/gorm/scopes/where_test.go b/gorm/scopes/where_test.go new file mode 100644 index 00000000..90fcebf7 --- /dev/null +++ b/gorm/scopes/where_test.go @@ -0,0 +1,258 @@ +package scopes + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWhere_Where(t *testing.T) { + users := []*User{ + GetUser("WhereUser1", GetUserOptions{}), + GetUser("WhereUser2", GetUserOptions{}), + GetUser("WhereUser3", GetUserOptions{}), + } + + CleanUsers() + DB.Create(&users) + + var users1, users2, users3 []User + DB.Scopes(Where("name in (?)", []string{"WhereUser1", "WhereUser2"}).Scope()).Find(&users1) + assert.Len(t, users1, 2) + + DB.Scopes(Where("name in (?)", []string{"WhereUser1", "WhereUser4"}).Scope()).Find(&users2) + assert.Len(t, users2, 1) + + DB.Scopes(Where("name = ?", "WhereUser3").Scope()).Find(&users3) + assert.Len(t, users3, 1) + assert.Equal(t, "WhereUser3", users3[0].Name) +} + +func TestWhere_Between(t *testing.T) { + users := []*User{ + GetUser("WhereBetweenUser1", GetUserOptions{Age: 18}), + GetUser("WhereBetweenUser2", GetUserOptions{Age: 20}), + GetUser("WhereBetweenUser3", GetUserOptions{Age: 22}), + } + + CleanUsers() + DB.Create(&users) + + var users1, users2, users3 []User + DB.Scopes(WhereBetween("age", 18, 20).Scope()).Find(&users1) + assert.Len(t, users1, 2) + + DB.Scopes(WhereBetween("age", 18, 19).Scope()).Find(&users2) + assert.Len(t, users2, 1) + assert.Equal(t, "WhereBetweenUser1", users2[0].Name) + + DB.Scopes(WhereBetween("age", 12, 16).Scope()).Find(&users3) + assert.Len(t, users3, 0) + + var users4, users5, users6 []User + DB.Scopes(WhereNotBetween("age", 18, 20).Scope()).Find(&users4) + assert.Len(t, users4, 1) + assert.Equal(t, "WhereBetweenUser3", users4[0].Name) + + DB.Scopes(WhereNotBetween("age", 18, 19).Scope()).Find(&users5) + assert.Len(t, users5, 2) + + DB.Scopes(WhereNotBetween("age", 12, 16).Scope()).Find(&users6) + assert.Len(t, users6, 3) +} + +func TestWhere_In(t *testing.T) { + users := []*User{ + GetUser("WhereInUser1", GetUserOptions{Age: 18}), + GetUser("WhereInUser2", GetUserOptions{Age: 20}), + GetUser("WhereInUser3", GetUserOptions{Age: 22}), + } + + CleanUsers() + DB.Create(&users) + + var users1, users2, users3 []User + DB.Debug().Scopes(WhereIn("name", "WhereInUser1", "WhereInUser2").Scope()).Find(&users1) + assert.Len(t, users1, 2) + + DB.Scopes(WhereIn("age", []int{18, 20}).Scope()).Find(&users2) + assert.Len(t, users2, 2) + + DB.Scopes(WhereIn("name", []string{"WhereInUser1", "WhereInUser2"}).Scope()).Find(&users3) + assert.Len(t, users3, 2) + + var users4, users5, users6 []User + DB.Scopes(WhereNotIn("name", "WhereInUser1", "WhereInUser2").Scope()).Find(&users4) + assert.Len(t, users4, 1) + assert.Equal(t, "WhereInUser3", users4[0].Name) + + DB.Scopes(WhereNotIn("age", []int{18, 20}).Scope()).Find(&users5) + assert.Len(t, users5, 1) + assert.Equal(t, "WhereInUser3", users5[0].Name) + + DB.Scopes(WhereNotIn("name", []string{"WhereInUser1", "WhereInUser2"}).Scope()).Find(&users6) + assert.Len(t, users6, 1) + assert.Equal(t, "WhereInUser3", users6[0].Name) +} + +func TestWhere_Like(t *testing.T) { + users := []*User{ + GetUser("WhereLikeUser1", GetUserOptions{Age: 18}), + GetUser("WhereLikeUser2", GetUserOptions{Age: 20}), + GetUser("WhereLikeUser3", GetUserOptions{Age: 22}), + } + + CleanUsers() + DB.Create(&users) + + var users1, users2, users3, users4 []User + DB.Scopes(WhereLike("name", "WhereLikeUser1").Scope()).Find(&users1) + assert.Len(t, users1, 1) + assert.Equal(t, "WhereLikeUser1", users1[0].Name) + + DB.Scopes(WhereLike("name", "WhereLike%").Scope()).Find(&users2) + assert.Len(t, users2, 3) + + DB.Scopes(WhereLike("name", "%LikeUser3").Scope()).Find(&users3) + assert.Len(t, users3, 1) + assert.Equal(t, "WhereLikeUser3", users3[0].Name) + + DB.Scopes(WhereLike("name", "%Like%").Scope()).Find(&users4) + assert.Len(t, users4, 3) + + var users5, users6, users7, users8 []User + DB.Scopes(WhereNotLike("name", "WhereLikeUser1").Scope()).Find(&users5) + assert.Len(t, users5, 2) + + DB.Scopes(WhereNotLike("name", "WhereLike%").Scope()).Find(&users6) + assert.Len(t, users6, 0) + + DB.Scopes(WhereNotLike("name", "%LikeUser3").Scope()).Find(&users7) + assert.Len(t, users7, 2) + + DB.Scopes(WhereNotLike("name", "%Like%").Scope()).Find(&users8) + assert.Len(t, users8, 0) +} + +func TestWhere_OP(t *testing.T) { + users := []*User{ + GetUser("WhereLikeUser1", GetUserOptions{Age: 18}), + GetUser("WhereLikeUser2", GetUserOptions{Age: 20}), + GetUser("WhereLikeUser3", GetUserOptions{Age: 22}), + GetUser("WhereLikeUser4", GetUserOptions{Age: 22}), + } + + CleanUsers() + DB.Create(&users) + + // Eq + var users1, users2, users3 []User + DB.Scopes(WhereEq("name", "WhereLikeUser1").Scope()).Find(&users1) + assert.Len(t, users1, 1) + assert.Equal(t, "WhereLikeUser1", users1[0].Name) + + DB.Scopes(WhereEq("age", 18).Scope()).Find(&users2) + assert.Len(t, users2, 1) + assert.Equal(t, "WhereLikeUser1", users2[0].Name) + + DB.Scopes(WhereEq("age", 22).Scope()).Find(&users3) + assert.Len(t, users3, 2) + + // Egt + var users4, users5, users6 []User + DB.Scopes(WhereEgt("age", 20).Scope()).Find(&users4) + assert.Len(t, users4, 3) + + DB.Scopes(WhereEgt("age", 22).Scope()).Find(&users5) + assert.Len(t, users5, 2) + + DB.Scopes(WhereEgt("age", 23).Scope()).Find(&users6) + assert.Len(t, users6, 0) + + // Elt + var users7, users8, users9 []User + DB.Scopes(WhereElt("age", 20).Scope()).Find(&users7) + assert.Len(t, users7, 2) + + DB.Scopes(WhereElt("age", 22).Scope()).Find(&users8) + assert.Len(t, users8, 4) + + DB.Scopes(WhereElt("age", 18).Scope()).Find(&users9) + assert.Len(t, users9, 1) + + // Gt + var users10, users11, users12 []User + DB.Scopes(WhereGt("age", 20).Scope()).Find(&users10) + assert.Len(t, users10, 2) + + DB.Scopes(WhereGt("age", 22).Scope()).Find(&users11) + assert.Len(t, users11, 0) + + DB.Scopes(WhereGt("age", 18).Scope()).Find(&users12) + assert.Len(t, users12, 3) + + // Lt + var users13, users14, users15 []User + DB.Scopes(WhereLt("age", 20).Scope()).Find(&users13) + assert.Len(t, users13, 1) + + DB.Scopes(WhereLt("age", 22).Scope()).Find(&users14) + assert.Len(t, users14, 2) + + DB.Scopes(WhereLt("age", 18).Scope()).Find(&users15) + assert.Len(t, users15, 0) + + // Ne + var users16, users17, users18 []User + DB.Scopes(WhereNe("age", 20).Scope()).Find(&users16) + assert.Len(t, users16, 3) + + DB.Scopes(WhereNe("age", 22).Scope()).Find(&users17) + assert.Len(t, users17, 2) + + DB.Scopes(WhereNe("age", 18).Scope()).Find(&users18) + assert.Len(t, users18, 3) +} + +func TestWhere_WhereNot(t *testing.T) { + users := []*User{ + GetUser("WhereLikeUser1", GetUserOptions{Age: 18}), + GetUser("WhereLikeUser2", GetUserOptions{Age: 20}), + GetUser("WhereLikeUser3", GetUserOptions{Age: 22}), + } + + CleanUsers() + DB.Create(&users) + + var users1 []User + DB.Scopes(WhereNot("name = ?", "WhereLikeUser1").Scope()).Find(&users1) + assert.Len(t, users1, 2) + assert.Equal(t, "WhereLikeUser2", users1[0].Name) + assert.Equal(t, "WhereLikeUser3", users1[1].Name) +} + +func TestWhere_Null(t *testing.T) { + address2 := "WhereNullAddress2" + address3 := "WhereNullAddress3" + users := []*User{ + GetUser("WhereNullUser1", GetUserOptions{Age: 18, Address: nil}), + GetUser("WhereNullUser2", GetUserOptions{Age: 20, Address: &address2}), + GetUser("WhereNullUser3", GetUserOptions{Age: 22, Address: &address3}), + } + + CleanUsers() + DB.Create(&users) + + // Null + var users1 []User + DB.Scopes(WhereNull("address").Scope()).Find(&users1) + assert.Len(t, users1, 1) + assert.Equal(t, "WhereNullUser1", users1[0].Name) + + // NotNull + var users2 []User + DB.Scopes(WhereNotNull("address").Scope()).Find(&users2) + assert.Len(t, users2, 2) + assert.Equal(t, "WhereNullUser2", users2[0].Name) + assert.Equal(t, "WhereNullUser3", users2[1].Name) +} diff --git a/pagination/paginator.go b/pagination/paginator.go new file mode 100644 index 00000000..39b671bc --- /dev/null +++ b/pagination/paginator.go @@ -0,0 +1,75 @@ +package pagination + +import "encoding/json" + +type Paginator struct { + Page int `json:"page"` + PrePage int `json:"pre_page"` + Total int `json:"total"` +} + +func NewPaginator(page, pageSize, total int) *Paginator { + return &Paginator{ + Page: page, + PrePage: pageSize, + Total: total, + } +} + +func (p *Paginator) GetPage() int { + return p.Page +} + +func (p *Paginator) GetPerPage() int { + return p.PrePage +} + +func (p *Paginator) GetTotal() int { + return p.Total +} + +func (p *Paginator) GetLastPage() int { + return (p.Total + p.PrePage - 1) / p.PrePage +} + +func (p *Paginator) GetOffset() int { + return (p.Page - 1) * p.PrePage +} + +func (p *Paginator) GetLimit() int { + return p.PrePage +} + +func (p *Paginator) HasMore() bool { + return p.Page*p.PrePage < p.Total +} + +func (p *Paginator) GetNextPage() int { + if p.HasMore() { + return p.Page + 1 + } + return p.Page +} + +func (p *Paginator) GetPrevPage() int { + if p.Page > 1 { + return p.Page - 1 + } + return p.Page +} + +func (p *Paginator) ToMap() map[string]interface{} { + return map[string]interface{}{ + "page": p.Page, + "pre_page": p.PrePage, + "total": p.Total, + "next_page": p.GetNextPage(), + "prev_page": p.GetPrevPage(), + "last_page": p.GetLastPage(), + "has_more": p.HasMore(), + } +} + +func (p *Paginator) ToJSON() ([]byte, error) { + return json.Marshal(p.ToMap()) +}