Skip to content

Commit

Permalink
emit subscription events on pause + resume (#327)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgentry authored May 3, 2024
1 parent c266c5d commit 377eac0
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 71 deletions.
51 changes: 34 additions & 17 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,9 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
MaxWorkers: queueConfig.MaxWorkers,
Notifier: client.notifier,
Queue: queue,
QueueEventCallback: func(event *Event) {
client.distributeQueueEvent(event)
},
RetryPolicy: config.RetryPolicy,
SchedulerInterval: config.schedulerInterval,
StatusFunc: client.monitor.SetProducerStatus,
Expand Down Expand Up @@ -909,8 +912,11 @@ func (c *Client[TTx]) SubscribeConfig(config *SubscribeConfig) (<-chan *Event, f
return subChan, cancel
}

// Distribute a single job into any listening subscriber channels.
func (c *Client[TTx]) distributeJob(job *rivertype.JobRow, stats *JobStatistics) {
// Distribute a single event into any listening subscriber channels.
//
// Job events should specify the job and stats, while queue events should only specify
// the queue.
func (c *Client[TTx]) distributeJobEvent(job *rivertype.JobRow, stats *JobStatistics) {
c.subscriptionsMu.Lock()
defer c.subscriptionsMu.Unlock()

Expand Down Expand Up @@ -948,6 +954,22 @@ func (c *Client[TTx]) distributeJob(job *rivertype.JobRow, stats *JobStatistics)
}
}

func (c *Client[TTx]) distributeQueueEvent(event *Event) {
c.subscriptionsMu.Lock()
defer c.subscriptionsMu.Unlock()

// All subscription channels are non-blocking so this is always fast and
// there's no risk of falling behind what producers are sending.
for _, sub := range c.subscriptions {
if sub.ListensFor(event.Kind) {
select {
case sub.Chan <- event:
default:
}
}
}
}

// Callback invoked by the completer and which prompts the client to update
// statistics and distribute jobs into any listening subscriber channels.
// (Subscriber channels are non-blocking so this should be quite fast.)
Expand All @@ -963,7 +985,7 @@ func (c *Client[TTx]) distributeJobCompleterCallback(update jobcompleter.Complet
c.statsNumJobs++
}()

c.distributeJob(update.Job, jobStatisticsFromInternal(update.JobStats))
c.distributeJobEvent(update.Job, jobStatisticsFromInternal(update.JobStats))
}

// Dump aggregate stats from job completions to logs periodically. These
Expand Down Expand Up @@ -1544,20 +1566,15 @@ func (c *Client[TTx]) maybeNotifyInsertForQueues(ctx context.Context, tx riverdr
}

