diff --git a/assert_test.go b/assert_test.go index 8714df3d9..e3568511b 100644 --- a/assert_test.go +++ b/assert_test.go @@ -68,6 +68,10 @@ func assertFalseF(t *testing.T, actual bool, descriptions ...string) { fatalOnNonEmpty(t, validateEqual(actual, false, descriptions...)) } +func assertFalseE(t *testing.T, actual bool, descriptions ...string) { + errorOnNonEmpty(t, validateEqual(actual, false, descriptions...)) +} + func assertStringContainsE(t *testing.T, actual string, expectedToContain string, descriptions ...string) { errorOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...)) } diff --git a/driver_test.go b/driver_test.go index 77da293d8..c783b81c1 100644 --- a/driver_test.go +++ b/driver_test.go @@ -595,6 +595,103 @@ func TestEmptyQueryWithRequestID(t *testing.T) { }) } +func TestRequestIDFromTwoDifferentSessions(t *testing.T) { + db, err := sql.Open("snowflake", dsn) + assertNilF(t, err) + db.SetMaxOpenConns(10) + + conn, err := db.Conn(context.Background()) + assertNilF(t, err) + defer conn.Close() + _, err = conn.ExecContext(context.Background(), forceJSON) + assertNilF(t, err) + + conn2, err := db.Conn(context.Background()) + assertNilF(t, err) + defer conn2.Close() + _, err = conn2.ExecContext(context.Background(), forceJSON) + assertNilF(t, err) + + // creating table + reqIDForCreate := NewUUID() + _, err = conn.ExecContext(WithRequestID(context.Background(), reqIDForCreate), "CREATE TABLE req_id_testing (id INTEGER)") + assertNilF(t, err) + defer func() { + _, err = db.Exec("DROP TABLE IF EXISTS req_id_testing") + assertNilE(t, err) + }() + _, err = conn.ExecContext(WithRequestID(context.Background(), reqIDForCreate), "CREATE TABLE req_id_testing (id INTEGER)") + assertNilF(t, err) + defer func() { + _, err = db.Exec("DROP TABLE IF EXISTS req_id_testing") + assertNilE(t, err) + }() + + // should fail as API v1 does not allow reusing requestID across sessions for DML statements + _, err = conn2.ExecContext(WithRequestID(context.Background(), reqIDForCreate), "CREATE TABLE req_id_testing (id INTEGER)") + assertNotNilE(t, err) + assertStringContainsE(t, err.Error(), "already exists") + + // inserting a record + reqIDForInsert := NewUUID() + execResult, err := conn.ExecContext(WithRequestID(context.Background(), reqIDForInsert), "INSERT INTO req_id_testing VALUES (1)") + assertNilF(t, err) + rowsInserted, err := execResult.RowsAffected() + assertNilF(t, err) + assertEqualE(t, rowsInserted, int64(1)) + + _, err = conn2.ExecContext(WithRequestID(context.Background(), reqIDForInsert), "INSERT INTO req_id_testing VALUES (1)") + assertNilF(t, err) + rowsInserted2, err := execResult.RowsAffected() + assertNilF(t, err) + assertEqualE(t, rowsInserted2, int64(1)) + + // selecting data + reqIDForSelect := NewUUID() + rows, err := conn.QueryContext(WithRequestID(context.Background(), reqIDForSelect), "SELECT * FROM req_id_testing") + assertNilF(t, err) + defer rows.Close() + var i int + assertTrueE(t, rows.Next()) + assertNilF(t, rows.Scan(&i)) + assertEqualE(t, i, 1) + i = 0 + assertTrueE(t, rows.Next()) + assertNilF(t, rows.Scan(&i)) + assertEqualE(t, i, 1) + assertFalseE(t, rows.Next()) + + rows2, err := conn.QueryContext(WithRequestID(context.Background(), reqIDForSelect), "SELECT * FROM req_id_testing") + assertNilF(t, err) + defer rows2.Close() + assertTrueE(t, rows2.Next()) + assertNilF(t, rows2.Scan(&i)) + assertEqualE(t, i, 1) + i = 0 + assertTrueE(t, rows2.Next()) + assertNilF(t, rows2.Scan(&i)) + assertEqualE(t, i, 1) + assertFalseE(t, rows2.Next()) + + // insert another data + _, err = conn.ExecContext(context.Background(), "INSERT INTO req_id_testing VALUES (1)") + assertNilF(t, err) + + // selecting using old request id + rows3, err := conn.QueryContext(WithRequestID(context.Background(), reqIDForSelect), "SELECT * FROM req_id_testing") + assertNilF(t, err) + defer rows3.Close() + assertTrueE(t, rows3.Next()) + assertNilF(t, rows3.Scan(&i)) + assertEqualE(t, i, 1) + i = 0 + assertTrueE(t, rows3.Next()) + assertNilF(t, rows3.Scan(&i)) + assertEqualE(t, i, 1) + i = 0 + assertFalseF(t, rows3.Next()) +} + func TestCRUD(t *testing.T) { runDBTest(t, func(dbt *DBTest) { // Create Table