diff --git a/extras/kms/repository.go b/extras/kms/repository.go index 1d8c7e27..6ebd4d16 100644 --- a/extras/kms/repository.go +++ b/extras/kms/repository.go @@ -68,14 +68,14 @@ func newRepository(r dbw.Reader, w dbw.Writer, opt ...Option) (*repository, erro if w == nil { return nil, fmt.Errorf("%s: nil writer: %w", op, ErrInvalidParameter) } - if _, err := validateSchema(context.Background(), r); err != nil { - return nil, fmt.Errorf("%s: %w", op, err) - } opts := getOpts(opt...) if opts.withLimit == 0 { // zero signals the defaults should be used. opts.withLimit = defaultLimit } + if _, err := validateSchema(context.Background(), r, opts.withTableNamePrefix); err != nil { + return nil, fmt.Errorf("%s: %w", op, err) + } return &repository{ reader: r, writer: w, @@ -88,13 +88,15 @@ func newRepository(r dbw.Reader, w dbw.Writer, opt ...Option) (*repository, erro // required migrations.Version func (r *repository) ValidateSchema(ctx context.Context) (string, error) { const op = "kms.(repository).validateVersion" - return validateSchema(ctx, r.reader) + return validateSchema(ctx, r.reader, r.tableNamePrefix) } -func validateSchema(ctx context.Context, r dbw.Reader) (string, error) { +func validateSchema(ctx context.Context, r dbw.Reader, tableNamePrefix string) (string, error) { const op = "kms.validateSchema" - var s schema - if err := r.LookupWhere(ctx, &s, "1=1", nil); err != nil { + s := schema{ + tableNamePrefix: tableNamePrefix, + } + if err := r.LookupWhere(ctx, &s, "1=1", nil, dbw.WithTable(s.TableName())); err != nil { return "", fmt.Errorf("%s: unable to get version: %w", op, err) } if s.Version != migrations.Version { diff --git a/extras/kms/schema.go b/extras/kms/schema.go index 5d451cdf..c5e85f65 100644 --- a/extras/kms/schema.go +++ b/extras/kms/schema.go @@ -3,7 +3,10 @@ package kms -import "time" +import ( + "fmt" + "time" +) // schema represents the current schema in the database type schema struct { @@ -13,7 +16,15 @@ type schema struct { UpdateTime time.Time // CreateTime is the create time of the initial version CreateTime time.Time + + // tableNamePrefix defines the prefix to use before the table name and + // allows us to support custom prefixes as well as multi KMSs within a + // single schema. + tableNamePrefix string `gorm:"-"` } -// TableName defines the table name for the Version type -func (v *schema) TableName() string { return "kms_schema_version" } +// TableName returns the table name +func (k *schema) TableName() string { + const tableName = "schema_version" + return fmt.Sprintf("%s_%s", k.tableNamePrefix, tableName) +}