// emit a notification about a queue being paused or resumed.
func (c *Client[TTx]) notifyQueuePauseOrResume(ctx context.Context, tx riverdriver.ExecutorTx, action, queue string, opts *QueuePauseOpts) error {
type queueStateChange struct {
Action string `json:"action"`
Queue string `json:"queue"`
}

func (c *Client[TTx]) notifyQueuePauseOrResume(ctx context.Context, tx riverdriver.ExecutorTx, action controlAction, queue string, opts *QueuePauseOpts) error {
c.baseService.Logger.DebugContext(ctx,
c.baseService.Name+": Notifying about queue state change",
slog.String("action", action),
slog.String("action", string(action)),
slog.String("queue", queue),
slog.String("opts", fmt.Sprintf("%+v", opts)),
)

payload, err := json.Marshal(queueStateChange{Action: action, Queue: queue})
payload, err := json.Marshal(jobControlPayload{Action: action, Queue: queue})
if err != nil {
return err
}
Expand Down Expand Up @@ -1737,18 +1754,18 @@ func (c *Client[TTx]) QueueList(ctx context.Context, params *QueueListParams) (*
// The provided context is used for the underlying Postgres update and can be
// used to cancel the operation or apply a timeout. The opts are reserved for
// future functionality.
func (c *Client[TTx]) QueuePause(ctx context.Context, queue string, opts *QueuePauseOpts) error {
func (c *Client[TTx]) QueuePause(ctx context.Context, name string, opts *QueuePauseOpts) error {
tx, err := c.driver.GetExecutor().Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)

if err = tx.QueuePause(ctx, queue); err != nil {
if err := tx.QueuePause(ctx, name); err != nil {
return err
}

if err = c.notifyQueuePauseOrResume(ctx, tx, "pause", queue, opts); err != nil {
if err := c.notifyQueuePauseOrResume(ctx, tx, controlActionPause, name, opts); err != nil {
return err
}

Expand All @@ -1767,18 +1784,18 @@ func (c *Client[TTx]) QueuePause(ctx context.Context, queue string, opts *QueueP
// The provided context is used for the underlying Postgres update and can be
// used to cancel the operation or apply a timeout. The opts are reserved for
// future functionality.
func (c *Client[TTx]) QueueResume(ctx context.Context, queue string, opts *QueuePauseOpts) error {
func (c *Client[TTx]) QueueResume(ctx context.Context, name string, opts *QueuePauseOpts) error {
tx, err := c.driver.GetExecutor().Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)

if err = tx.QueueResume(ctx, queue); err != nil {
if err := tx.QueueResume(ctx, name); err != nil {
return err
}

if err = c.notifyQueuePauseOrResume(ctx, tx, "resume", queue, opts); err != nil {
if err := c.notifyQueuePauseOrResume(ctx, tx, controlActionResume, name, opts); err != nil {
return err
}

Expand Down
97 changes: 43 additions & 54 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ func Test_Client(t *testing.T) {
EventKindJobCompleted,
EventKindJobFailed,
EventKindJobSnoozed,
EventKindQueuePaused,
EventKindQueueResumed,
)
t.Cleanup(cancel)
return subscribeChan
Expand Down Expand Up @@ -464,44 +466,36 @@ func Test_Client(t *testing.T) {
config, bundle := setupConfig(t)
client := newTestClient(t, bundle.dbPool, config)

jobStartedChan := make(chan int64)

type JobArgs struct {
JobArgsReflectKind[JobArgs]
}

AddWorker(client.config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error {
jobStartedChan <- job.ID
return nil
}))

subscribeChan := subscribe(t, client)
startClient(ctx, t, client)

client.producersByQueueName[QueueDefault].testSignals.Init()

insertRes1, err := client.Insert(ctx, &JobArgs{}, nil)
insertRes1, err := client.Insert(ctx, &noOpArgs{}, nil)
require.NoError(t, err)

startedJobID := riverinternaltest.WaitOrTimeout(t, jobStartedChan)
require.Equal(t, insertRes1.Job.ID, startedJobID)
event := riverinternaltest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)
require.Equal(t, insertRes1.Job.ID, event.Job.ID)

require.NoError(t, client.QueuePause(ctx, QueueDefault, nil))
client.producersByQueueName[QueueDefault].testSignals.Paused.WaitOrTimeout()
event = riverinternaltest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, &Event{Kind: EventKindQueuePaused, Queue: &rivertype.Queue{Name: QueueDefault}}, event)

insertRes2, err := client.Insert(ctx, &JobArgs{}, nil)
insertRes2, err := client.Insert(ctx, &noOpArgs{}, nil)
require.NoError(t, err)

select {
case <-jobStartedChan:
case <-subscribeChan:
t.Fatal("expected job 2 to not start on paused queue")
case <-time.After(500 * time.Millisecond):
}

require.NoError(t, client.QueueResume(ctx, QueueDefault, nil))
client.producersByQueueName[QueueDefault].testSignals.Resumed.WaitOrTimeout()
event = riverinternaltest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, &Event{Kind: EventKindQueueResumed, Queue: &rivertype.Queue{Name: QueueDefault}}, event)

startedJobID = riverinternaltest.WaitOrTimeout(t, jobStartedChan)
require.Equal(t, insertRes2.Job.ID, startedJobID)
event = riverinternaltest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)
require.Equal(t, insertRes2.Job.ID, event.Job.ID)
})

t.Run("PauseAndResumeMultipleQueues", func(t *testing.T) {
Expand All @@ -511,74 +505,69 @@ func Test_Client(t *testing.T) {
config.Queues["alternate"] = QueueConfig{MaxWorkers: 10}
client := newTestClient(t, bundle.dbPool, config)

jobStartedChan := make(chan int64)

type JobArgs struct {
JobArgsReflectKind[JobArgs]
}

AddWorker(client.config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error {
jobStartedChan <- job.ID
return nil
}))

subscribeChan := subscribe(t, client)
startClient(ctx, t, client)

client.producersByQueueName[QueueDefault].testSignals.Init()
client.producersByQueueName["alternate"].testSignals.Init()

insertRes1, err := client.Insert(ctx, &JobArgs{}, nil)
insertRes1, err := client.Insert(ctx, &noOpArgs{}, nil)
require.NoError(t, err)

startedJobID := riverinternaltest.WaitOrTimeout(t, jobStartedChan)
require.Equal(t, insertRes1.Job.ID, startedJobID)
event := riverinternaltest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)
require.Equal(t, insertRes1.Job.ID, event.Job.ID)

// Pause only the default queue:
require.NoError(t, client.QueuePause(ctx, QueueDefault, nil))
client.producersByQueueName[QueueDefault].testSignals.Paused.WaitOrTimeout()
event = riverinternaltest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, &Event{Kind: EventKindQueuePaused, Queue: &rivertype.Queue{Name: QueueDefault}}, event)

insertRes2, err := client.Insert(ctx, &JobArgs{}, nil)
insertRes2, err := client.Insert(ctx, &noOpArgs{}, nil)
require.NoError(t, err)

select {
case <-jobStartedChan:
case <-subscribeChan:
t.Fatal("expected job 2 to not start on paused queue")
case <-time.After(500 * time.Millisecond):
}

// alternate queue should still be running:
insertResAlternate1, err := client.Insert(ctx, &JobArgs{}, &InsertOpts{Queue: "alternate"})
insertResAlternate1, err := client.Insert(ctx, &noOpArgs{}, &InsertOpts{Queue: "alternate"})
require.NoError(t, err)

startedJobID = riverinternaltest.WaitOrTimeout(t, jobStartedChan)
require.Equal(t, insertResAlternate1.Job.ID, startedJobID)
event = riverinternaltest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)
require.Equal(t, insertResAlternate1.Job.ID, event.Job.ID)

// Pause all queues:
require.NoError(t, client.QueuePause(ctx, rivercommon.AllQueuesString, nil))
client.producersByQueueName["alternate"].testSignals.Paused.WaitOrTimeout()
event = riverinternaltest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, &Event{Kind: EventKindQueuePaused, Queue: &rivertype.Queue{Name: "alternate"}}, event)

insertResAlternate2, err := client.Insert(ctx, &JobArgs{}, &InsertOpts{Queue: "alternate"})
insertResAlternate2, err := client.Insert(ctx, &noOpArgs{}, &InsertOpts{Queue: "alternate"})
require.NoError(t, err)

select {
case <-jobStartedChan:
case <-subscribeChan:
t.Fatal("expected alternate job 2 to not start on paused queue")
case <-time.After(500 * time.Millisecond):
}

// Resume only the alternate queue:
require.NoError(t, client.QueueResume(ctx, "alternate", nil))
client.producersByQueueName["alternate"].testSignals.Resumed.WaitOrTimeout()
event = riverinternaltest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, &Event{Kind: EventKindQueueResumed, Queue: &rivertype.Queue{Name: "alternate"}}, event)

startedJobID = riverinternaltest.WaitOrTimeout(t, jobStartedChan)
require.Equal(t, insertResAlternate2.Job.ID, startedJobID)
event = riverinternaltest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)
require.Equal(t, insertResAlternate2.Job.ID, event.Job.ID)

