diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 247b8c47c..d99d02fb5 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -2369,10 +2369,16 @@ func testConnContextCanceledCancelsRunningQueryOnServer(t *testing.T, connString // server process to clients. However, we can check if the query is running by checking the generated query ID. queryID := fmt.Sprintf("%s testConnContextCanceled %d", dbType, time.Now().UnixNano()) - multiResult := pgConn.Exec(ctx, fmt.Sprintf(` - -- %v - select 'Hello, world', pg_sleep(30) - `, queryID)) + var multiResult *pgconn.MultiResultReader + if pgConn.ParameterStatus("crdb_version") != "" { + // in crdb comments are not shown in query in crdb_internal.node_queries + multiResult = pgConn.Exec(ctx, fmt.Sprintf(` select '%s', pg_sleep(30) `, queryID)) + } else { + multiResult = pgConn.Exec(ctx, fmt.Sprintf(` + -- %v + select 'Hello, world', pg_sleep(30) + `, queryID)) + } for multiResult.NextResult() { } @@ -2392,8 +2398,16 @@ func testConnContextCanceledCancelsRunningQueryOnServer(t *testing.T, connString ctx, cancel = context.WithTimeout(ctx, time.Second*5) defer cancel() + if pgConn.ParameterStatus("crdb_version") != "" { + crdbTest(t, ctx, otherConn, queryID) + } else { + postgresTest(t, ctx, otherConn, queryID) + } +} + +func postgresTest(t *testing.T, ctx context.Context, conn *pgconn.PgConn, queryID string) { for { - result := otherConn.ExecParams(ctx, + result := conn.ExecParams(ctx, `select state from pg_stat_activity where query like $1`, [][]byte{[]byte("%" + queryID + "%")}, nil, @@ -2411,6 +2425,24 @@ func testConnContextCanceledCancelsRunningQueryOnServer(t *testing.T, connString } } +func crdbTest(t *testing.T, ctx context.Context, conn *pgconn.PgConn, queryID string) { + for { + result := conn.ExecParams(ctx, + `select 1 from crdb_internal.node_queries where query like $1`, + [][]byte{[]byte("%" + queryID + "%")}, + nil, + nil, + nil, + ).Read() + require.NoError(t, result.Err) + + // in crdb query is deleted from table node_queries when it is finished + if len(result.Rows) == 0 { + break + } + } +} + func TestHijackAndConstruct(t *testing.T) { t.Parallel()