diff --git a/examples/replicate.go b/examples/replicate.go index 46f0be0..803fc63 100644 --- a/examples/replicate.go +++ b/examples/replicate.go @@ -17,7 +17,7 @@ func main() { log.Fatal(err) } - set := pgoutput.NewRelationSet() + set := pgoutput.NewRelationSet(nil) dump := func(relation uint32, row []pgoutput.Tuple) error { values, err := set.Values(relation, row) @@ -31,8 +31,7 @@ func main() { return nil } - handler := func(m pgoutput.Message) error { - return fmt.Errorf("hey") + handler := func(m pgoutput.Message, walPos uint64) error { switch v := m.(type) { case pgoutput.Relation: log.Printf("RELATION") @@ -50,8 +49,8 @@ func main() { return nil } - sub := pgoutput.NewSubscription("sub1", "pub1") - if err := sub.Start(ctx, conn, handler); err != nil { + sub := pgoutput.NewSubscription(conn, "sub1", "pub1", 0, false) + if err := sub.Start(ctx, 0, handler); err != nil { log.Fatal(err) } } diff --git a/parse.go b/parse.go index 4b413da..9dab528 100644 --- a/parse.go +++ b/parse.go @@ -129,6 +129,10 @@ type Relation struct { Columns []Column } +func (r Relation) IsEmpty() bool { + return r.ID == 0 && r.Name == "" && r.Replica == 0 && len(r.Columns) == 0 +} + type Type struct { // ID of the data type ID uint32 diff --git a/parse_test.go b/parse_test.go index 59422e2..08be3fa 100644 --- a/parse_test.go +++ b/parse_test.go @@ -51,7 +51,7 @@ func GenerateLogicalReplicationFiles(t *testing.T) { func TestParseWalData(t *testing.T) { files, _ := filepath.Glob("testdata/*") - set := NewRelationSet() + set := NewRelationSet(nil) expected := map[int]struct { ID int32 diff --git a/sub.go b/sub.go index 641c800..393a983 100644 --- a/sub.go +++ b/sub.go @@ -3,7 +3,8 @@ package pgoutput import ( "context" "fmt" - "log" + "sync" + "sync/atomic" "time" "github.com/jackc/pgx" @@ -14,120 +15,190 @@ type Subscription struct { Publication string WaitTimeout time.Duration StatusTimeout time.Duration - CopyData bool + + conn *pgx.ReplicationConn + maxWal uint64 + walRetain uint64 + walFlushed uint64 + + failOnHandler bool + + // Mutex is used to prevent reading and writing to a connection at the same time + sync.Mutex } -type Handler func(Message) error +type Handler func(Message, uint64) error -func NewSubscription(name, publication string) *Subscription { +func NewSubscription(conn *pgx.ReplicationConn, name, publication string, walRetain uint64, failOnHandler bool) *Subscription { return &Subscription{ Name: name, Publication: publication, - WaitTimeout: time.Second * 10, - StatusTimeout: time.Second * 10, - CopyData: true, + WaitTimeout: 1 * time.Second, + StatusTimeout: 10 * time.Second, + + conn: conn, + walRetain: walRetain, + failOnHandler: failOnHandler, } } func pluginArgs(version, publication string) string { - return fmt.Sprintf(`("proto_version" '%s', "publication_names" '%s')`, version, publication) + return fmt.Sprintf(`"proto_version" '%s', "publication_names" '%s'`, version, publication) } -func (s *Subscription) Start(ctx context.Context, conn *pgx.ReplicationConn, h Handler) error { - // TODO: Struct Validation here - _ = conn.DropReplicationSlot(s.Name) - +// CreateSlot creates a replication slot if it doesn't exist +func (s *Subscription) CreateSlot() (err error) { // If creating the replication slot fails with code 42710, this means // the replication slot already exists. - err := conn.CreateReplicationSlot(s.Name, "pgoutput") - if err != nil { + if err = s.conn.CreateReplicationSlot(s.Name, "pgoutput"); err != nil { pgerr, ok := err.(pgx.PgError) - if !ok { - return fmt.Errorf("failed to create replication slot: %s", err) - } - if pgerr.Code != "42710" { - return fmt.Errorf("failed to create replication slot: %s", err) + if !ok || pgerr.Code != "42710" { + return } + + err = nil + } + + return +} + +func (s *Subscription) sendStatus(walWrite, walFlush uint64) error { + if walFlush > walWrite { + return fmt.Errorf("walWrite should be >= walFlush") } - // rows, err := conn.IdentifySystem() - // if err != nil { - // return err - // } + s.Lock() + defer s.Unlock() - // var slotName, consitentPoint, snapshotName, outputPlugin string - // if err := row.Scan(&slotName, &consitentPoint, &snapshotName, &outputPlugin); err != nil { - // return err - // } + k, err := pgx.NewStandbyStatus(walFlush, walFlush, walWrite) + if err != nil { + return fmt.Errorf("error creating status: %s", err) + } - // log.Printf("slotName: %s\n", slotName) - // log.Printf("consitentPoint: %s\n", consitentPoint) - // log.Printf("snapshotName: %s\n", snapshotName) - // log.Printf("outputPlugin: %s\n", outputPlugin) + if err = s.conn.SendStandbyStatus(k); err != nil { + return err + } - // Open a transaction on the server - // SET TRANSACTION SNAPSHOT id - // read all the data from the tables + return nil +} - err = conn.StartReplication(s.Name, 0, -1, pluginArgs("1", s.Publication)) +// Flush sends the status message to server indicating that we've fully applied all of the events until maxWal. +// This allows PostgreSQL to purge it's WAL logs +func (s *Subscription) Flush() error { + wp := atomic.LoadUint64(&s.maxWal) + err := s.sendStatus(wp, wp) + if err == nil { + atomic.StoreUint64(&s.walFlushed, wp) + } + + return err +} + +// Start replication and block until error or ctx is canceled +func (s *Subscription) Start(ctx context.Context, startLSN uint64, h Handler) (err error) { + err = s.conn.StartReplication(s.Name, startLSN, -1, pluginArgs("1", s.Publication)) if err != nil { return fmt.Errorf("failed to start replication: %s", err) } - var maxWal uint64 + s.maxWal = startLSN sendStatus := func() error { - k, err := pgx.NewStandbyStatus(maxWal) - if err != nil { - return fmt.Errorf("error creating standby status: %s", err) + walPos := atomic.LoadUint64(&s.maxWal) + walLastFlushed := atomic.LoadUint64(&s.walFlushed) + + // Confirm only walRetain bytes in past + // If walRetain is zero - will confirm current walPos as flushed + walFlush := walPos - s.walRetain + + if walLastFlushed > walFlush { + // If there was a manual flush - report it's position until we're past it + walFlush = walLastFlushed + } else if walFlush < 0 { + // If we have less than walRetain bytes - just report zero + walFlush = 0 } - if err := conn.SendStandbyStatus(k); err != nil { - return fmt.Errorf("failed to send standy status: %s", err) - } - return nil + + return s.sendStatus(walPos, walFlush) } - tick := time.NewTicker(s.StatusTimeout).C + go func() { + tick := time.NewTicker(s.StatusTimeout) + defer tick.Stop() + + for { + select { + case <-tick.C: + if err = sendStatus(); err != nil { + return + } + + case <-ctx.Done(): + return + } + } + }() + for { select { - case <-tick: - log.Println("pub status") - if maxWal == 0 { - continue - } - if err := sendStatus(); err != nil { - return err + case <-ctx.Done(): + // Send final status and exit + if err = sendStatus(); err != nil { + return fmt.Errorf("Unable to send final status: %s", err) } + + return + default: var message *pgx.ReplicationMessage wctx, cancel := context.WithTimeout(ctx, s.WaitTimeout) - message, err = conn.WaitForReplicationMessage(wctx) + s.Lock() + message, err = s.conn.WaitForReplicationMessage(wctx) + s.Unlock() cancel() + if err == context.DeadlineExceeded { continue - } - if err != nil { + } else if err == context.Canceled { + return + } else if err != nil { return fmt.Errorf("replication failed: %s", err) } + + if message == nil { + return fmt.Errorf("replication failed: nil message received, should not happen") + } + if message.WalMessage != nil { - if message.WalMessage.WalStart > maxWal { - maxWal = message.WalMessage.WalStart + var logmsg Message + walStart := message.WalMessage.WalStart + + // Skip stuff that's in the past + if walStart > 0 && walStart <= startLSN { + continue } - logmsg, err := Parse(message.WalMessage.WalData) + + if walStart > atomic.LoadUint64(&s.maxWal) { + atomic.StoreUint64(&s.maxWal, walStart) + } + + logmsg, err = Parse(message.WalMessage.WalData) if err != nil { return fmt.Errorf("invalid pgoutput message: %s", err) } - if err := h(logmsg); err != nil { - return fmt.Errorf("error handling waldata: %s", err) + + // Ignore the error from handler for now + if err = h(logmsg, walStart); err != nil && s.failOnHandler { + return } - } - if message.ServerHeartbeat != nil { + } else if message.ServerHeartbeat != nil { if message.ServerHeartbeat.ReplyRequested == 1 { - log.Println("server wants a reply") - if err := sendStatus(); err != nil { - return err + if err = sendStatus(); err != nil { + return } } + } else { + return fmt.Errorf("No WalMessage/ServerHeartbeat defined in packet, should not happen") } } } diff --git a/values.go b/values.go index 04b584b..7ea9f06 100644 --- a/values.go +++ b/values.go @@ -7,34 +7,47 @@ import ( ) type RelationSet struct { - // TODO: Add mutex + // Mutex probably will be redundant as receiving + // a replication stream is currently strictly single-threaded relations map[uint32]Relation + connInfo *pgtype.ConnInfo } -func NewRelationSet() *RelationSet { - return &RelationSet{relations: map[uint32]Relation{}} +// NewRelationSet creates a new relation set. +// Optionally ConnInfo can be provided, however currently we need some changes to pgx to get it out +// from ReplicationConn. +func NewRelationSet(ci *pgtype.ConnInfo) *RelationSet { + return &RelationSet{map[uint32]Relation{}, ci} } func (rs *RelationSet) Add(r Relation) { rs.relations[r.ID] = r } +func (rs *RelationSet) Get(ID uint32) (r Relation, ok bool) { + r, ok = rs.relations[ID] + return +} + func (rs *RelationSet) Values(id uint32, row []Tuple) (map[string]pgtype.Value, error) { values := map[string]pgtype.Value{} - rel, ok := rs.relations[id] + rel, ok := rs.Get(id) if !ok { return values, fmt.Errorf("no relation for %d", id) } + // assert same number of row and columns for i, tuple := range row { col := rel.Columns[i] decoder := col.Decoder() - // TODO: Pass in connection? - if err := decoder.DecodeText(nil, tuple.Value); err != nil { - return values, fmt.Errorf("error decoding tuple %d: %s", i, err) + + if err := decoder.DecodeText(rs.connInfo, tuple.Value); err != nil { + return nil, fmt.Errorf("error decoding tuple %d: %s", i, err) } + values[col.Name] = decoder } + return values, nil }