Skip to content

Commit

Permalink
Allowed nxtf to signal end of data by returning nil,nil
Browse files Browse the repository at this point in the history
Added some test
Improved documentation
  • Loading branch information
robfordww authored and jackc committed Nov 11, 2023
1 parent 9b6d380 commit d38dd85
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
13 changes: 7 additions & 6 deletions copy_from.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ func (cts *copyFromSlice) Err() error {
return cts.err
}

// CopyFromCh returns a CopyFromSource interface over the provided channel.
// FieldNames is an ordered list of field names to copy from the struct, which
// order must match the order of the columns.
func CopyFromFunc(nxtf func() ([]any, error)) CopyFromSource {
// CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values.
// nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil,
// or it returns an error. If nxtf returns an error, the copy is aborted.
func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource {
return &copyFromFunc{next: nxtf}
}

Expand All @@ -79,11 +79,12 @@ type copyFromFunc struct {

func (g *copyFromFunc) Next() bool {
g.valueRow, g.err = g.next()
return g.err == nil
// only return true if valueRow exists and no error
return g.valueRow != nil && g.err == nil
}

func (g *copyFromFunc) Values() ([]any, error) {
return g.valueRow, nil
return g.valueRow, g.err
}

func (g *copyFromFunc) Err() error {
Expand Down
23 changes: 17 additions & 6 deletions copy_from_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package pgx_test

import (
"context"
"errors"
"fmt"
"os"
"reflect"
Expand Down Expand Up @@ -815,7 +814,6 @@ func TestCopyFromFunc(t *testing.T) {
)`)

dataCh := make(chan int, 1)
closeChanErr := errors.New("closed channel")

const channelItems = 10
go func() {
Expand All @@ -829,14 +827,12 @@ func TestCopyFromFunc(t *testing.T) {
pgx.CopyFromFunc(func() ([]any, error) {
v, ok := <-dataCh
if !ok {
return nil, closeChanErr
return nil, nil
}
return []any{v}, nil
}))

fmt.Print(copyCount, err, "\n")

require.ErrorIs(t, err, closeChanErr)
require.ErrorIs(t, err, nil)
require.EqualValues(t, channelItems, copyCount)

rows, err := conn.Query(context.Background(), "select * from foo order by a")
Expand All @@ -845,5 +841,20 @@ func TestCopyFromFunc(t *testing.T) {
require.NoError(t, err)
require.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, nums)

// simulate a failure
copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"},
pgx.CopyFromFunc(func() func() ([]any, error) {
x := 9
return func() ([]any, error) {
x++
if x > 100 {
return nil, fmt.Errorf("simulated error")
}
return []any{x}, nil
}
}()))
require.NotErrorIs(t, err, nil)
require.EqualValues(t, 0, copyCount) // no change, due to error

ensureConnValid(t, conn)
}

0 comments on commit d38dd85

Please sign in to comment.