// Resume all queues:
require.NoError(t, client.QueueResume(ctx, rivercommon.AllQueuesString, nil))
client.producersByQueueName[QueueDefault].testSignals.Resumed.WaitOrTimeout()
event = riverinternaltest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, &Event{Kind: EventKindQueueResumed, Queue: &rivertype.Queue{Name: QueueDefault}}, event)

startedJobID = riverinternaltest.WaitOrTimeout(t, jobStartedChan)
require.Equal(t, insertRes2.Job.ID, startedJobID)
event = riverinternaltest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)
require.Equal(t, insertRes2.Job.ID, event.Job.ID)
})

t.Run("PausedBeforeStart", func(t *testing.T) {
Expand Down
11 changes: 11 additions & 0 deletions event.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ const (

// EventKindJobSnoozed occurs when a job is snoozed.
EventKindJobSnoozed EventKind = "job_snoozed"

// EventKindQueuePaused occurs when a queue is paused.
EventKindQueuePaused EventKind = "queue_paused"

// EventKindQueueResumed occurs when a queue is resumed.
EventKindQueueResumed EventKind = "queue_resumed"
)

// All known event kinds, used to validate incoming kinds. This is purposely not
Expand All @@ -35,6 +41,8 @@ var allKinds = map[EventKind]struct{}{ //nolint:gochecknoglobals
EventKindJobCompleted: {},
EventKindJobFailed: {},
EventKindJobSnoozed: {},
EventKindQueuePaused: {},
EventKindQueueResumed: {},
}

// Event wraps an event that occurred within a River client, like a job being
Expand All @@ -50,6 +58,9 @@ type Event struct {

// JobStats are statistics about the run of a job.
JobStats *JobStatistics

// Queue contains queue-related information.
Queue *rivertype.Queue
}

// JobStatistics contains information about a single execution of a job.
Expand Down
9 changes: 9 additions & 0 deletions producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ type producerConfig struct {
Notifier *notifier.Notifier

Queue string
// QueueEventCallback gets called when a queue's config changes (such as
// pausing or resuming) events can be emitted to subscriptions.
QueueEventCallback func(event *Event)

// QueuePollInterval is the amount of time between periodic checks for
// queue setting changes. This is only used in poll-only mode (when no
Expand Down Expand Up @@ -427,13 +430,19 @@ func (p *producer) fetchAndRunLoop(fetchCtx, workCtx context.Context, fetchLimit
p.paused = true
p.Logger.DebugContext(workCtx, p.Name+": Paused", slog.String("queue", p.config.Queue), slog.String("queue_in_message", msg.Queue))
p.testSignals.Paused.Signal(struct{}{})
if p.config.QueueEventCallback != nil {
p.config.QueueEventCallback(&Event{Kind: EventKindQueuePaused, Queue: &rivertype.Queue{Name: p.config.Queue}})
}
case controlActionResume:
if !p.paused {
continue
}
p.paused = false
p.Logger.DebugContext(workCtx, p.Name+": Resumed", slog.String("queue", p.config.Queue), slog.String("queue_in_message", msg.Queue))
p.testSignals.Resumed.Signal(struct{}{})
if p.config.QueueEventCallback != nil {
p.config.QueueEventCallback(&Event{Kind: EventKindQueueResumed, Queue: &rivertype.Queue{Name: p.config.Queue}})
}
case controlActionCancel:
// Separate this case to make linter happy:
p.Logger.DebugContext(workCtx, p.Name+": Unhandled queue control action", "action", msg.Action)
Expand Down

0 comments on commit 377eac0

Please sign in to comment.