From 98360f036f311f085077e0d423595710d0a4c315 Mon Sep 17 00:00:00 2001 From: keepchen Date: Fri, 10 May 2024 16:47:09 +0800 Subject: [PATCH] =?UTF-8?q?1.=E9=83=A8=E5=88=86=E7=BB=84=E4=BB=B6New?= =?UTF-8?q?=E6=96=B9=E6=B3=95=E5=87=BA=E7=8E=B0=E9=94=99=E8=AF=AF=E4=B8=8D?= =?UTF-8?q?=E5=86=8Dpanic=E8=80=8C=E6=98=AF=E8=BF=94=E5=9B=9E=E9=94=99?= =?UTF-8?q?=E8=AF=AF,2.orm=E7=BB=93=E6=9E=84=E8=B0=83=E6=95=B4=E5=B9=B6?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lib/db/db.go | 40 +++++++--- lib/kafka/kafka.go | 26 ++++--- lib/logger/zap.go | 28 ++++++- lib/nats/nats.go | 17 +++- lib/redis/redis.go | 27 ++++++- lib/redis/redis_cluster.go | 46 ++++++++++- orm/{model => }/base.go | 2 +- orm/{model => }/helper.go | 26 ++++++- orm/{model => }/hook.go | 2 +- orm/service/orm_example.go | 52 ------------- orm/{service/orm.go => svc.go} | 138 +++++++++++++++++++++------------ orm/svc_test.go | 131 +++++++++++++++++++++++++++++++ 12 files changed, 397 insertions(+), 138 deletions(-) rename orm/{model => }/base.go (97%) rename orm/{model => }/helper.go (50%) rename orm/{model => }/hook.go (97%) delete mode 100644 orm/service/orm_example.go rename orm/{service/orm.go => svc.go} (57%) create mode 100644 orm/svc_test.go diff --git a/lib/db/db.go b/lib/db/db.go index 50df860..2338f17 100644 --- a/lib/db/db.go +++ b/lib/db/db.go @@ -20,9 +20,9 @@ var dbInstance *Instance func InitDB(conf Conf) { dialectR, dialectW := conf.GenDialector() //read instance - dbPtrR := initDB(conf, dialectR) + dbPtrR := mustInitDB(conf, dialectR) //write instance - dbPtrW := initDB(conf, dialectW) + dbPtrW := mustInitDB(conf, dialectW) dbInstance = &Instance{ R: dbPtrR, @@ -33,14 +33,15 @@ func InitDB(conf Conf) { // NewFreshDB 实例化全新的数据库链接 // // rInstance为读实例,wInstance为写实例 -func NewFreshDB(conf Conf) (rInstance, wInstance *gorm.DB) { +func NewFreshDB(conf Conf) (rInstance *gorm.DB, rErr error, wInstance *gorm.DB, wErr error) { dialectR, dialectW := conf.GenDialector() - rInstance, wInstance = initDB(conf, dialectR), initDB(conf, dialectW) + rInstance, rErr = initDB(conf, dialectR) + wInstance, wErr = initDB(conf, dialectW) return } -func initDB(conf Conf, dialect gorm.Dialector) *gorm.DB { +func mustInitDB(conf Conf, dialect gorm.Dialector) *gorm.DB { loggerSvc := NewZapLoggerForGorm(logger.GetLogger(), conf) loggerSvc.SetAsDefault() dbPtr, err := gorm.Open(dialect, &gorm.Config{ @@ -63,6 +64,29 @@ func initDB(conf Conf, dialect gorm.Dialector) *gorm.DB { return dbPtr } +func initDB(conf Conf, dialect gorm.Dialector) (*gorm.DB, error) { + loggerSvc := NewZapLoggerForGorm(logger.GetLogger(), conf) + loggerSvc.SetAsDefault() + dbPtr, err := gorm.Open(dialect, &gorm.Config{ + Logger: loggerSvc, + }) + if err != nil { + return nil, err + } + + sqlDB, err := dbPtr.DB() + if err != nil { + return nil, err + } + + sqlDB.SetMaxOpenConns(conf.ConnectionPool.MaxOpenConnCount) + sqlDB.SetMaxIdleConns(conf.ConnectionPool.MaxIdleConnCount) + sqlDB.SetConnMaxLifetime(time.Minute * time.Duration(conf.ConnectionPool.ConnMaxLifeTimeMinutes)) + sqlDB.SetConnMaxIdleTime(time.Minute * time.Duration(conf.ConnectionPool.ConnMaxIdleTimeMinutes)) + + return dbPtr, nil +} + // GetInstance 获取数据库实例 // // 获取由InitDB实例化后的连接 @@ -73,10 +97,8 @@ func GetInstance() *Instance { // New 初始化化全新的数据库链接 // // rInstance为读实例,wInstance为写实例 -func New(conf Conf) (rInstance, wInstance *gorm.DB) { - rInstance, wInstance = NewFreshDB(conf) - - return +func New(conf Conf) (rInstance *gorm.DB, rErr error, wInstance *gorm.DB, wErr error) { + return NewFreshDB(conf) } // Init 初始化数据库连接 diff --git a/lib/kafka/kafka.go b/lib/kafka/kafka.go index 8de40ec..de31f37 100644 --- a/lib/kafka/kafka.go +++ b/lib/kafka/kafka.go @@ -63,12 +63,14 @@ func Init(conf Conf, topic, groupID string) { // New 初始化连接 // // 该方法会初始化连接、读实例、写实例 -func New(conf Conf, topic, groupID string) ([]*kafkaLib.Conn, *kafkaLib.Writer, *kafkaLib.Reader) { - connections := NewConnections(conf) - writer := NewWriter(conf, topic) - reader := NewReader(conf, topic, groupID) - - return connections, writer, reader +func New(conf Conf, topic, groupID string) (connections []*kafkaLib.Conn, + writer *kafkaLib.Writer, wErr error, + reader *kafkaLib.Reader, rErr error) { + connections = NewConnections(conf) + writer, wErr = NewWriter(conf, topic) + reader, rErr = NewReader(conf, topic, groupID) + + return } // InitConnections 初始化连接 @@ -175,7 +177,7 @@ func InitWriter(conf Conf, topic string) { } // NewWriter 实例化新的写实例 -func NewWriter(conf Conf, topic string) *kafkaLib.Writer { +func NewWriter(conf Conf, topic string) (*kafkaLib.Writer, error) { writer := &kafkaLib.Writer{ Addr: kafkaLib.TCP(conf.Endpoints...), Topic: topic, @@ -189,14 +191,14 @@ func NewWriter(conf Conf, topic string) *kafkaLib.Writer { if len(conf.Username) != 0 && len(conf.Password) != 0 { mechanism, mErr := getMechanism(conf) if mErr != nil { - panic(mErr) + return nil, mErr } transport.SASL = mechanism } writer.Transport = transport - return writer + return writer, nil } // InitReader 初始化读实例 @@ -236,7 +238,7 @@ func InitReader(conf Conf, topic, groupID string) { } // NewReader 实例化新的读实例 -func NewReader(conf Conf, topic, groupID string) *kafkaLib.Reader { +func NewReader(conf Conf, topic, groupID string) (*kafkaLib.Reader, error) { if conf.Timeout < 1 { conf.Timeout = 10000 } @@ -252,7 +254,7 @@ func NewReader(conf Conf, topic, groupID string) *kafkaLib.Reader { if len(conf.Username) != 0 && len(conf.Password) != 0 { mechanism, mErr := getMechanism(conf) if mErr != nil { - panic(mErr) + return nil, mErr } dialer.SASLMechanism = mechanism } @@ -264,7 +266,7 @@ func NewReader(conf Conf, topic, groupID string) *kafkaLib.Reader { Dialer: dialer, }) - return reader + return reader, nil } // 根据SASL授权类型获取认证装置 diff --git a/lib/logger/zap.go b/lib/logger/zap.go index e439764..ee3a919 100644 --- a/lib/logger/zap.go +++ b/lib/logger/zap.go @@ -219,8 +219,13 @@ func exporterProvider(cfg Conf) zapcore.WriteSyncer { switch strings.ToLower(cfg.Exporter.Provider) { case "redis": + redisInstance, err := redis.New(cfg.Exporter.Redis.ConnConf) + if err != nil { + log.Println("[logger] using (redis) exporter, but initialize connection error: ", err) + return writer + } redisWriter := &redisWriterStd{ - cli: redis.New(cfg.Exporter.Redis.ConnConf), + cli: redisInstance, listKey: cfg.Exporter.Redis.ListKey, } @@ -228,8 +233,13 @@ func exporterProvider(cfg Conf) zapcore.WriteSyncer { log.Println("[logger] using (redis) exporter") return writer case "redis-cluster": + redisInstance, err := redis.NewCluster(cfg.Exporter.Redis.ClusterConnConf) + if err != nil { + log.Println("[logger] using (redis-cluster) exporter, but initialize connection error: ", err) + return writer + } redisWriter := &redisClusterWriterStd{ - cli: redis.NewCluster(cfg.Exporter.Redis.ClusterConnConf), + cli: redisInstance, listKey: cfg.Exporter.Redis.ListKey, } @@ -237,8 +247,13 @@ func exporterProvider(cfg Conf) zapcore.WriteSyncer { log.Println("[logger] using (redis-cluster) exporter") return writer case "nats": + natsInstance, err := nats.New(cfg.Exporter.Nats.ConnConf) + if err != nil { + log.Println("[logger] using (nats) exporter, but initialize connection error: ", err) + return writer + } natsWriter := &natsWriterStd{ - cli: nats.New(cfg.Exporter.Nats.ConnConf), + cli: natsInstance, subjectKey: cfg.Exporter.Nats.Subject, } @@ -246,8 +261,13 @@ func exporterProvider(cfg Conf) zapcore.WriteSyncer { log.Println("[logger] using (nats) exporter") return writer case "kafka": + kafkaInstance, err := kafka.NewWriter(cfg.Exporter.Kafka.ConnConf, cfg.Exporter.Kafka.Topic) + if err != nil { + log.Println("[logger] using (kafka) exporter, but initialize writer error: ", err) + return writer + } kafkaWriter := &kafkaWriterStd{ - writer: kafka.NewWriter(cfg.Exporter.Kafka.ConnConf, cfg.Exporter.Kafka.Topic), + writer: kafkaInstance, topic: cfg.Exporter.Kafka.Topic, } diff --git a/lib/nats/nats.go b/lib/nats/nats.go index 0d879c9..97e741c 100644 --- a/lib/nats/nats.go +++ b/lib/nats/nats.go @@ -10,12 +10,12 @@ var natsInstance *natsLib.Conn // Init 初始化 func Init(conf Conf) { - conn := initNats(conf) + conn := mustInitNats(conf) natsInstance = conn } -func initNats(conf Conf) *natsLib.Conn { +func mustInitNats(conf Conf) *natsLib.Conn { var opts natsLib.Option if len(conf.Username) != 0 && len(conf.Password) != 0 { opts = natsLib.UserInfo(conf.Username, conf.Password) @@ -30,12 +30,23 @@ func initNats(conf Conf) *natsLib.Conn { return conn } +func initNats(conf Conf) (*natsLib.Conn, error) { + var opts natsLib.Option + if len(conf.Username) != 0 && len(conf.Password) != 0 { + opts = natsLib.UserInfo(conf.Username, conf.Password) + } + + conn, err := natsLib.Connect(strings.Join(conf.Endpoints, ","), opts) + + return conn, err +} + // GetInstance 获取链接实例 func GetInstance() *natsLib.Conn { return natsInstance } // New 初始化新的nats实例 -func New(conf Conf) *natsLib.Conn { +func New(conf Conf) (*natsLib.Conn, error) { return initNats(conf) } diff --git a/lib/redis/redis.go b/lib/redis/redis.go index 19373b0..aee28a4 100644 --- a/lib/redis/redis.go +++ b/lib/redis/redis.go @@ -12,7 +12,7 @@ var redisInstance *redisLib.Client // InitRedis 初始化redis连接 func InitRedis(conf Conf) { - rdb := initRedis(conf) + rdb := mustInitRedis(conf) redisInstance = rdb } @@ -24,7 +24,7 @@ func GetInstance() *redisLib.Client { return redisInstance } -func initRedis(conf Conf) *redisLib.Client { +func mustInitRedis(conf Conf) *redisLib.Client { opts := &redisLib.Options{ Addr: fmt.Sprintf("%s:%d", conf.Host, conf.Port), Username: conf.Username, @@ -48,7 +48,28 @@ func initRedis(conf Conf) *redisLib.Client { return rdb } +func initRedis(conf Conf) (*redisLib.Client, error) { + opts := &redisLib.Options{ + Addr: fmt.Sprintf("%s:%d", conf.Host, conf.Port), + Username: conf.Username, + Password: conf.Password, + DB: conf.Database, + } + if conf.SSLEnable { + //https://redis.uptrace.dev/guide/go-redis.html#using-tls + // + //To enable TLS/SSL, you need to provide an empty tls.Config. + //If you're using private certs, you need to specify them in the tls.Config + opts.TLSConfig = &tls.Config{} + } + rdb := redisLib.NewClient(opts) + + err := rdb.Ping(context.Background()).Err() + + return rdb, err +} + // New 实例化新的实例 -func New(conf Conf) *redisLib.Client { +func New(conf Conf) (*redisLib.Client, error) { return initRedis(conf) } diff --git a/lib/redis/redis_cluster.go b/lib/redis/redis_cluster.go index 49cd29c..52ff9ea 100644 --- a/lib/redis/redis_cluster.go +++ b/lib/redis/redis_cluster.go @@ -18,7 +18,7 @@ var redisClusterInstance *redisLib.ClusterClient // InitRedisCluster 初始化redis集群连接 func InitRedisCluster(conf ClusterConf) { - rdb := initRedisCluster(conf) + rdb := mustInitRedisCluster(conf) redisClusterInstance = rdb } @@ -28,7 +28,7 @@ func GetClusterInstance() *redisLib.ClusterClient { return redisClusterInstance } -func initRedisCluster(conf ClusterConf) *redisLib.ClusterClient { +func mustInitRedisCluster(conf ClusterConf) *redisLib.ClusterClient { var ( endpoints = make([]string, len(conf.Endpoints)) username string @@ -71,7 +71,47 @@ func initRedisCluster(conf ClusterConf) *redisLib.ClusterClient { return rdb } +func initRedisCluster(conf ClusterConf) (*redisLib.ClusterClient, error) { + var ( + endpoints = make([]string, len(conf.Endpoints)) + username string + password string + ) + for i := 0; i < len(conf.Endpoints); i++ { + endpoints[i] = fmt.Sprintf("%s:%d", conf.Endpoints[i].Host, conf.Endpoints[i].Port) + if len(conf.Endpoints[i].Password) != 0 { + password = conf.Endpoints[i].Password + } + if len(conf.Endpoints[i].Username) != 0 { + username = conf.Endpoints[i].Username + } + } + opts := &redisLib.ClusterOptions{ + Addrs: endpoints, + Username: username, + Password: password, + MaxRedirects: len(conf.Endpoints) - 1, + } + if opts.MaxRedirects < 3 { + opts.MaxRedirects = 3 + } + if conf.SSLEnable { + //https://redis.uptrace.dev/guide/go-redis.html#using-tls + // + //To enable TLS/SSL, you need to provide an empty tls.Config. + //If you're using private certs, you need to specify them in the tls.Config + opts.TLSConfig = &tls.Config{} + } + rdb := redisLib.NewClusterClient(opts) + + err := rdb.ForEachShard(context.Background(), func(ctx context.Context, shard *redisLib.Client) error { + return shard.Ping(ctx).Err() + }) + + return rdb, err +} + // NewCluster 实例化新的实例 -func NewCluster(conf ClusterConf) *redisLib.ClusterClient { +func NewCluster(conf ClusterConf) (*redisLib.ClusterClient, error) { return initRedisCluster(conf) } diff --git a/orm/model/base.go b/orm/base.go similarity index 97% rename from orm/model/base.go rename to orm/base.go index 9d9fe97..dd473a1 100644 --- a/orm/model/base.go +++ b/orm/base.go @@ -1,4 +1,4 @@ -package model +package orm import ( "time" diff --git a/orm/model/helper.go b/orm/helper.go similarity index 50% rename from orm/model/helper.go rename to orm/helper.go index 8fcd26d..42f612d 100644 --- a/orm/model/helper.go +++ b/orm/helper.go @@ -1,6 +1,10 @@ -package model +package orm -import "gorm.io/gorm" +import ( + "errors" + + "gorm.io/gorm" +) // AutoMigrate 自动同步表结构 func AutoMigrate(db *gorm.DB, tables ...interface{}) error { @@ -26,3 +30,21 @@ func Paginate(page, pageSize int) func(db *gorm.DB) *gorm.DB { return db.Offset(page * pageSize).Limit(pageSize) } } + +// IgnoreErrRecordNotFound 忽略记录未找到的错误 +// +// @docs https://gorm.io/docs/v2_release_note.html#ErrRecordNotFound +// +// Example: +// +// err := IgnoreErrRecordNotFound(db.First()) +func IgnoreErrRecordNotFound(db *gorm.DB) error { + if errors.Is(db.Error, gorm.ErrRecordNotFound) { + if db.Statement != nil { + db.Statement.RaiseErrorOnNotFound = false + } + db.Error = nil + } + + return db.Error +} diff --git a/orm/model/hook.go b/orm/hook.go similarity index 97% rename from orm/model/hook.go rename to orm/hook.go index a1a5ef8..f25a2ea 100644 --- a/orm/model/hook.go +++ b/orm/hook.go @@ -1,4 +1,4 @@ -package model +package orm import ( "time" diff --git a/orm/service/orm_example.go b/orm/service/orm_example.go deleted file mode 100644 index d4eb8dc..0000000 --- a/orm/service/orm_example.go +++ /dev/null @@ -1,52 +0,0 @@ -package service - -import ( - "fmt" - - "github.com/keepchen/go-sail/v3/lib/db" - "github.com/keepchen/go-sail/v3/lib/logger" - "github.com/keepchen/go-sail/v3/orm/model" -) - -type User struct { - model.Base - UserID int64 `gorm:"column:user_id;type:bigint;not null;index:,unique;comment:用户ID"` - Nickname string `gorm:"column:nickname;type:varchar(30);comment:用户昵称"` - Status int `gorm:"column:status;type:tinyint;default:0;comment:用户状态"` -} - -func ExampleORMUsage() { - svc := NewORMSvcImpl(db.GetInstance().R, db.GetInstance().W, logger.GetLogger()) - - // ---- read one record - var user User - err := svc.R().Where(&User{UserID: 1000}).First(&user) - fmt.Println(err) - - // ---- create record - var user0 = User{ - UserID: 1000, - Nickname: "go-sail", - Status: 1, - } - err = svc.W().Create(&user0) - fmt.Println(err) - - // ---- force update all fields except some one - var user1 = User{ - UserID: 1000, - Nickname: "go-sail", - Status: 1, - } - err = svc.W().Select("*").Omit("deleted_at").Updates(&user1) - fmt.Println(err) - - // ---- paginate - var ( - users []User - page = 1 - pageSize = 50 - ) - total, err := svc.R().Paginate(users, page, pageSize) - fmt.Println(total, err) -} diff --git a/orm/service/orm.go b/orm/svc.go similarity index 57% rename from orm/service/orm.go rename to orm/svc.go index e0fe506..94edc5f 100644 --- a/orm/service/orm.go +++ b/orm/svc.go @@ -2,50 +2,50 @@ //注意,这个包的存在并而不是为了替代gorm。 // //目前已经包装了常规的创建、查询、更新、删除和分页方法。并且接受传入外部logger,此间的操 -//作方法日志会由传入的外部logger收集和输出。 +//作产生的日志会由传入的外部logger收集和输出。 // -//要指定读、写实例可以调用 R() 或者 W() 方法,请查阅orm_example.go文件。 +//要指定读、写实例可以调用 R() 或者 W() 方法,请查阅svc_test.go文件。 // //更高阶的方法调用,请使用gorm库提供的语法糖。 -package service +package orm import ( "database/sql" "errors" "github.com/keepchen/go-sail/v3/lib/logger" - "github.com/keepchen/go-sail/v3/orm/model" "go.uber.org/zap" "gorm.io/gorm" ) -type ORMSvc interface { - Model(value interface{}) ORMSvc - Where(query interface{}, args ...interface{}) ORMSvc - Or(query interface{}, args ...interface{}) ORMSvc - Not(query interface{}, args ...interface{}) ORMSvc - Joins(query string, args ...interface{}) ORMSvc - Select(query interface{}, args ...interface{}) ORMSvc - Omit(columns ...string) ORMSvc - Order(value interface{}) ORMSvc - Group(name string) ORMSvc - Offset(offset int) ORMSvc - Limit(limit int) ORMSvc - Having(query interface{}, args ...interface{}) ORMSvc - Scopes(fns ...func(*gorm.DB) *gorm.DB) ORMSvc +type Svc interface { + Model(value interface{}) Svc + Where(query interface{}, args ...interface{}) Svc + Or(query interface{}, args ...interface{}) Svc + Not(query interface{}, args ...interface{}) Svc + Joins(query string, args ...interface{}) Svc + Select(query interface{}, args ...interface{}) Svc + Omit(columns ...string) Svc + Order(value interface{}) Svc + Group(name string) Svc + Offset(offset int) Svc + Limit(limit int) Svc + Having(query interface{}, args ...interface{}) Svc + Scopes(fns ...func(*gorm.DB) *gorm.DB) Svc Create(value interface{}) error Find(dest interface{}, conditions ...interface{}) error + First(dest interface{}, conditions ...interface{}) error Updates(values interface{}) error Delete(value interface{}, conditions ...interface{}) error Transaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) (err error) //R 使用读实例 - R() ORMSvc + R() Svc //W 使用写实例 - W() ORMSvc + W() Svc //Paginate 分页查询(多行) // //参数: @@ -58,26 +58,34 @@ type ORMSvc interface { // //总条数和错误 Paginate(dest interface{}, page, pageSize int) (int64, error) + //FindOrNil 查询多条记录 + // + //如果记录不存在忽略 gorm.ErrRecordNotFound 错误 + FindOrNil(dest interface{}, conditions ...interface{}) error + //FirstOrNil 查询单条记录 + // + //如果记录不存在忽略 gorm.ErrRecordNotFound 错误 + FirstOrNil(dest interface{}, conditions ...interface{}) error } -type ORMSvcImpl struct { +type SvcImpl struct { dbr *gorm.DB dbw *gorm.DB tx *gorm.DB logger *zap.Logger } -var _ ORMSvc = (*ORMSvcImpl)(nil) +var _ Svc = (*SvcImpl)(nil) -var NewORMSvcImpl = func(dbr *gorm.DB, dbw *gorm.DB, logger *zap.Logger) ORMSvc { - return &ORMSvcImpl{ +var NewORMSvcImpl = func(dbr *gorm.DB, dbw *gorm.DB, logger *zap.Logger) Svc { + return &SvcImpl{ dbr: dbr, dbw: dbw, logger: logger, } } -func (a *ORMSvcImpl) Model(value interface{}) ORMSvc { +func (a *SvcImpl) Model(value interface{}) Svc { if a.tx == nil { a.tx = a.dbw } @@ -87,7 +95,7 @@ func (a *ORMSvcImpl) Model(value interface{}) ORMSvc { return a } -func (a *ORMSvcImpl) Where(query interface{}, args ...interface{}) ORMSvc { +func (a *SvcImpl) Where(query interface{}, args ...interface{}) Svc { if a.tx == nil { a.tx = a.dbw } @@ -97,7 +105,7 @@ func (a *ORMSvcImpl) Where(query interface{}, args ...interface{}) ORMSvc { return a } -func (a *ORMSvcImpl) Or(query interface{}, args ...interface{}) ORMSvc { +func (a *SvcImpl) Or(query interface{}, args ...interface{}) Svc { if a.tx == nil { a.tx = a.dbw } @@ -107,7 +115,7 @@ func (a *ORMSvcImpl) Or(query interface{}, args ...interface{}) ORMSvc { return a } -func (a *ORMSvcImpl) Not(query interface{}, args ...interface{}) ORMSvc { +func (a *SvcImpl) Not(query interface{}, args ...interface{}) Svc { if a.tx == nil { a.tx = a.dbw } @@ -117,7 +125,7 @@ func (a *ORMSvcImpl) Not(query interface{}, args ...interface{}) ORMSvc { return a } -func (a *ORMSvcImpl) Joins(query string, args ...interface{}) ORMSvc { +func (a *SvcImpl) Joins(query string, args ...interface{}) Svc { if a.tx == nil { a.tx = a.dbw } @@ -127,7 +135,7 @@ func (a *ORMSvcImpl) Joins(query string, args ...interface{}) ORMSvc { return a } -func (a *ORMSvcImpl) Select(query interface{}, args ...interface{}) ORMSvc { +func (a *SvcImpl) Select(query interface{}, args ...interface{}) Svc { if a.tx == nil { a.tx = a.dbw } @@ -137,7 +145,7 @@ func (a *ORMSvcImpl) Select(query interface{}, args ...interface{}) ORMSvc { return a } -func (a *ORMSvcImpl) Omit(columns ...string) ORMSvc { +func (a *SvcImpl) Omit(columns ...string) Svc { if a.tx == nil { a.tx = a.dbw } @@ -147,7 +155,7 @@ func (a *ORMSvcImpl) Omit(columns ...string) ORMSvc { return a } -func (a *ORMSvcImpl) Order(value interface{}) ORMSvc { +func (a *SvcImpl) Order(value interface{}) Svc { if a.tx == nil { a.tx = a.dbw } @@ -157,7 +165,7 @@ func (a *ORMSvcImpl) Order(value interface{}) ORMSvc { return a } -func (a *ORMSvcImpl) Group(name string) ORMSvc { +func (a *SvcImpl) Group(name string) Svc { if a.tx == nil { a.tx = a.dbw } @@ -167,7 +175,7 @@ func (a *ORMSvcImpl) Group(name string) ORMSvc { return a } -func (a *ORMSvcImpl) Having(query interface{}, args ...interface{}) ORMSvc { +func (a *SvcImpl) Having(query interface{}, args ...interface{}) Svc { if a.tx == nil { a.tx = a.dbw } @@ -177,7 +185,7 @@ func (a *ORMSvcImpl) Having(query interface{}, args ...interface{}) ORMSvc { return a } -func (a *ORMSvcImpl) Offset(offset int) ORMSvc { +func (a *SvcImpl) Offset(offset int) Svc { if a.tx == nil { a.tx = a.dbw } @@ -187,7 +195,7 @@ func (a *ORMSvcImpl) Offset(offset int) ORMSvc { return a } -func (a *ORMSvcImpl) Limit(limit int) ORMSvc { +func (a *SvcImpl) Limit(limit int) Svc { if a.tx == nil { a.tx = a.dbw } @@ -197,7 +205,7 @@ func (a *ORMSvcImpl) Limit(limit int) ORMSvc { return a } -func (a *ORMSvcImpl) Scopes(fns ...func(*gorm.DB) *gorm.DB) ORMSvc { +func (a *SvcImpl) Scopes(fns ...func(*gorm.DB) *gorm.DB) Svc { if a.tx == nil { a.tx = a.dbw } @@ -207,7 +215,7 @@ func (a *ORMSvcImpl) Scopes(fns ...func(*gorm.DB) *gorm.DB) ORMSvc { return a } -func (a *ORMSvcImpl) Create(value interface{}) error { +func (a *SvcImpl) Create(value interface{}) error { err := a.dbw.Create(value).Error if err != nil { @@ -219,12 +227,12 @@ func (a *ORMSvcImpl) Create(value interface{}) error { return err } -func (a *ORMSvcImpl) Find(dest interface{}, conditions ...interface{}) error { +func (a *SvcImpl) Find(dest interface{}, conditions ...interface{}) error { if a.tx == nil { a.tx = a.dbw } - err := a.tx.Find(&dest, conditions...).Error + err := a.tx.Find(dest, conditions...).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { a.logger.Error("[Database service]:Find:Error", @@ -236,7 +244,24 @@ func (a *ORMSvcImpl) Find(dest interface{}, conditions ...interface{}) error { return err } -func (a *ORMSvcImpl) First(dest interface{}, conditions ...interface{}) error { +func (a *SvcImpl) FindOrNil(dest interface{}, conditions ...interface{}) error { + if a.tx == nil { + a.tx = a.dbw + } + + err := IgnoreErrRecordNotFound(a.tx.Find(dest, conditions...)) + + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + a.logger.Error("[Database service]:Find:Error", + zap.String("value", logger.MarshalInterfaceValue(dest)), + zap.String("conditions", logger.MarshalInterfaceValue(conditions)), + zap.Errors("errors", []error{err})) + } + + return err +} + +func (a *SvcImpl) First(dest interface{}, conditions ...interface{}) error { if a.tx == nil { a.tx = a.dbw } @@ -253,7 +278,24 @@ func (a *ORMSvcImpl) First(dest interface{}, conditions ...interface{}) error { return err } -func (a *ORMSvcImpl) Updates(values interface{}) error { +func (a *SvcImpl) FirstOrNil(dest interface{}, conditions ...interface{}) error { + if a.tx == nil { + a.tx = a.dbw + } + + err := IgnoreErrRecordNotFound(a.tx.First(dest, conditions...)) + + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + a.logger.Error("[Database service]:First:Error", + zap.String("value", logger.MarshalInterfaceValue(dest)), + zap.String("conditions", logger.MarshalInterfaceValue(conditions)), + zap.Errors("errors", []error{err})) + } + + return err +} + +func (a *SvcImpl) Updates(values interface{}) error { if a.tx == nil { a.tx = a.dbw } @@ -269,7 +311,7 @@ func (a *ORMSvcImpl) Updates(values interface{}) error { return err } -func (a *ORMSvcImpl) Delete(value interface{}, conditions ...interface{}) error { +func (a *SvcImpl) Delete(value interface{}, conditions ...interface{}) error { if a.tx == nil { a.tx = a.dbw } @@ -285,7 +327,7 @@ func (a *ORMSvcImpl) Delete(value interface{}, conditions ...interface{}) error return err } -func (a *ORMSvcImpl) Transaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) (err error) { +func (a *SvcImpl) Transaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) (err error) { if a.tx == nil { a.tx = a.dbw } @@ -299,7 +341,7 @@ func (a *ORMSvcImpl) Transaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOpti return } -func (a *ORMSvcImpl) R() ORMSvc { +func (a *SvcImpl) R() Svc { if a.tx == nil { a.tx = a.dbr } @@ -307,7 +349,7 @@ func (a *ORMSvcImpl) R() ORMSvc { return a } -func (a *ORMSvcImpl) W() ORMSvc { +func (a *SvcImpl) W() Svc { if a.tx == nil { a.tx = a.dbw } @@ -315,7 +357,7 @@ func (a *ORMSvcImpl) W() ORMSvc { return a } -func (a *ORMSvcImpl) Paginate(dest interface{}, page, pageSize int) (int64, error) { +func (a *SvcImpl) Paginate(dest interface{}, page, pageSize int) (int64, error) { if a.tx == nil { a.tx = a.dbw } @@ -323,7 +365,7 @@ func (a *ORMSvcImpl) Paginate(dest interface{}, page, pageSize int) (int64, erro var total int64 a.tx.Count(&total) - err := a.tx.Scopes(model.Paginate(page, pageSize)).Find(&dest).Error + err := a.tx.Scopes(Paginate(page, pageSize)).Find(dest).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { a.logger.Error("[Database service]:Paginate:Error", diff --git a/orm/svc_test.go b/orm/svc_test.go new file mode 100644 index 0000000..9c7de2f --- /dev/null +++ b/orm/svc_test.go @@ -0,0 +1,131 @@ +package orm + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/keepchen/go-sail/v3/lib/db" + "github.com/keepchen/go-sail/v3/lib/logger" +) + +type User struct { + Base + UserID int64 `gorm:"column:user_id;type:bigint;not null;index:,unique;comment:用户ID"` + Nickname string `gorm:"column:nickname;type:varchar(30);comment:用户昵称"` + Status int `gorm:"column:status;type:tinyint;default:0;comment:用户状态"` +} + +func (*User) TableName() string { + return "user" +} + +var ( + loggerConf = logger.Conf{} + dbConf = db.Conf{ + Enable: false, + DriverName: "mysql", + AutoMigrate: true, + ConnectionPool: db.ConnectionPoolConf{ + MaxOpenConnCount: 100, + MaxIdleConnCount: 10, + ConnMaxLifeTimeMinutes: 30, + ConnMaxIdleTimeMinutes: 10, + }, + Mysql: db.MysqlConf{ + Read: db.MysqlConfItem{ + Host: "127.0.0.1", + Port: 33060, + Username: "root", + Password: "changeMe", + Database: "go-sail", + Charset: "utf8mb4", + ParseTime: true, + Loc: "Local", + }, + Write: db.MysqlConfItem{ + Host: "127.0.0.1", + Port: 33060, + Username: "root", + Password: "changeMe", + Database: "go-sail", + Charset: "utf8mb4", + ParseTime: true, + Loc: "Local", + }, + }, + } +) + +func TestSvcUsage(t *testing.T) { + logger.Init(loggerConf, "go-sail") + dbr, _, dbw, _ := db.New(dbConf) + //logger.Init(loggerConf) + if dbr == nil || dbw == nil { + t.Log("database instance is nil, testing not emit.") + return + } + _ = AutoMigrate(dbw, &User{}) + + svc := NewORMSvcImpl(dbr, dbw, logger.GetLogger()) + + dbw.Exec(fmt.Sprintf("truncate table %s", (&User{}).TableName())) + + // ---- ignore gorm.ErrRecordNotFound + var user0 User + err := svc.R().FirstOrNil(&user0) + t.Log("FirstOrNil:", err) + assert.NoError(t, err) + + // ---- create record + var user1 = User{ + UserID: 1000, + Nickname: "go-sail", + Status: 1, + } + err = svc.W().Create(&user1) + assert.NoError(t, err) + t.Log("Create:", user1) + + // ---- read one record + var user2 User + err = svc.R().Where(&User{UserID: 1000}).First(&user2) + assert.NoError(t, err) + t.Log("First:", user2) + + // ---- force update all fields except some one + var user3 = User{ + UserID: 1000, + Nickname: "go-sail", + Status: 2, + } + err = svc.W().Select("*").Omit("id", "deleted_at").Updates(&user3) + assert.NoError(t, err) + t.Log("Updates:", user3) + + // ---- read several records + var ( + users0 []User + ) + err = svc.R().Find(&users0) + assert.NoError(t, err) + t.Log("Find:", users0) + + // ---- ignore gorm.ErrRecordNotFound + var users1 User + err = svc.R().Where(&User{UserID: 99999}).FindOrNil(&users1) + t.Log("FindOrNil:", err) + assert.NoError(t, err) + + // ---- paginate + var ( + users2 []User + page = 1 + pageSize = 50 + ) + total, err := svc.R().Paginate(&users2, page, pageSize) + assert.NoError(t, err) + assert.Equal(t, int64(len(users2)), total) + t.Log("Paginate:", users2, total) +}