diff --git a/redis/manager.go b/redis/manager.go new file mode 100644 index 00000000..b1466b4a --- /dev/null +++ b/redis/manager.go @@ -0,0 +1,37 @@ +package redis + +import "github.com/redis/go-redis/v9" + +type Manager struct { + redis.Cmdable + + connections map[string]redis.Cmdable +} + +func New(db redis.Cmdable) *Manager { + return &Manager{ + Cmdable: db, + connections: make(map[string]redis.Cmdable), + } +} + +func (m *Manager) Register(name string, db redis.Cmdable) { + m.connections[name] = db +} + +func (m *Manager) Conn(names ...string) redis.Cmdable { + var name string + if len(names) > 0 { + name = names[0] + } + + if name == "" { + return m.Cmdable + } + + if c, ok := m.connections[name]; ok { + return c + } + + panic("redis: the connection [" + name + "] is not registered.") +} diff --git a/redis/manager_test.go b/redis/manager_test.go new file mode 100644 index 00000000..96a0c8b1 --- /dev/null +++ b/redis/manager_test.go @@ -0,0 +1,54 @@ +package redis + +import ( + "context" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +func TestManager(t *testing.T) { + ctx := context.Background() + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + }) + + m := New(rdb) + + m.Register("rdb2", rdb) + + assert.Equal(t, rdb, m.Cmdable) + assert.Equal(t, rdb, m.Conn()) + assert.Equal(t, rdb, m.Conn("rdb2")) + + assert.Panics(t, func() { + m.Conn("rdb3") + }) + + var ( + key1 = "redis:manager:key1" + key2 = "redis:manager:key2" + key3 = "redis:manager:key3" + ) + + m.Set(ctx, key1, "value1", time.Second*10) + m.Conn().Set(ctx, key2, "value2", time.Second*10) + m.Conn("rdb2").Set(ctx, key3, "value3", time.Second*10) + + // default connection + val, err := m.Get(ctx, key1).Result() + assert.NoError(t, err) + assert.Equal(t, "value1", val) + + // empty name connection + val, err = m.Conn().Get(ctx, key2).Result() + assert.NoError(t, err) + assert.Equal(t, "value2", val) + + // use named connection + val, err = m.Conn("rdb2").Get(ctx, key3).Result() + assert.NoError(t, err) + assert.Equal(t, "value3", val) +}