From 57be7f7f575fcd913008b4fddee022bf5261367b Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Wed, 25 Sep 2024 13:25:37 +0100 Subject: [PATCH] feat: Handle panics in handlers --- stream.go | 33 ++++++++++-- stream_panic_test.go | 121 +++++++++++++++++++++++++++++++++++++++++++ write.go | 13 ++++- 3 files changed, 163 insertions(+), 4 deletions(-) create mode 100644 stream_panic_test.go diff --git a/stream.go b/stream.go index 4603d15..8c05d1d 100644 --- a/stream.go +++ b/stream.go @@ -1,6 +1,7 @@ package filetypes import ( + "errors" "fmt" "io" @@ -54,11 +55,22 @@ func (cl *Client) StartStream(table *schema.Table, uploadFunc func(io.Reader) er } // Write to the stream opened with StartStream. -func (s *Stream) Write(records []arrow.Record) error { +func (s *Stream) Write(records []arrow.Record) (retErr error) { if len(records) == 0 { return nil } + defer func() { + if msg := recover(); msg != nil { + switch v := msg.(type) { + case error: + retErr = fmt.Errorf("panic: %w [recovered]", v) + default: + retErr = fmt.Errorf("panic: %v [recovered]", msg) + } + } + }() + return s.h.WriteContent(records) } @@ -74,11 +86,11 @@ func (s *Stream) FinishWithError(finishError error) error { return <-s.done } - if err := s.h.WriteFooter(); err != nil { + if err := s.writeFooter(); err != nil { if !s.wc.closed { _ = s.wc.CloseWithError(err) } - return fmt.Errorf("failed to write footer: %w", <-s.done) + return fmt.Errorf("failed to write footer: %w", errors.Join(err, <-s.done)) } // ParquetWriter likes to close the underlying writer, so we need to check if it's already closed @@ -90,3 +102,18 @@ func (s *Stream) FinishWithError(finishError error) error { return <-s.done } + +func (s *Stream) writeFooter() (retErr error) { + defer func() { + if msg := recover(); msg != nil { + switch v := msg.(type) { + case error: + retErr = fmt.Errorf("panic: %w [recovered]", v) + default: + retErr = fmt.Errorf("panic: %v [recovered]", msg) + } + } + }() + + return s.h.WriteFooter() +} diff --git a/stream_panic_test.go b/stream_panic_test.go new file mode 100644 index 0000000..d23a573 --- /dev/null +++ b/stream_panic_test.go @@ -0,0 +1,121 @@ +package filetypes + +import ( + "errors" + "io" + "testing" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/cloudquery/filetypes/v4/types" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/stretchr/testify/require" +) + +func TestPanicOnHeader(t *testing.T) { + r := require.New(t) + cl := &Client{ + spec: &FileSpec{ + Compression: CompressionTypeNone, + }, + filetype: &customWriter{ + PanicOnHeader: true, + }, + } + + stream, err := cl.StartStream(&schema.Table{}, func(io.Reader) error { + return nil + }) + r.Nil(stream) + r.Error(err) + r.ErrorContains(err, "panic:") +} + +func TestPanicOnWrite(t *testing.T) { + r := require.New(t) + cl := &Client{ + spec: &FileSpec{ + Compression: CompressionTypeNone, + }, + filetype: &customWriter{ + PanicOnWrite: true, + }, + } + + table := &schema.Table{ + Name: "test", + Columns: []schema.Column{ + {Name: "name", Type: arrow.BinaryTypes.String}, + }, + } + bldr := array.NewRecordBuilder(memory.DefaultAllocator, table.ToArrowSchema()) + bldr.Field(0).(*array.StringBuilder).Append("foo") + bldr.Field(0).(*array.StringBuilder).Append("bar") + record := bldr.NewRecord() + + stream, err := cl.StartStream(table, func(io.Reader) error { + return nil + }) + r.NoError(err) + err = stream.Write([]arrow.Record{record}) + r.Error(err) + r.ErrorContains(err, "panic:") + + r.NoError(stream.Finish()) +} + +func TestPanicOnClose(t *testing.T) { + r := require.New(t) + cl := &Client{ + spec: &FileSpec{ + Compression: CompressionTypeNone, + }, + filetype: &customWriter{ + PanicOnClose: true, + }, + } + + stream, err := cl.StartStream(&schema.Table{}, func(io.Reader) error { + return nil + }) + r.NoError(err) + r.NoError(stream.Write(nil)) + + err = stream.Finish() + r.Error(err) + r.ErrorContains(err, "panic:") +} + +type customWriter struct { + PanicOnHeader bool + PanicOnWrite bool + PanicOnClose bool +} +type customHandle struct { + w *customWriter +} + +func (w *customWriter) WriteHeader(io.Writer, *schema.Table) (types.Handle, error) { + if w.PanicOnHeader { + panic("test panic") + } + return &customHandle{w: w}, nil +} + +func (*customWriter) Read(types.ReaderAtSeeker, *schema.Table, chan<- arrow.Record) error { + return errors.New("not implemented") +} + +func (h *customHandle) WriteContent([]arrow.Record) error { + if h.w.PanicOnWrite { + panic("test panic") + } + return nil +} +func (h *customHandle) WriteFooter() error { + if h.w.PanicOnClose { + panic("test panic") + } + return nil +} diff --git a/write.go b/write.go index 30d30cf..e1dc64a 100644 --- a/write.go +++ b/write.go @@ -14,7 +14,18 @@ func (cl *Client) WriteTableBatchFile(w io.Writer, table *schema.Table, records return types.WriteAll(cl, w, table, records) } -func (cl *Client) WriteHeader(w io.Writer, t *schema.Table) (types.Handle, error) { +func (cl *Client) WriteHeader(w io.Writer, t *schema.Table) (h types.Handle, retErr error) { + defer func() { + if msg := recover(); msg != nil { + switch v := msg.(type) { + case error: + retErr = fmt.Errorf("panic: %w [recovered]", v) + default: + retErr = fmt.Errorf("panic: %v [recovered]", msg) + } + } + }() + switch cl.spec.Compression { case CompressionTypeNone: return cl.filetype.WriteHeader(w, t)