Skip to content

Commit

Permalink
feat: Handle panics in handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
disq committed Sep 25, 2024
1 parent abdbe46 commit 57be7f7
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 4 deletions.
33 changes: 30 additions & 3 deletions stream.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package filetypes

import (
"errors"
"fmt"
"io"

Expand Down Expand Up @@ -54,11 +55,22 @@ func (cl *Client) StartStream(table *schema.Table, uploadFunc func(io.Reader) er
}

// Write to the stream opened with StartStream.
func (s *Stream) Write(records []arrow.Record) error {
func (s *Stream) Write(records []arrow.Record) (retErr error) {
if len(records) == 0 {
return nil
}

defer func() {
if msg := recover(); msg != nil {
switch v := msg.(type) {
case error:
retErr = fmt.Errorf("panic: %w [recovered]", v)
default:
retErr = fmt.Errorf("panic: %v [recovered]", msg)
}
}
}()

return s.h.WriteContent(records)
}

Expand All @@ -74,11 +86,11 @@ func (s *Stream) FinishWithError(finishError error) error {
return <-s.done
}

if err := s.h.WriteFooter(); err != nil {
if err := s.writeFooter(); err != nil {
if !s.wc.closed {
_ = s.wc.CloseWithError(err)
}
return fmt.Errorf("failed to write footer: %w", <-s.done)
return fmt.Errorf("failed to write footer: %w", errors.Join(err, <-s.done))
}

// ParquetWriter likes to close the underlying writer, so we need to check if it's already closed
Expand All @@ -90,3 +102,18 @@ func (s *Stream) FinishWithError(finishError error) error {

return <-s.done
}

func (s *Stream) writeFooter() (retErr error) {
defer func() {
if msg := recover(); msg != nil {
switch v := msg.(type) {
case error:
retErr = fmt.Errorf("panic: %w [recovered]", v)
default:
retErr = fmt.Errorf("panic: %v [recovered]", msg)
}
}
}()

return s.h.WriteFooter()
}
121 changes: 121 additions & 0 deletions stream_panic_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package filetypes

import (
"errors"
"io"
"testing"

"github.com/apache/arrow/go/v17/arrow"
"github.com/apache/arrow/go/v17/arrow/array"
"github.com/apache/arrow/go/v17/arrow/memory"
"github.com/cloudquery/filetypes/v4/types"
"github.com/cloudquery/plugin-sdk/v4/schema"
"github.com/stretchr/testify/require"
)

func TestPanicOnHeader(t *testing.T) {
r := require.New(t)
cl := &Client{
spec: &FileSpec{
Compression: CompressionTypeNone,
},
filetype: &customWriter{
PanicOnHeader: true,
},
}

stream, err := cl.StartStream(&schema.Table{}, func(io.Reader) error {
return nil
})
r.Nil(stream)
r.Error(err)
r.ErrorContains(err, "panic:")
}

func TestPanicOnWrite(t *testing.T) {
r := require.New(t)
cl := &Client{
spec: &FileSpec{
Compression: CompressionTypeNone,
},
filetype: &customWriter{
PanicOnWrite: true,
},
}

table := &schema.Table{
Name: "test",
Columns: []schema.Column{
{Name: "name", Type: arrow.BinaryTypes.String},
},
}
bldr := array.NewRecordBuilder(memory.DefaultAllocator, table.ToArrowSchema())
bldr.Field(0).(*array.StringBuilder).Append("foo")
bldr.Field(0).(*array.StringBuilder).Append("bar")
record := bldr.NewRecord()

stream, err := cl.StartStream(table, func(io.Reader) error {
return nil
})
r.NoError(err)
err = stream.Write([]arrow.Record{record})
r.Error(err)
r.ErrorContains(err, "panic:")

r.NoError(stream.Finish())
}

func TestPanicOnClose(t *testing.T) {
r := require.New(t)
cl := &Client{
spec: &FileSpec{
Compression: CompressionTypeNone,
},
filetype: &customWriter{
PanicOnClose: true,
},
}

stream, err := cl.StartStream(&schema.Table{}, func(io.Reader) error {
return nil
})
r.NoError(err)
r.NoError(stream.Write(nil))

err = stream.Finish()
r.Error(err)
r.ErrorContains(err, "panic:")
}

type customWriter struct {
PanicOnHeader bool
PanicOnWrite bool
PanicOnClose bool
}
type customHandle struct {
w *customWriter
}

func (w *customWriter) WriteHeader(io.Writer, *schema.Table) (types.Handle, error) {
if w.PanicOnHeader {
panic("test panic")
}
return &customHandle{w: w}, nil
}

func (*customWriter) Read(types.ReaderAtSeeker, *schema.Table, chan<- arrow.Record) error {
return errors.New("not implemented")
}

func (h *customHandle) WriteContent([]arrow.Record) error {
if h.w.PanicOnWrite {
panic("test panic")
}
return nil
}
func (h *customHandle) WriteFooter() error {
if h.w.PanicOnClose {
panic("test panic")
}
return nil
}
13 changes: 12 additions & 1 deletion write.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,18 @@ func (cl *Client) WriteTableBatchFile(w io.Writer, table *schema.Table, records
return types.WriteAll(cl, w, table, records)
}

func (cl *Client) WriteHeader(w io.Writer, t *schema.Table) (types.Handle, error) {
func (cl *Client) WriteHeader(w io.Writer, t *schema.Table) (h types.Handle, retErr error) {
defer func() {
if msg := recover(); msg != nil {
switch v := msg.(type) {
case error:
retErr = fmt.Errorf("panic: %w [recovered]", v)
default:
retErr = fmt.Errorf("panic: %v [recovered]", msg)
}
}
}()

switch cl.spec.Compression {
case CompressionTypeNone:
return cl.filetype.WriteHeader(w, t)
Expand Down

0 comments on commit 57be7f7

Please sign in to comment.