Skip to content

Commit

Permalink
claryfy query results comparision (#13090)
Browse files Browse the repository at this point in the history
  • Loading branch information
zverevgeny authored Dec 28, 2024
1 parent e1ce008 commit 22e791d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 41 deletions.
88 changes: 49 additions & 39 deletions ydb/tests/functional/suite_tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,28 @@ class Type(enum.Enum):
StreamQuery = 'statement stream query'
ImportTableData = 'statement import table data'

def __init__(self, suite: str, at_line: int, type: Type, text: [str]):
class SqlStatementType(enum.Enum):
Create = "create"
DropTable = "drop"
Insert = "insert"
Uosert = "upsert"
Replace = "replace"
Delete = "delete"
Select = "select"

def __init__(self, suite: str, at_line: int, type: Type, text: [str], sql_statement_type: SqlStatementType):
self.suite_name = suite
self.at_line = at_line
self.s_type = type
self.text = text
self.sql_statement_type = sql_statement_type

def __str__(self):
return f'''StatementDefinition:
suite: {self.suite_name}
line: {self.at_line}
type: {self.s_type}
sql_stmt_tyoe: {self.sql_statement_type}
text:
''' + '\n'.join([f' {row}' for row in self.text.split('\n')])

Expand All @@ -71,6 +82,17 @@ def _parse_statement_type(statement_line: str) -> Type:
return t
return None

@staticmethod
def _parse_sql_statement_type(lines: [str]) -> SqlStatementType:
for line in lines:
line = line.lower()
if line.startswith("pragma"):
continue
for t in list(StatementDefinition.SqlStatementType):
if line.startswith(t.value):
return t
return None

@staticmethod
def parse(suite: str, at_line: int, lines: list[str]):
if not lines or not lines[0]:
Expand All @@ -86,7 +108,10 @@ def parse(suite: str, at_line: int, lines: list[str]):
pass
else:
statement_lines.append(line)
return StatementDefinition(suite, at_line, type, "\n".join(statement_lines))
sql_statement_type = StatementDefinition._parse_sql_statement_type(statement_lines)
if sql_statement_type is None:
raise RuntimeError(f'Unknown sql statement type in {suite}, at line: {at_line}')
return StatementDefinition(suite, at_line, type, "\n".join(statement_lines), sql_statement_type)


def get_token(length=10):
Expand Down Expand Up @@ -274,6 +299,7 @@ def assert_statement_import_table_data(self, statement):
future_results.append(
tp.submit(
self.execute_query,
statement,
cmd,
)
)
Expand All @@ -298,7 +324,7 @@ def assert_statement(self, parsed_statement):
parsed_statement.at_line, parsed_statement.suite_name, end_time - start_time))

def assert_statement_ok(self, statement):
actual = safe_execute(lambda: self.execute_query(statement.text))
actual = safe_execute(lambda: self.execute_query(statement))
assert_that(
len(actual),
1,
Expand All @@ -307,14 +333,14 @@ def assert_statement_ok(self, statement):

def assert_statement_error(self, statement):
assert_that(
lambda: self.execute_query(statement.text),
lambda: self.execute_query(statement),
raises(
ydb.Error
)
)

def get_query_and_output(self, statement_text):
return statement_text, None
def get_expected_output(self, _):
return None

@staticmethod
def pretty_json(j):
Expand All @@ -331,11 +357,6 @@ def remove_optimizer_estimates(self, query_plan):
del op[key]

def assert_statement_query(self, statement):
def get_actual_and_expected():
query, expected = self.get_query_and_output(statement.text)
actual = self.execute_query(query)
return actual, expected

query_id = next(self.query_id)
query_name = "query_%d" % query_id
if self.plan:
Expand All @@ -350,8 +371,8 @@ def get_actual_and_expected():
)

return

actual_output, expected_output = safe_execute(get_actual_and_expected, statement, query_name)
expected_output = self.get_expected_output(statement.text)
actual_output = safe_execute(lambda: self.execute_query(statement), statement, query_name)

if len(actual_output) > 0:
self.files[query_name] = write_canonical_response(
Expand All @@ -367,25 +388,14 @@ def get_actual_and_expected():
)

def execute_scan_query(self, yql_text):
success = False
retries = 10
while retries > 0 and not success:
retries -= 1

def callee():
it = self.driver.table_client.scan_query(yql_text)
result = []
while True:
try:
response = next(it)
for row in response.result_set.rows:
result.append(row)

except StopIteration:
return result

except Exception:
if retries == 0:
raise
for response in it:
for row in response.result_set.rows:
result.append(row)
return result
return ydb.retry_operation_sync(callee)

def assert_statement_stream_query(self, statement):
if self.plan:
Expand Down Expand Up @@ -420,17 +430,17 @@ def explain(self, query):

return self.legacy_pool.retry_operation_sync(lambda s: s.explain(yql_text)).query_plan

def execute_query(self, statement_text):
yql_text = patch_yql_statement(statement_text, self.table_path_prefix)
def execute_query(self, statement: StatementDefinition, amended_text: str = None):
yql_text = amended_text if amended_text is not None else statement.text
yql_text = patch_yql_statement(yql_text, self.table_path_prefix)
result = self.pool.execute_with_retries(yql_text)

if len(result) == 1:
if statement.sql_statement_type == StatementDefinition.SqlStatementType.Select:
scan_query_result = self.execute_scan_query(yql_text)
for i in range(len(result)):
self.execute_assert(
result[i].rows,
scan_query_result,
"Results are not same",
)
self.execute_assert(
result[0].rows,
scan_query_result,
"Results are not same",
)

return result
4 changes: 2 additions & 2 deletions ydb/tests/functional/suite_tests/test_sql_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def assert_statement_error(self, statement):
assert_that(lambda: self.__execute_sqlitedb(statement.text), raises(sqlite3.Error), str(statement))
super(TestSQLLogic, self).assert_statement_error(statement)

def get_query_and_output(self, statement_text):
return statement_text, self.__execute_sqlitedb(statement_text, query=True)
def get_expected_output(self, statement_text):
return self.__execute_sqlitedb(statement_text, query=True)

def __execute_sqlitedb(self, statement_text, query=False):
cursor = self.sqlite_connection.cursor()
Expand Down

0 comments on commit 22e791d

Please sign in to comment.