From 9caa65929974190c2f53e2958b6b8793ee5cfe60 Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Thu, 3 Oct 2024 15:34:44 +0100 Subject: [PATCH] Revert "fix: Revert "fix: Error handling in StreamingBatchWriter" (#1918)" This reverts commit 38b4bfd20e17a00d5a2c83e1d48b8b16270592ba. --- .../streamingbatchwriter.go | 128 ++++++++++-------- .../streamingbatchwriter_test.go | 51 ++++--- 2 files changed, 103 insertions(+), 76 deletions(-) diff --git a/writers/streamingbatchwriter/streamingbatchwriter.go b/writers/streamingbatchwriter/streamingbatchwriter.go index 23b6e8c139..d5825df31d 100644 --- a/writers/streamingbatchwriter/streamingbatchwriter.go +++ b/writers/streamingbatchwriter/streamingbatchwriter.go @@ -182,26 +182,28 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr errCh := make(chan error) defer close(errCh) - go func() { - for err := range errCh { - w.logger.Err(err).Msg("error from StreamingBatchWriter") - } - }() + for { + select { + case msg, ok := <-msgs: + if !ok { + return w.Close(ctx) + } - for msg := range msgs { - msgType := writers.MsgID(msg) - if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType { - if err := w.Flush(ctx); err != nil { + msgType := writers.MsgID(msg) + if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType { + if err := w.Flush(ctx); err != nil { + return err + } + } + w.lastMsgType = msgType + if err := w.startWorker(ctx, errCh, msg); err != nil { return err } - } - w.lastMsgType = msgType - if err := w.startWorker(ctx, errCh, msg); err != nil { + + case err := <-errCh: return err } } - - return w.Close(ctx) } func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- error, msg message.WriteMessage) error { @@ -221,13 +223,14 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err case *message.WriteMigrateTable: w.workersLock.Lock() defer w.workersLock.Unlock() + if w.migrateWorker != nil { w.migrateWorker.ch <- m return nil } - ch := make(chan *message.WriteMigrateTable) + w.migrateWorker = &streamingWorkerManager[*message.WriteMigrateTable]{ - ch: ch, + ch: make(chan *message.WriteMigrateTable), writeFunc: w.client.MigrateTable, flush: make(chan chan bool), @@ -241,17 +244,19 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err w.workersWaitGroup.Add(1) go w.migrateWorker.run(ctx, &w.workersWaitGroup, tableName) w.migrateWorker.ch <- m + return nil case *message.WriteDeleteStale: w.workersLock.Lock() defer w.workersLock.Unlock() + if w.deleteStaleWorker != nil { w.deleteStaleWorker.ch <- m return nil } - ch := make(chan *message.WriteDeleteStale) + w.deleteStaleWorker = &streamingWorkerManager[*message.WriteDeleteStale]{ - ch: ch, + ch: make(chan *message.WriteDeleteStale), writeFunc: w.client.DeleteStale, flush: make(chan chan bool), @@ -265,19 +270,29 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err w.workersWaitGroup.Add(1) go w.deleteStaleWorker.run(ctx, &w.workersWaitGroup, tableName) w.deleteStaleWorker.ch <- m + return nil case *message.WriteInsert: w.workersLock.RLock() - wr, ok := w.insertWorkers[tableName] + worker, ok := w.insertWorkers[tableName] w.workersLock.RUnlock() if ok { - wr.ch <- m + worker.ch <- m return nil } - ch := make(chan *message.WriteInsert) - wr = &streamingWorkerManager[*message.WriteInsert]{ - ch: ch, + w.workersLock.Lock() + activeWorker, ok := w.insertWorkers[tableName] + if ok { + w.workersLock.Unlock() + // some other goroutine could have already added the worker + // just send the message to it & discard our allocated worker + activeWorker.ch <- m + return nil + } + + worker = &streamingWorkerManager[*message.WriteInsert]{ + ch: make(chan *message.WriteInsert), writeFunc: w.client.WriteTable, flush: make(chan chan bool), @@ -287,33 +302,27 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err batchTimeout: w.batchTimeout, tickerFn: w.tickerFn, } - w.workersLock.Lock() - wrOld, ok := w.insertWorkers[tableName] - if ok { - w.workersLock.Unlock() - // some other goroutine could have already added the worker - // just send the message to it & discard our allocated worker - wrOld.ch <- m - return nil - } - w.insertWorkers[tableName] = wr + + w.insertWorkers[tableName] = worker w.workersLock.Unlock() w.workersWaitGroup.Add(1) - go wr.run(ctx, &w.workersWaitGroup, tableName) - ch <- m + go worker.run(ctx, &w.workersWaitGroup, tableName) + worker.ch <- m + return nil case *message.WriteDeleteRecord: w.workersLock.Lock() defer w.workersLock.Unlock() + if w.deleteRecordWorker != nil { w.deleteRecordWorker.ch <- m return nil } - ch := make(chan *message.WriteDeleteRecord) + // TODO: flush all workers for nested tables as well (See https://github.com/cloudquery/plugin-sdk/issues/1296) w.deleteRecordWorker = &streamingWorkerManager[*message.WriteDeleteRecord]{ - ch: ch, + ch: make(chan *message.WriteDeleteRecord), writeFunc: w.client.DeleteRecords, flush: make(chan chan bool), @@ -327,6 +336,7 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err w.workersWaitGroup.Add(1) go w.deleteRecordWorker.run(ctx, &w.workersWaitGroup, tableName) w.deleteRecordWorker.ch <- m + return nil default: return fmt.Errorf("unhandled message type: %T", msg) @@ -348,9 +358,9 @@ type streamingWorkerManager[T message.WriteMessage] struct { func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, tableName string) { defer wg.Done() var ( - clientCh chan T - clientErrCh chan error - open bool + inputCh chan T + outputCh chan error + open bool ) ensureOpened := func() { @@ -358,25 +368,30 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, return } - clientCh = make(chan T) - clientErrCh = make(chan error, 1) + inputCh = make(chan T) + outputCh = make(chan error) go func() { - defer close(clientErrCh) + defer close(outputCh) defer func() { - if err := recover(); err != nil { - clientErrCh <- fmt.Errorf("panic: %v", err) + if msg := recover(); msg != nil { + switch v := msg.(type) { + case error: + outputCh <- fmt.Errorf("panic: %w [recovered]", v) + default: + outputCh <- fmt.Errorf("panic: %v [recovered]", msg) + } } }() - clientErrCh <- s.writeFunc(ctx, clientCh) + result := s.writeFunc(ctx, inputCh) + outputCh <- result }() + open = true } + closeFlush := func() { if open { - close(clientCh) - if err := <-clientErrCh; err != nil { - s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err) - } + close(inputCh) s.limit.Reset() } open = false @@ -400,7 +415,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, if add != nil { ensureOpened() s.limit.AddSlice(add) - clientCh <- any(&message.WriteInsert{Record: add.Record}).(T) + inputCh <- any(&message.WriteInsert{Record: add.Record}).(T) } if len(toFlush) > 0 || rest != nil || s.limit.ReachedLimit() { // flush current batch @@ -410,7 +425,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, for _, sliceToFlush := range toFlush { ensureOpened() s.limit.AddRows(sliceToFlush.NumRows()) - clientCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T) + inputCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T) closeFlush() ticker.Reset(s.batchTimeout) } @@ -419,11 +434,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, if rest != nil { ensureOpened() s.limit.AddSlice(rest) - clientCh <- any(&message.WriteInsert{Record: rest.Record}).(T) + inputCh <- any(&message.WriteInsert{Record: rest.Record}).(T) } } else { ensureOpened() - clientCh <- r + inputCh <- r s.limit.AddRows(1) if s.limit.ReachedLimit() { closeFlush() @@ -441,6 +456,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, ticker.Reset(s.batchTimeout) } done <- true + case err := <-outputCh: + if err != nil { + s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err) + return + } case <-ctxDone: // this means the request was cancelled return // after this NO other call will succeed diff --git a/writers/streamingbatchwriter/streamingbatchwriter_test.go b/writers/streamingbatchwriter/streamingbatchwriter_test.go index 08cabbfd1a..7e6703c92d 100644 --- a/writers/streamingbatchwriter/streamingbatchwriter_test.go +++ b/writers/streamingbatchwriter/streamingbatchwriter_test.go @@ -201,20 +201,30 @@ func TestStreamingBatchSizeRows(t *testing.T) { ch <- &message.WriteInsert{ Record: record, } - time.Sleep(50 * time.Millisecond) - if l := testClient.MessageLen(messageTypeInsert); l != 0 { - t.Fatalf("expected 0 insert messages, got %d", l) - } + waitForLength(t, testClient.MessageLen, messageTypeInsert, 0) + waitForLength(t, testClient.InflightLen, messageTypeInsert, 1) ch <- &message.WriteInsert{ Record: record, } - ch <- &message.WriteInsert{ // third message, because we flush before exceeding the limit and then save the third one + + waitForLength(t, testClient.MessageLen, messageTypeInsert, 2) + waitForLength(t, testClient.InflightLen, messageTypeInsert, 0) + + ch <- &message.WriteInsert{ Record: record, } waitForLength(t, testClient.MessageLen, messageTypeInsert, 2) + waitForLength(t, testClient.InflightLen, messageTypeInsert, 1) + + ch <- &message.WriteInsert{ + Record: record, + } + + waitForLength(t, testClient.MessageLen, messageTypeInsert, 4) + waitForLength(t, testClient.InflightLen, messageTypeInsert, 0) close(ch) if err := <-errCh; err != nil { @@ -225,7 +235,7 @@ func TestStreamingBatchSizeRows(t *testing.T) { t.Fatalf("expected 0 open tables, got %d", l) } - if l := testClient.MessageLen(messageTypeInsert); l != 3 { + if l := testClient.MessageLen(messageTypeInsert); l != 4 { t.Fatalf("expected 3 insert messages, got %d", l) } } @@ -253,18 +263,12 @@ func TestStreamingBatchTimeout(t *testing.T) { ch <- &message.WriteInsert{ Record: record, } - time.Sleep(50 * time.Millisecond) - if l := testClient.MessageLen(messageTypeInsert); l != 0 { - t.Fatalf("expected 0 insert messages, got %d", l) - } + waitForLength(t, testClient.MessageLen, messageTypeInsert, 0) - // we need to wait for the batch to be flushed - time.Sleep(time.Millisecond * 50) + time.Sleep(time.Millisecond * 50) // we need to wait for the batch to be flushed - if l := testClient.MessageLen(messageTypeInsert); l != 0 { - t.Fatalf("expected 0 insert messages, got %d", l) - } + waitForLength(t, testClient.MessageLen, messageTypeInsert, 0) // flush tickFn() @@ -301,32 +305,35 @@ func TestStreamingBatchNoTimeout(t *testing.T) { ch <- &message.WriteInsert{ Record: record, } - time.Sleep(50 * time.Millisecond) - if l := testClient.MessageLen(messageTypeInsert); l != 0 { - t.Fatalf("expected 0 insert messages, got %d", l) - } + waitForLength(t, testClient.MessageLen, messageTypeInsert, 0) + waitForLength(t, testClient.InflightLen, messageTypeInsert, 1) time.Sleep(2 * time.Second) - if l := testClient.MessageLen(messageTypeInsert); l != 0 { - t.Fatalf("expected 0 insert messages, got %d", l) - } + waitForLength(t, testClient.MessageLen, messageTypeInsert, 0) + waitForLength(t, testClient.InflightLen, messageTypeInsert, 1) ch <- &message.WriteInsert{ Record: record, } + waitForLength(t, testClient.MessageLen, messageTypeInsert, 2) + waitForLength(t, testClient.InflightLen, messageTypeInsert, 0) + ch <- &message.WriteInsert{ Record: record, } waitForLength(t, testClient.MessageLen, messageTypeInsert, 2) + waitForLength(t, testClient.InflightLen, messageTypeInsert, 1) close(ch) if err := <-errCh; err != nil { t.Fatal(err) } + time.Sleep(50 * time.Millisecond) + if l := testClient.OpenLen(messageTypeInsert); l != 0 { t.Fatalf("expected 0 open tables, got %d", l) }