Skip to content

Commit

Permalink
fix: Revert "fix: Error handling in StreamingBatchWriter" (#1918)
Browse files Browse the repository at this point in the history
Reverts #1913

This broke come stuff, so reverting it to unblock SDK changes cloudquery/cloudquery#19312 (comment)
  • Loading branch information
erezrokah authored Oct 3, 2024
1 parent 00b9d9a commit 38b4bfd
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 103 deletions.
128 changes: 54 additions & 74 deletions writers/streamingbatchwriter/streamingbatchwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,28 +182,26 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr
errCh := make(chan error)
defer close(errCh)

for {
select {
case msg, ok := <-msgs:
if !ok {
return w.Close(ctx)
}
go func() {
for err := range errCh {
w.logger.Err(err).Msg("error from StreamingBatchWriter")
}
}()

msgType := writers.MsgID(msg)
if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType {
if err := w.Flush(ctx); err != nil {
return err
}
}
w.lastMsgType = msgType
if err := w.startWorker(ctx, errCh, msg); err != nil {
for msg := range msgs {
msgType := writers.MsgID(msg)
if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType {
if err := w.Flush(ctx); err != nil {
return err
}

case err := <-errCh:
}
w.lastMsgType = msgType
if err := w.startWorker(ctx, errCh, msg); err != nil {
return err
}
}

return w.Close(ctx)
}

func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- error, msg message.WriteMessage) error {
Expand All @@ -223,14 +221,13 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
case *message.WriteMigrateTable:
w.workersLock.Lock()
defer w.workersLock.Unlock()

if w.migrateWorker != nil {
w.migrateWorker.ch <- m
return nil
}

ch := make(chan *message.WriteMigrateTable)
w.migrateWorker = &streamingWorkerManager[*message.WriteMigrateTable]{
ch: make(chan *message.WriteMigrateTable),
ch: ch,
writeFunc: w.client.MigrateTable,

flush: make(chan chan bool),
Expand All @@ -244,19 +241,17 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
w.workersWaitGroup.Add(1)
go w.migrateWorker.run(ctx, &w.workersWaitGroup, tableName)
w.migrateWorker.ch <- m

return nil
case *message.WriteDeleteStale:
w.workersLock.Lock()
defer w.workersLock.Unlock()

if w.deleteStaleWorker != nil {
w.deleteStaleWorker.ch <- m
return nil
}

ch := make(chan *message.WriteDeleteStale)
w.deleteStaleWorker = &streamingWorkerManager[*message.WriteDeleteStale]{
ch: make(chan *message.WriteDeleteStale),
ch: ch,
writeFunc: w.client.DeleteStale,

flush: make(chan chan bool),
Expand All @@ -270,29 +265,19 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
w.workersWaitGroup.Add(1)
go w.deleteStaleWorker.run(ctx, &w.workersWaitGroup, tableName)
w.deleteStaleWorker.ch <- m

return nil
case *message.WriteInsert:
w.workersLock.RLock()
worker, ok := w.insertWorkers[tableName]
wr, ok := w.insertWorkers[tableName]
w.workersLock.RUnlock()
if ok {
worker.ch <- m
wr.ch <- m
return nil
}

w.workersLock.Lock()
activeWorker, ok := w.insertWorkers[tableName]
if ok {
w.workersLock.Unlock()
// some other goroutine could have already added the worker
// just send the message to it & discard our allocated worker
activeWorker.ch <- m
return nil
}

worker = &streamingWorkerManager[*message.WriteInsert]{
ch: make(chan *message.WriteInsert),
ch := make(chan *message.WriteInsert)
wr = &streamingWorkerManager[*message.WriteInsert]{
ch: ch,
writeFunc: w.client.WriteTable,

flush: make(chan chan bool),
Expand All @@ -302,27 +287,33 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
batchTimeout: w.batchTimeout,
tickerFn: w.tickerFn,
}

w.insertWorkers[tableName] = worker
w.workersLock.Lock()
wrOld, ok := w.insertWorkers[tableName]
if ok {
w.workersLock.Unlock()
// some other goroutine could have already added the worker
// just send the message to it & discard our allocated worker
wrOld.ch <- m
return nil
}
w.insertWorkers[tableName] = wr
w.workersLock.Unlock()

w.workersWaitGroup.Add(1)
go worker.run(ctx, &w.workersWaitGroup, tableName)
worker.ch <- m

go wr.run(ctx, &w.workersWaitGroup, tableName)
ch <- m
return nil
case *message.WriteDeleteRecord:
w.workersLock.Lock()
defer w.workersLock.Unlock()

if w.deleteRecordWorker != nil {
w.deleteRecordWorker.ch <- m
return nil
}

ch := make(chan *message.WriteDeleteRecord)
// TODO: flush all workers for nested tables as well (See https://github.com/cloudquery/plugin-sdk/issues/1296)
w.deleteRecordWorker = &streamingWorkerManager[*message.WriteDeleteRecord]{
ch: make(chan *message.WriteDeleteRecord),
ch: ch,
writeFunc: w.client.DeleteRecords,

flush: make(chan chan bool),
Expand All @@ -336,7 +327,6 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
w.workersWaitGroup.Add(1)
go w.deleteRecordWorker.run(ctx, &w.workersWaitGroup, tableName)
w.deleteRecordWorker.ch <- m

return nil
default:
return fmt.Errorf("unhandled message type: %T", msg)
Expand All @@ -358,40 +348,35 @@ type streamingWorkerManager[T message.WriteMessage] struct {
func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, tableName string) {
defer wg.Done()
var (
inputCh chan T
outputCh chan error
open bool
clientCh chan T
clientErrCh chan error
open bool
)

ensureOpened := func() {
if open {
return
}

inputCh = make(chan T)
outputCh = make(chan error)
clientCh = make(chan T)
clientErrCh = make(chan error, 1)
go func() {
defer close(outputCh)
defer close(clientErrCh)
defer func() {
if msg := recover(); msg != nil {
switch v := msg.(type) {
case error:
outputCh <- fmt.Errorf("panic: %w [recovered]", v)
default:
outputCh <- fmt.Errorf("panic: %v [recovered]", msg)
}
if err := recover(); err != nil {
clientErrCh <- fmt.Errorf("panic: %v", err)
}
}()
result := s.writeFunc(ctx, inputCh)
outputCh <- result
clientErrCh <- s.writeFunc(ctx, clientCh)
}()

open = true
}

closeFlush := func() {
if open {
close(inputCh)
close(clientCh)
if err := <-clientErrCh; err != nil {
s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err)
}
s.limit.Reset()
}
open = false
Expand All @@ -415,7 +400,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
if add != nil {
ensureOpened()
s.limit.AddSlice(add)
inputCh <- any(&message.WriteInsert{Record: add.Record}).(T)
clientCh <- any(&message.WriteInsert{Record: add.Record}).(T)
}
if len(toFlush) > 0 || rest != nil || s.limit.ReachedLimit() {
// flush current batch
Expand All @@ -425,7 +410,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
for _, sliceToFlush := range toFlush {
ensureOpened()
s.limit.AddRows(sliceToFlush.NumRows())
inputCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T)
clientCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T)
closeFlush()
ticker.Reset(s.batchTimeout)
}
Expand All @@ -434,11 +419,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
if rest != nil {
ensureOpened()
s.limit.AddSlice(rest)
inputCh <- any(&message.WriteInsert{Record: rest.Record}).(T)
clientCh <- any(&message.WriteInsert{Record: rest.Record}).(T)
}
} else {
ensureOpened()
inputCh <- r
clientCh <- r
s.limit.AddRows(1)
if s.limit.ReachedLimit() {
closeFlush()
Expand All @@ -456,11 +441,6 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
ticker.Reset(s.batchTimeout)
}
done <- true
case err := <-outputCh:
if err != nil {
s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err)
return
}
case <-ctxDone:
// this means the request was cancelled
return // after this NO other call will succeed
Expand Down
51 changes: 22 additions & 29 deletions writers/streamingbatchwriter/streamingbatchwriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,30 +201,20 @@ func TestStreamingBatchSizeRows(t *testing.T) {
ch <- &message.WriteInsert{
Record: record,
}
time.Sleep(50 * time.Millisecond)

waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)

ch <- &message.WriteInsert{
Record: record,
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
t.Fatalf("expected 0 insert messages, got %d", l)
}

waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 0)

ch <- &message.WriteInsert{
Record: record,
}

waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)

ch <- &message.WriteInsert{
ch <- &message.WriteInsert{ // third message, because we flush before exceeding the limit and then save the third one
Record: record,
}

waitForLength(t, testClient.MessageLen, messageTypeInsert, 4)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 0)
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)

close(ch)
if err := <-errCh; err != nil {
Expand All @@ -235,7 +225,7 @@ func TestStreamingBatchSizeRows(t *testing.T) {
t.Fatalf("expected 0 open tables, got %d", l)
}

if l := testClient.MessageLen(messageTypeInsert); l != 4 {
if l := testClient.MessageLen(messageTypeInsert); l != 3 {
t.Fatalf("expected 3 insert messages, got %d", l)
}
}
Expand Down Expand Up @@ -263,12 +253,18 @@ func TestStreamingBatchTimeout(t *testing.T) {
ch <- &message.WriteInsert{
Record: record,
}
time.Sleep(50 * time.Millisecond)

waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
t.Fatalf("expected 0 insert messages, got %d", l)
}

time.Sleep(time.Millisecond * 50) // we need to wait for the batch to be flushed
// we need to wait for the batch to be flushed
time.Sleep(time.Millisecond * 50)

waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
t.Fatalf("expected 0 insert messages, got %d", l)
}

// flush
tickFn()
Expand Down Expand Up @@ -305,35 +301,32 @@ func TestStreamingBatchNoTimeout(t *testing.T) {
ch <- &message.WriteInsert{
Record: record,
}
time.Sleep(50 * time.Millisecond)

waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
t.Fatalf("expected 0 insert messages, got %d", l)
}

time.Sleep(2 * time.Second)

waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
t.Fatalf("expected 0 insert messages, got %d", l)
}

ch <- &message.WriteInsert{
Record: record,
}
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 0)

ch <- &message.WriteInsert{
Record: record,
}

waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)

close(ch)
if err := <-errCh; err != nil {
t.Fatal(err)
}

time.Sleep(50 * time.Millisecond)

if l := testClient.OpenLen(messageTypeInsert); l != 0 {
t.Fatalf("expected 0 open tables, got %d", l)
}
Expand Down

0 comments on commit 38b4bfd

Please sign in to comment.