diff --git a/constants/code.go b/constants/code.go index 5044170..27d5649 100644 --- a/constants/code.go +++ b/constants/code.go @@ -62,6 +62,9 @@ func RegisterCodeTable(language LanguageCode, i18nMsg map[ICodeType]string) { // # 此方法适用于【动态】注入单个错误码的场景 func RegisterCodeSingle(language LanguageCode, code ICodeType, msg string) { ctm.mux.Lock() + if _, ok := ctm.maps[language][code]; !ok { + ctm.maps[language] = make(map[ICodeType]string) + } ctm.maps[language][code] = msg ctm.mux.Unlock() } diff --git a/constants/i18n.go b/constants/i18n.go index 8a83b80..7455eaf 100644 --- a/constants/i18n.go +++ b/constants/i18n.go @@ -272,12 +272,16 @@ var i18n = []LanguageCode{ LanguageZulu, } +func (lc LanguageCode) String() string { + return string(lc) +} + func (lc LanguageCode) ToLowerCase() string { - return strings.ToLower(string(lc)) + return strings.ToLower(lc.String()) } func (lc LanguageCode) ToUpperCase() string { - return strings.ToUpper(string(lc)) + return strings.ToUpper(lc.String()) } func (lc LanguageCode) Exist() bool { diff --git a/http/api/option.go b/http/api/option.go index f16f465..e9a9d79 100644 --- a/http/api/option.go +++ b/http/api/option.go @@ -13,6 +13,7 @@ var ( forceHttpCode200 = false //强制使用200作为http的状态码 timezone = constants.DefaultTimeZone //时区 detectAcceptLanguage = false //是否检测客户端语言 + languageCode = constants.LanguageEnglish //语言代码 ) var ( @@ -35,6 +36,10 @@ type Option struct { Timezone string //是否检测客户端语言,用于错误码消息返回 DetectAcceptLanguage bool + //语言代码 + // + //当没有启用 DetectAcceptLanguage 时,使用该语言代码 + LanguageCode constants.LanguageCode } const ( @@ -91,5 +96,6 @@ func DefaultSetupOption() *Option { ErrNoneCode: constants.ErrNone, ErrNoneCodeMsg: "SUCCESS", ForceHttpCode200: true, + LanguageCode: constants.LanguageEnglish, } } diff --git a/http/api/response.go b/http/api/response.go index ab8eb7c..9d543ad 100644 --- a/http/api/response.go +++ b/http/api/response.go @@ -243,7 +243,7 @@ func (a *responseEngine) mergeBody(code constants.ICodeType, resp interface{}, m body dto.Base requestId string httpCode int - language = []string{"en"} + language = []string{languageCode.String()} ) //从上下文中获取语言代码 if detectAcceptLanguage { diff --git a/orm/base.go b/orm/base.go index d913b99..a736ad8 100644 --- a/orm/base.go +++ b/orm/base.go @@ -16,3 +16,10 @@ type BaseModel struct { // NoneID 空ID const NoneID = uint64(0) + +var nowTime = time.Now() + +// SetHookTime 设置勾子函数的时间对象 +func SetHookTime(now time.Time) { + nowTime = now +} diff --git a/orm/hook.go b/orm/hook.go index a9e1a2f..ce5ee62 100644 --- a/orm/hook.go +++ b/orm/hook.go @@ -1,33 +1,31 @@ package orm import ( - "time" - "gorm.io/gorm" ) func (u *BaseModel) BeforeSave(_ *gorm.DB) (err error) { - u.CreatedAt = time.Now() + u.CreatedAt = nowTime u.UpdatedAt = u.CreatedAt return nil } func (u *BaseModel) BeforeCreate(_ *gorm.DB) (err error) { - u.CreatedAt = time.Now() + u.CreatedAt = nowTime u.UpdatedAt = u.CreatedAt return nil } func (u *BaseModel) BeforeUpdate(_ *gorm.DB) (err error) { - u.UpdatedAt = time.Now() + u.UpdatedAt = nowTime return nil } func (u *BaseModel) BeforeDelete(_ *gorm.DB) (err error) { - now := time.Now() + now := nowTime u.DeletedAt = &now return nil diff --git a/orm/svc.go b/orm/svc.go index 9f8bf8f..8fca869 100644 --- a/orm/svc.go +++ b/orm/svc.go @@ -36,23 +36,23 @@ type Svc interface { Session(session *gorm.Session) Svc WithContext(ctx context.Context) Svc + Count(count *int64) Create(value interface{}) error Find(dest interface{}, conditions ...interface{}) error - First(dest interface{}, conditions ...interface{}) error Updates(values interface{}) error Save(values interface{}) error Delete(value interface{}, conditions ...interface{}) error Transaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) (err error) - //Unwrap 返回gorm原生实例 + // Unwrap 返回gorm原生实例 Unwrap() *gorm.DB - //R 使用读实例 + // R 使用读实例 R() Svc - //W 使用写实例 + // W 使用写实例 W() Svc - //Paginate 分页查询(多行) + // Paginate 分页查询(多行) // //参数: // @@ -64,11 +64,11 @@ type Svc interface { // //总条数和错误 Paginate(dest interface{}, page, pageSize int) (int64, error) - //FindOrNil 查询多条记录 + // FindOrNil 查询多条记录 // //如果记录不存在忽略 gorm.ErrRecordNotFound 错误 FindOrNil(dest interface{}, conditions ...interface{}) error - //FirstOrNil 查询单条记录 + // FirstOrNil 查询单条记录 // //如果记录不存在忽略 gorm.ErrRecordNotFound 错误 FirstOrNil(dest interface{}, conditions ...interface{}) error @@ -255,8 +255,22 @@ func (a *SvcImpl) Unwrap() *gorm.DB { return a.tx } +func (a *SvcImpl) Count(count *int64) { + if a.tx == nil { + a.tx = a.dbw + } + + a.tx.Count(count) + a.clearTx() +} + func (a *SvcImpl) Create(value interface{}) error { - err := a.dbw.Create(value).Error + if a.tx == nil { + a.tx = a.dbw + } + + err := a.tx.Create(value).Error + a.clearTx() if err != nil { a.logger.Error("[Database service]:Create:Error", @@ -273,6 +287,7 @@ func (a *SvcImpl) Find(dest interface{}, conditions ...interface{}) error { } err := IgnoreErrRecordNotFound(a.tx.Find(dest, conditions...)) + a.clearTx() if err != nil { a.logger.Error("[Database service]:Find:Error", @@ -290,6 +305,7 @@ func (a *SvcImpl) FindOrNil(dest interface{}, conditions ...interface{}) error { } err := IgnoreErrRecordNotFound(a.tx.Find(dest, conditions...)) + a.clearTx() if err != nil { a.logger.Error("[Database service]:FindOrNil:Error", @@ -307,6 +323,7 @@ func (a *SvcImpl) First(dest interface{}, conditions ...interface{}) error { } err := IgnoreErrRecordNotFound(a.tx.First(dest, conditions...)) + a.clearTx() if err != nil { a.logger.Error("[Database service]:First:Error", @@ -324,6 +341,7 @@ func (a *SvcImpl) FirstOrNil(dest interface{}, conditions ...interface{}) error } err := IgnoreErrRecordNotFound(a.tx.First(dest, conditions...)) + a.clearTx() if err != nil { a.logger.Error("[Database service]:FirstOrNil:Error", @@ -341,6 +359,7 @@ func (a *SvcImpl) Updates(values interface{}) error { } err := a.tx.Updates(values).Error + a.clearTx() if err != nil { a.logger.Error("[Database service]:Updates:Error", @@ -357,6 +376,7 @@ func (a *SvcImpl) Save(values interface{}) error { } err := a.tx.Save(values).Error + a.clearTx() if err != nil { a.logger.Error("[Database service]:Save:Error", @@ -371,7 +391,9 @@ func (a *SvcImpl) Delete(value interface{}, conditions ...interface{}) error { if a.tx == nil { a.tx = a.dbw } + err := a.tx.Delete(value, conditions...).Error + a.clearTx() if err != nil { a.logger.Error("[Database service]:Delete:Error", @@ -387,7 +409,9 @@ func (a *SvcImpl) Transaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOptions if a.tx == nil { a.tx = a.dbw } + err = a.tx.Transaction(fc, opts...) + a.clearTx() if err != nil { a.logger.Error("[Database service]:Transaction:Error", @@ -418,10 +442,15 @@ func (a *SvcImpl) Paginate(dest interface{}, page, pageSize int) (int64, error) a.tx = a.dbw } - var total int64 - a.tx.Count(&total) + var count int64 + if a.tx.Statement.Model == nil { + a.tx.Model(dest).Count(&count) + } else { + a.tx.Count(&count) + } err := IgnoreErrRecordNotFound(a.tx.Scopes(Paginate(page, pageSize)).Find(dest)) + a.clearTx() if err != nil { a.logger.Error("[Database service]:Paginate:Error", @@ -429,5 +458,9 @@ func (a *SvcImpl) Paginate(dest interface{}, page, pageSize int) (int64, error) zap.Errors("errors", []error{err})) } - return total, err + return count, err +} + +func (a *SvcImpl) clearTx() { + a.tx = nil } diff --git a/orm/svc_test.go b/orm/svc_test.go index 8114795..2a349e9 100644 --- a/orm/svc_test.go +++ b/orm/svc_test.go @@ -2,8 +2,14 @@ package orm import ( "fmt" + "net/url" "testing" + "github.com/keepchen/go-sail/v3/constants" + "github.com/keepchen/go-sail/v3/utils" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" "github.com/keepchen/go-sail/v3/lib/db" @@ -21,12 +27,26 @@ func (*User) TableName() string { return "user" } +type Wallet struct { + BaseModel + UserID int64 `gorm:"column:user_id;type:bigint;not null;index:,unique;comment:用户ID"` + Amount decimal.Decimal `gorm:"column:amount;type:decimal(10,2);default:0;comment:余额"` +} + +func (*Wallet) TableName() string { + return "wallet" +} + var ( - loggerConf = logger.Conf{} - dbConf = db.Conf{ - Enable: false, + loggerConf = logger.Conf{ + Level: "debug", + Filename: "../examples/logs/testcase_db.log", + } + dbConf = db.Conf{ + Enable: true, DriverName: "mysql", AutoMigrate: true, + LogLevel: "debug", ConnectionPool: db.ConnectionPoolConf{ MaxOpenConnCount: 100, MaxIdleConnCount: 10, @@ -42,7 +62,7 @@ var ( Database: "go-sail", Charset: "utf8mb4", ParseTime: true, - Loc: "Local", + Loc: url.QueryEscape(constants.TimeZoneUTCPlus7), }, Write: db.MysqlConfItem{ Host: "127.0.0.1", @@ -52,7 +72,7 @@ var ( Database: "go-sail", Charset: "utf8mb4", ParseTime: true, - Loc: "Local", + Loc: url.QueryEscape(constants.TimeZoneUTCPlus7), }, }, } @@ -66,11 +86,14 @@ func TestSvcUsage(t *testing.T) { t.Log("database instance is nil, testing not emit.") return } - _ = AutoMigrate(dbw, &User{}) + _ = AutoMigrate(dbw, &User{}, &Wallet{}) svc := New(dbr, dbw, logger.GetLogger()) dbw.Exec(fmt.Sprintf("truncate table %s", (&User{}).TableName())) + dbw.Exec(fmt.Sprintf("truncate table %s", (&Wallet{}).TableName())) + + SetHookTime(utils.NewTimeWithTimeZone(constants.TimeZoneUTCPlus7).Now()) // ---- ignore gorm.ErrRecordNotFound var user0 User @@ -86,7 +109,16 @@ func TestSvcUsage(t *testing.T) { } err = svc.W().Create(&user1) assert.NoError(t, err) - t.Log("Create:", user1) + t.Log("Create user:", user1) + + var wallet1 = Wallet{ + UserID: user1.UserID, + Amount: decimal.NewFromFloat(10.24), + } + + err = svc.Create(&wallet1) + assert.NoError(t, err) + t.Log("Create wallet:", wallet1) // ---- read one record var user2 User @@ -100,15 +132,27 @@ func TestSvcUsage(t *testing.T) { Nickname: "go-sail", Status: 2, } - err = svc.W().Select("*").Omit("id", "created_at", "deleted_at").Updates(&user3) + err = svc.W().Select("*").Omit("id", "created_at", "deleted_at"). + Where("user_id = ?", user3.UserID).Updates(&user3) assert.NoError(t, err) t.Log("Updates:", user3) + var ( + queryUser User + queryWallet Wallet + ) + + err = svc.Model(&User{}).Where("user_id = ?", user3.UserID).First(&queryUser) + assert.NoError(t, err) + + err = svc.Model(&Wallet{}).Where("amount = ?", 10.24).First(&queryWallet) + assert.NoError(t, err) + // ---- read several records var ( users0 []User ) - err = svc.R().Find(&users0) + err = svc.R().Where("id > ?", 0).Find(&users0) assert.NoError(t, err) t.Log("Find:", users0) @@ -117,6 +161,7 @@ func TestSvcUsage(t *testing.T) { err = svc.R().Where(&User{UserID: 99999}).FindOrNil(&users1) t.Log("FindOrNil:", err) assert.NoError(t, err) + assert.Equal(t, users1.ID, NoneID) // ---- paginate var ( diff --git a/sail/components.go b/sail/components.go index f9f7116..49b1dae 100644 --- a/sail/components.go +++ b/sail/components.go @@ -26,15 +26,28 @@ import ( // // 注意,使用前请确保db组件已初始化成功。 func GetDB() (read *gorm.DB, write *gorm.DB) { - read, write = db.GetInstance().R, db.GetInstance().W + if db.GetInstance() == nil { + read, write = nil, nil + } else { + read, write = db.GetInstance().R, db.GetInstance().W + } return } +// NewDB 创建新的数据实例 +func NewDB(conf db.Conf) (read *gorm.DB, rErr error, write *gorm.DB, wErr error) { + return db.New(conf) +} + // GetDBR 获取数据库读实例 // // 注意,使用前请确保db组件已初始化成功。 func GetDBR() *gorm.DB { + if db.GetInstance() == nil { + return nil + } + return db.GetInstance().R } @@ -42,6 +55,10 @@ func GetDBR() *gorm.DB { // // 注意,使用前请确保db组件已初始化成功。 func GetDBW() *gorm.DB { + if db.GetInstance() == nil { + return nil + } + return db.GetInstance().W } @@ -77,6 +94,13 @@ func GetLogger(module ...string) *zap.Logger { return logger.GetLogger(module...) } +// MarshalInterfaceValue 将interface序列化成字符串 +// +// 主要用于日志记录 +func MarshalInterfaceValue(obj interface{}) string { + return logger.MarshalInterfaceValue(obj) +} + // Response http响应组件 func Response(c *gin.Context) api.Responder { return api.Response(c) diff --git a/utils/number.go b/utils/number.go new file mode 100644 index 0000000..3c8dff2 --- /dev/null +++ b/utils/number.go @@ -0,0 +1,139 @@ +package utils + +import ( + "fmt" + "math/rand" + + "github.com/shopspring/decimal" +) + +// RandomInt64 在指定范围内取随机整数 +// +// start和end同时支持正负数 +// +// 结果值区间 ∈ [start, end) +// +// # Note +// +// 若start大于end将panic +// +// # Example: +// +// result := RandomInt64(10, 20) +// //-> 13 +// +// result := RandomInt64(-10, 20) +// //-> 3 +// +// result := RandomInt64(-20, -10) +// //-> -7 +func RandomInt64(start, end int64) int64 { + if start > end { + panic(fmt.Errorf("range invalid: start great than end")) + } + if start == end { + return start + } + //如果范围都是负值区间 + if start < 0 && end < 0 { + fixedStart, fixedEnd := 0-start, 0-end + return 0 - (fixedEnd + rand.Int63n(fixedStart)) + } + //如果是一正一负 + if start < 0 && end >= 0 { + fixed := 0 - start + return rand.Int63n(fixed+end) - fixed + } + //起始为0 + if start == 0 { + return rand.Int63n(end) + } + + return start +} + +// RandomFloat64 在指定范围内取随机浮点数 +// +// start和end同时支持正负数 +// +// precision为精度,此参数将限定返回值的最大小数位数 +// +// 结果值区间 ∈ [start, end) +// +// # Note +// +// 若start大于end将panic +// +// # Example: +// +// result := RandomFloat64(10.10, 20.20, 2) +// //-> 16.22 +// +// result := RandomFloat64(-10.10, 20.20, 3) +// //-> -7.222 +// +// result := RandomFloat64(-20.20, -10.10101010101, 4) +// //-> -8.1234 +func RandomFloat64(start, end float64, precision int) float64 { + if start > end { + panic(fmt.Errorf("range invalid: start great than end")) + } + if start == end { + return start + } + var scale = 1 + + //对start进行放大 + for { + startScaled := start * float64(scale) + if startScaled == float64(int64(startScaled)) { + break + } + scale *= 10 + } + + //对end进行放大 + for { + endScaled := end * float64(scale) + if endScaled == float64(int64(endScaled)) { + break + } + scale *= 10 + } + + start *= float64(scale) + end *= float64(scale) + + randInt64 := RandomInt64(int64(start), int64(end)) + + return decimal.NewFromInt(randInt64). + Div(decimal.NewFromInt(int64(scale))). + Truncate(int32(precision)).InexactFloat64() +} + +// Pow 计算x的y次幂 +// +// # Note +// +// 若y小于0,将panic +func Pow(x, y int64) int64 { + if y < 0 { + panic(fmt.Errorf("y less than zero")) + } + if y == 0 { + return 1 + } + if y == 1 { + return x + } + var times int64 + for { + times++ + if times == y { + break + } + x *= x + } + + return x +} diff --git a/utils/number_test.go b/utils/number_test.go new file mode 100644 index 0000000..3663f6e --- /dev/null +++ b/utils/number_test.go @@ -0,0 +1,78 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRandomInt64(t *testing.T) { + var ranges = [][]int64{ + {0, 0}, + {1, 1}, + {0, 1000}, + {-10, 0}, + {-10, 100}, + {-100, 100}, + } + for _, rn := range ranges { + t.Log("range: from", rn[0], "to", rn[1]) + for i := 0; i < 10; i++ { + t.Log(RandomInt64(rn[0], rn[1])) + } + + for j := 0; j < 1000; j++ { + r := RandomInt64(rn[0], rn[1]) + assert.GreaterOrEqual(t, r, rn[0]) + if rn[0] != rn[1] { + assert.Less(t, r, rn[1]) + } else { + assert.Equal(t, r, rn[1]) + } + } + } +} + +func TestRandomFloat64(t *testing.T) { + type target struct { + start float64 + end float64 + precision int + } + var ranges = []target{ + {0, 0, 2}, + {1, 1, 2}, + {0, 1000, 3}, + {-10.11, 0.11, 4}, + {-10.101, 100.10111111, 5}, + {-100.1011, 100.1011111111, 6}, + } + for _, rn := range ranges { + t.Log("range: from", rn.start, "to", rn.end, "precision", rn.precision) + for i := 0; i < 10; i++ { + t.Log(RandomFloat64(rn.start, rn.end, rn.precision)) + } + + for j := 0; j < 1000; j++ { + r := RandomFloat64(rn.start, rn.end, rn.precision) + assert.GreaterOrEqual(t, r, rn.start) + if rn.start != rn.end { + assert.Less(t, r, rn.end) + } else { + assert.Equal(t, r, rn.end) + } + } + } +} + +func TestPow(t *testing.T) { + var cases = [][]int64{ + {1, 0, 1}, + {-1, 0, 1}, + {2, 2, 4}, + {3, 2, 9}, + } + for _, ca := range cases { + assert.Equal(t, Pow(ca[0], ca[1]), ca[2]) + } +}