diff --git a/mocks/pgmq.go b/mocks/pgmq.go index 1ee410e..1213cfb 100644 --- a/mocks/pgmq.go +++ b/mocks/pgmq.go @@ -72,6 +72,20 @@ func (mr *MockDBMockRecorder) Exec(ctx, sql any, args ...any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockDB)(nil).Exec), varargs...) } +// Ping mocks base method. +func (m *MockDB) Ping(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Ping", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Ping indicates an expected call of Ping. +func (mr *MockDBMockRecorder) Ping(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockDB)(nil).Ping), ctx) +} + // Query mocks base method. func (m *MockDB) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { m.ctrl.T.Helper() diff --git a/pgmq.go b/pgmq.go index 2a8ac0c..67df61d 100644 --- a/pgmq.go +++ b/pgmq.go @@ -27,6 +27,7 @@ type Message struct { } type DB interface { + Ping(ctx context.Context) error Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row @@ -51,31 +52,27 @@ func New(ctx context.Context, connString string) (*PGMQ, error) { return nil, fmt.Errorf("error creating pool: %w", err) } - err = pool.Ping(ctx) - if err != nil { + return NewFromDB(ctx, pool) +} + +// NewFromDB is a bring your own DB version of New. Given an implementation +// of DB, it will call Ping to ensure the connection has been established, +// then create the PGMQ extension if it does not already exist. +func NewFromDB(ctx context.Context, db DB) (*PGMQ, error) { + if err := db.Ping(ctx); err != nil { return nil, err } - _, err = pool.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS pgmq CASCADE") + _, err := db.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS pgmq CASCADE") if err != nil { return nil, fmt.Errorf("error creating pgmq extension: %w", err) } return &PGMQ{ - db: pool, + db: db, }, nil } -// MustNew is similar to New, but panics if it encounters an error. -func MustNew(ctx context.Context, connString string) *PGMQ { - q, err := New(ctx, connString) - if err != nil { - panic(err) - } - - return q -} - // Close closes the underlying connection pool. func (p *PGMQ) Close() { p.db.Close()