Skip to content

Commit

Permalink
Change to json.RawMessage for messages (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
craigpastro authored Nov 8, 2023
1 parent fa8eec8 commit 539f177
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func main() {
panic(err)
}

id, err := q.Send(ctx, "my_queue", map[string]any{"foo": "bar"})
id, err := q.Send(ctx, "my_queue", json.RawMessage(`{"foo": "bar"}`))
if err != nil {
panic(err)
}
Expand Down
11 changes: 6 additions & 5 deletions pgmq.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgmq

import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
Expand All @@ -22,7 +23,7 @@ type Message struct {
// VT is "visibility time". The UTC timestamp at which the message will
// be available for reading again.
VT time.Time
Message map[string]any
Message json.RawMessage
}

type DB interface {
Expand Down Expand Up @@ -116,13 +117,13 @@ func (p *PGMQ) DropQueue(ctx context.Context, queue string) error {

// Send sends a single message to a queue. The message id, unique to the
// queue, is returned.
func (p *PGMQ) Send(ctx context.Context, queue string, msg map[string]any) (int64, error) {
func (p *PGMQ) Send(ctx context.Context, queue string, msg json.RawMessage) (int64, error) {
return p.SendWithDelay(ctx, queue, msg, 0)
}

// SendWithDelay sends a single message to a queue with a delay. The delay
// is specified in seconds. The message id, unique to the queue, is returned.
func (p *PGMQ) SendWithDelay(ctx context.Context, queue string, msg map[string]any, delay int) (int64, error) {
func (p *PGMQ) SendWithDelay(ctx context.Context, queue string, msg json.RawMessage, delay int) (int64, error) {
var msgID int64
err := p.db.
QueryRow(ctx, "SELECT * FROM pgmq.send($1, $2, $3)", queue, msg, delay).
Expand All @@ -136,14 +137,14 @@ func (p *PGMQ) SendWithDelay(ctx context.Context, queue string, msg map[string]a

// SendBatch sends a batch of messages to a queue. The message ids, unique to
// the queue, are returned.
func (p *PGMQ) SendBatch(ctx context.Context, queue string, msgs []map[string]any) ([]int64, error) {
func (p *PGMQ) SendBatch(ctx context.Context, queue string, msgs []json.RawMessage) ([]int64, error) {
return p.SendBatchWithDelay(ctx, queue, msgs, 0)
}

// SendBatchWithDelay sends a batch of messages to a queue with a delay. The
// delay is specified in seconds. The message ids, unique to the queue, are
// returned.
func (p *PGMQ) SendBatchWithDelay(ctx context.Context, queue string, msgs []map[string]any, delay int) ([]int64, error) {
func (p *PGMQ) SendBatchWithDelay(ctx context.Context, queue string, msgs []json.RawMessage, delay int) ([]int64, error) {
rows, err := p.db.Query(ctx, "SELECT * FROM pgmq.send_batch($1, $2::jsonb[], $3)", queue, msgs, delay)
if err != nil {
return nil, wrapPostgresError(err)
Expand Down
54 changes: 47 additions & 7 deletions pgmq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgmq

import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
Expand All @@ -20,8 +21,8 @@ import (
var q *PGMQ

var (
testMsg1 = map[string]any{"foo": "bar1"}
testMsg2 = map[string]any{"foo": "bar2"}
testMsg1 = json.RawMessage(`{"foo": "bar1"}`)
testMsg2 = json.RawMessage(`{"foo": "bar2"}`)
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -124,14 +125,53 @@ func TestSend(t *testing.T) {
require.EqualValues(t, 2, id)
}

func TestSendAMarshalledStruct(t *testing.T) {
type A struct {
Val int `json:"val"`
}

a := A{3}
b, err := json.Marshal(a)
require.NoError(t, err)

ctx := context.Background()
queue := t.Name()

err = q.CreateQueue(ctx, queue)
require.NoError(t, err)

_, err = q.Send(ctx, queue, b)
require.NoError(t, err)

msg, err := q.Read(ctx, queue, 0)
require.NoError(t, err)

var aa A
err = json.Unmarshal(msg.Message, &aa)
require.NoError(t, err)

require.EqualValues(t, a, aa)
}

func TestSendInvalidJSONFails(t *testing.T) {
ctx := context.Background()
queue := t.Name()

err := q.CreateQueue(ctx, queue)
require.NoError(t, err)

_, err = q.Send(ctx, queue, json.RawMessage(`{"foo":}`))
require.Error(t, err)
}

func TestSendBatch(t *testing.T) {
ctx := context.Background()
queue := t.Name()

err := q.CreateQueue(ctx, queue)
require.NoError(t, err)

ids, err := q.SendBatch(ctx, queue, []map[string]any{testMsg1, testMsg2})
ids, err := q.SendBatch(ctx, queue, []json.RawMessage{testMsg1, testMsg2})
require.NoError(t, err)
require.Equal(t, []int64{1, 2}, ids)
}
Expand Down Expand Up @@ -174,7 +214,7 @@ func TestReadBatch(t *testing.T) {
err := q.CreateQueue(ctx, queue)
require.NoError(t, err)

_, err = q.SendBatch(ctx, queue, []map[string]any{testMsg1, testMsg2})
_, err = q.SendBatch(ctx, queue, []json.RawMessage{testMsg1, testMsg2})
require.NoError(t, err)

time.Sleep(time.Second)
Expand Down Expand Up @@ -264,7 +304,7 @@ func TestArchiveBatch(t *testing.T) {
err := q.CreateQueue(ctx, queue)
require.NoError(t, err)

ids, err := q.SendBatch(ctx, queue, []map[string]any{testMsg1, testMsg2})
ids, err := q.SendBatch(ctx, queue, []json.RawMessage{testMsg1, testMsg2})
require.NoError(t, err)

archived, err := q.ArchiveBatch(ctx, queue, ids)
Expand Down Expand Up @@ -318,7 +358,7 @@ func TestDeleteBatch(t *testing.T) {
err := q.CreateQueue(ctx, queue)
require.NoError(t, err)

ids, err := q.SendBatch(ctx, queue, []map[string]any{testMsg1, testMsg2})
ids, err := q.SendBatch(ctx, queue, []json.RawMessage{testMsg1, testMsg2})
require.NoError(t, err)

deleted, err := q.DeleteBatch(ctx, queue, ids)
Expand Down Expand Up @@ -372,7 +412,7 @@ func TestErrorCases(t *testing.T) {

t.Run("sendBatchError", func(t *testing.T) {
mockDB.EXPECT().Query(ctx, "SELECT * FROM pgmq.send_batch($1, $2::jsonb[], $3)", queue, gomock.Any(), 0).Return(nil, testErr)
ids, err := q.SendBatch(ctx, queue, []map[string]any{testMsg1})
ids, err := q.SendBatch(ctx, queue, []json.RawMessage{testMsg1})
require.Nil(t, ids)
require.ErrorContains(t, err, "postgres error")
})
Expand Down

0 comments on commit 539f177

Please sign in to comment.