diff --git a/client.go b/client.go index d23959ec..7b1f5e02 100644 --- a/client.go +++ b/client.go @@ -1441,14 +1441,41 @@ func (c *Client[TTx]) InsertManyTx(ctx context.Context, tx TTx, params []InsertM } func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx, params []InsertManyParams) ([]*rivertype.JobInsertResult, error) { + return c.insertManyShared(ctx, tx, params, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) { + results, err := tx.JobInsertFastMany(ctx, insertParams) + if err != nil { + return nil, err + } + + return sliceutil.Map(results, + func(result *riverdriver.JobInsertFastResult) *rivertype.JobInsertResult { + return (*rivertype.JobInsertResult)(result) + }, + ), nil + }) +} + +// The shared code path for all InsertMany methods. It takes a function that +// executes the actual insert operation and allows for different implementations +// of the insert query to be passed in, each mapping their results back to a +// common result type. +// +// TODO(bgentry): this isn't yet used for the single insert path. The only thing +// blocking that is the removal of advisory lock unique inserts. +func (c *Client[TTx]) insertManyShared( + ctx context.Context, + tx riverdriver.ExecutorTx, + params []InsertManyParams, + execute func(context.Context, []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error), +) ([]*rivertype.JobInsertResult, error) { insertParams, err := c.insertManyParams(params) if err != nil { return nil, err } - jobRows, err := tx.JobInsertFastMany(ctx, insertParams) + inserted, err := execute(ctx, insertParams) if err != nil { - return nil, err + return inserted, err } queues := make([]string, 0, 10) @@ -1460,12 +1487,7 @@ func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx, if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil { return nil, err } - - return sliceutil.Map(jobRows, - func(result *riverdriver.JobInsertFastResult) *rivertype.JobInsertResult { - return (*rivertype.JobInsertResult)(result) - }, - ), nil + return inserted, nil } // Validates input parameters for a batch insert operation and generates a set @@ -1516,11 +1538,6 @@ func (c *Client[TTx]) InsertManyFast(ctx context.Context, params []InsertManyPar return 0, errNoDriverDBPool } - insertParams, err := c.insertManyFastParams(params) - if err != nil { - return 0, err - } - // Wrap in a transaction in case we need to notify about inserts. tx, err := c.driver.GetExecutor().Begin(ctx) if err != nil { @@ -1528,7 +1545,7 @@ func (c *Client[TTx]) InsertManyFast(ctx context.Context, params []InsertManyPar } defer tx.Rollback(ctx) - inserted, err := c.insertManyFast(ctx, tx, insertParams) + inserted, err := c.insertManyFast(ctx, tx, params) if err != nil { return 0, err } @@ -1562,54 +1579,23 @@ func (c *Client[TTx]) InsertManyFast(ctx context.Context, params []InsertManyPar // unique conflicts cannot be handled gracefully. If a unique constraint is // violated, the operation will fail and no jobs will be inserted. func (c *Client[TTx]) InsertManyFastTx(ctx context.Context, tx TTx, params []InsertManyParams) (int, error) { - insertParams, err := c.insertManyFastParams(params) - if err != nil { - return 0, err - } - exec := c.driver.UnwrapExecutor(tx) - return c.insertManyFast(ctx, exec, insertParams) -} - -func (c *Client[TTx]) insertManyFast(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*riverdriver.JobInsertFastParams) (int, error) { - inserted, err := tx.JobInsertFastManyNoReturning(ctx, insertParams) - if err != nil { - return inserted, err - } - - queues := make([]string, 0, 10) - for _, params := range insertParams { - if params.State == rivertype.JobStateAvailable { - queues = append(queues, params.Queue) - } - } - if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil { - return 0, err - } - return inserted, nil + return c.insertManyFast(ctx, exec, params) } -// Validates input parameters for an a batch insert operation and generates a -// set of batch insert parameters. -func (c *Client[TTx]) insertManyFastParams(params []InsertManyParams) ([]*riverdriver.JobInsertFastParams, error) { - if len(params) < 1 { - return nil, errors.New("no jobs to insert") - } - - insertParams := make([]*riverdriver.JobInsertFastParams, len(params)) - for i, param := range params { - if err := c.validateJobArgs(param.Args); err != nil { - return nil, err - } - - insertParamsItem, _, err := insertParamsFromConfigArgsAndOptions(&c.baseService.Archetype, c.config, param.Args, param.InsertOpts, true) +func (c *Client[TTx]) insertManyFast(ctx context.Context, tx riverdriver.ExecutorTx, params []InsertManyParams) (int, error) { + results, err := c.insertManyShared(ctx, tx, params, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) { + count, err := tx.JobInsertFastManyNoReturning(ctx, insertParams) if err != nil { return nil, err } - insertParams[i] = insertParamsItem + return make([]*rivertype.JobInsertResult, count), nil + }) + if err != nil { + return 0, err } - return insertParams, nil + return len(results), nil } func (c *Client[TTx]) maybeNotifyInsert(ctx context.Context, tx riverdriver.ExecutorTx, state rivertype.JobState, queue string) error {