diff --git a/copy_from_test.go b/copy_from_test.go index ac2ccaabd..faed1d461 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -2,6 +2,7 @@ package pgx_test import ( "context" + "errors" "fmt" "os" "reflect" @@ -802,3 +803,47 @@ func TestConnCopyFromAutomaticStringConversion(t *testing.T) { ensureConnValid(t, conn) } + +func TestCopyFromFunc(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int + )`) + + dataCh := make(chan int, 1) + closeChanErr := errors.New("closed channel") + + const channelItems = 10 + go func() { + for i := 0; i < channelItems; i++ { + dataCh <- i + } + close(dataCh) + }() + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, + pgx.CopyFromFunc(func() ([]any, error) { + v, ok := <-dataCh + if !ok { + return nil, closeChanErr + } + return []any{v}, nil + })) + + fmt.Print(copyCount, err, "\n") + + require.ErrorIs(t, err, closeChanErr) + require.EqualValues(t, channelItems, copyCount) + + rows, err := conn.Query(context.Background(), "select * from foo order by a") + require.NoError(t, err) + nums, err := pgx.CollectRows(rows, pgx.RowTo[int64]) + require.NoError(t, err) + require.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, nums) + + ensureConnValid(t, conn) +}