diff --git a/core/events.go b/core/events.go new file mode 100644 index 0000000..baf8d26 --- /dev/null +++ b/core/events.go @@ -0,0 +1,10 @@ +package core + +import "github.com/fxamacker/cbor/v2" + +var CBORTags = cbor.NewTagSet() + +type ReplicableEvent[T any] interface { + Wrap() (T, error) + Unwrap() (T, error) +} diff --git a/db/change_log.go b/db/change_log.go index 786f5f0..0815ecb 100644 --- a/db/change_log.go +++ b/db/change_log.go @@ -368,7 +368,7 @@ func (conn *SqliteStreamDB) publishChangeLog() { err = conn.consumeChangeLogs(change.TableName, []*changeLogEntry{&logEntry}) if err != nil { - if err == ErrLogNotReadyToPublish || err == context.Canceled { + if errors.Is(err, ErrLogNotReadyToPublish) || errors.Is(err, context.Canceled) { break } diff --git a/db/change_log_event.go b/db/change_log_event.go index d594ca1..35c23a4 100644 --- a/db/change_log_event.go +++ b/db/change_log_event.go @@ -2,15 +2,23 @@ package db import ( "hash/fnv" + "reflect" "sort" "sync" + "time" "github.com/fxamacker/cbor/v2" + "github.com/maxpert/marmot/core" + "github.com/rs/zerolog/log" ) var tablePKColumnsCache = make(map[string][]string) var tablePKColumnsLock = sync.RWMutex{} +type sensitiveTypeWrapper struct { + Time *time.Time `cbor:"1,keyasint,omitempty"` +} + type ChangeLogEvent struct { Id int64 Type string @@ -19,15 +27,50 @@ type ChangeLogEvent struct { tableInfo []*ColumnInfo `cbor:"-"` } -func (e *ChangeLogEvent) Marshal() ([]byte, error) { - return cbor.Marshal(e) +func init() { + err := core.CBORTags.Add( + cbor.TagOptions{ + DecTag: cbor.DecTagRequired, + EncTag: cbor.EncTagRequired, + }, + reflect.TypeOf(sensitiveTypeWrapper{}), + 32, + ) + + log.Panic().Err(err) +} + +func (s sensitiveTypeWrapper) GetValue() any { + // Right now only sensitive value is Time + return s.Time +} + +func (e ChangeLogEvent) Wrap() (ChangeLogEvent, error) { + return e.prepare(), nil } -func (e *ChangeLogEvent) Unmarshal(data []byte) error { - return cbor.Unmarshal(data, e) +func (e ChangeLogEvent) Unwrap() (ChangeLogEvent, error) { + ret := ChangeLogEvent{ + Id: e.Id, + TableName: e.TableName, + Type: e.Type, + Row: map[string]any{}, + tableInfo: e.tableInfo, + } + + for k, v := range e.Row { + if st, ok := v.(sensitiveTypeWrapper); ok { + ret.Row[k] = st.GetValue() + continue + } + + ret.Row[k] = v + } + + return ret, nil } -func (e *ChangeLogEvent) Hash() (uint64, error) { +func (e ChangeLogEvent) Hash() (uint64, error) { hasher := fnv.New64() enc := cbor.NewEncoder(hasher) err := enc.StartIndefiniteArray() @@ -56,7 +99,7 @@ func (e *ChangeLogEvent) Hash() (uint64, error) { return hasher.Sum64(), nil } -func (e *ChangeLogEvent) getSortedPKColumns() []string { +func (e ChangeLogEvent) getSortedPKColumns() []string { tablePKColumnsLock.RLock() if values, found := tablePKColumnsCache[e.TableName]; found { @@ -79,3 +122,28 @@ func (e *ChangeLogEvent) getSortedPKColumns() []string { tablePKColumnsCache[e.TableName] = pkColumns return pkColumns } + +func (e ChangeLogEvent) prepare() ChangeLogEvent { + needsTransform := false + preparedRow := map[string]any{} + for k, v := range e.Row { + if t, ok := v.(time.Time); ok { + preparedRow[k] = sensitiveTypeWrapper{Time: &t} + needsTransform = true + } else { + preparedRow[k] = v + } + } + + if !needsTransform { + return e + } + + return ChangeLogEvent{ + Id: e.Id, + Type: e.Type, + TableName: e.TableName, + Row: preparedRow, + tableInfo: e.tableInfo, + } +} diff --git a/logstream/replication_event.go b/logstream/replication_event.go index ded01f2..e7d1c9f 100644 --- a/logstream/replication_event.go +++ b/logstream/replication_event.go @@ -1,16 +1,49 @@ package logstream -import "github.com/fxamacker/cbor/v2" +import ( + "github.com/fxamacker/cbor/v2" + "github.com/maxpert/marmot/core" +) -type ReplicationEvent[T any] struct { +type ReplicationEvent[T core.ReplicableEvent[T]] struct { FromNodeId uint64 - Payload *T + Payload T } func (e *ReplicationEvent[T]) Marshal() ([]byte, error) { - return cbor.Marshal(e) + wrappedPayload, err := e.Payload.Wrap() + if err != nil { + return nil, err + } + + ev := ReplicationEvent[T]{ + FromNodeId: e.FromNodeId, + Payload: wrappedPayload, + } + + em, err := cbor.EncOptions{}.EncModeWithTags(core.CBORTags) + if err != nil { + return nil, err + } + + return em.Marshal(ev) } func (e *ReplicationEvent[T]) Unmarshal(data []byte) error { - return cbor.Unmarshal(data, e) + dm, err := cbor.DecOptions{}.DecModeWithTags(core.CBORTags) + if err != nil { + return nil + } + + err = dm.Unmarshal(data, e) + if err != nil { + return err + } + + e.Payload, err = e.Payload.Unwrap() + if err != nil { + return err + } + + return nil } diff --git a/logstream/replicator.go b/logstream/replicator.go index 2ca2329..e6faa4d 100644 --- a/logstream/replicator.go +++ b/logstream/replicator.go @@ -2,6 +2,7 @@ package logstream import ( "context" + "errors" "fmt" "time" @@ -178,8 +179,7 @@ func (r *Replicator) Listen(shardID uint64, callback func(payload []byte) error) savedSeq := r.repState.get(streamName(shardID, r.compressionEnabled)) for sub.IsValid() { msg, err := sub.NextMsg(5 * time.Second) - - if err == nats.ErrTimeout { + if errors.Is(err, nats.ErrTimeout) { continue } @@ -199,7 +199,7 @@ func (r *Replicator) Listen(shardID uint64, callback func(payload []byte) error) err = r.invokeListener(callback, msg) if err != nil { msg.Nak() - if err == context.Canceled { + if errors.Is(err, context.Canceled) { return nil } diff --git a/marmot.go b/marmot.go index 918ddef..f39a2e7 100644 --- a/marmot.go +++ b/marmot.go @@ -189,7 +189,7 @@ func onChangeEvent(streamDB *db.SqliteStreamDB, ctxSt *utils.StateContext, event return err } - return streamDB.Replicate(ev.Payload) + return streamDB.Replicate(&ev.Payload) } } @@ -206,7 +206,7 @@ func onTableChanged(r *logstream.Replicator, ctxSt *utils.StateContext, events E ev := &logstream.ReplicationEvent[db.ChangeLogEvent]{ FromNodeId: nodeID, - Payload: event, + Payload: *event, } data, err := ev.Marshal()