From 8de2a65d569b90278e6d9882c4c8bb957b7c23cf Mon Sep 17 00:00:00 2001 From: PARYA JAFARI Date: Tue, 22 Oct 2024 14:15:25 -0400 Subject: [PATCH] new mock executor class --- .../cli/_plugins/nativeapp/sf_sql_facade.py | 2 + tests/nativeapp/test_new_sql_facade.py | 83 +++++++++++++++++++ tests/nativeapp/test_post_deploy_for_app.py | 2 +- 3 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 tests/nativeapp/test_new_sql_facade.py diff --git a/src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py b/src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py index f7cf788b55..742085321f 100644 --- a/src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py +++ b/src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py @@ -30,6 +30,8 @@ from snowflake.cli.api.sql_execution import SqlExecutor from snowflake.connector import ProgrammingError +# TODO: replace SqlExecutor with the smallest set of methods we want to call (abc) + class SnowflakeSQLFacade: def __init__(self, sql_executor: SqlExecutor | None = None): diff --git a/tests/nativeapp/test_new_sql_facade.py b/tests/nativeapp/test_new_sql_facade.py new file mode 100644 index 0000000000..fd6efad1d1 --- /dev/null +++ b/tests/nativeapp/test_new_sql_facade.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import deque +from textwrap import dedent +from typing import Any, List, Tuple + +from snowflake.cli._plugins.nativeapp.sf_sql_facade import SnowflakeSQLFacade +from snowflake.cli.api.sql_execution import SqlExecutor + +# Objective: +# - correct calls are made to SF. in the correct order. +# Requirements: + +# - can ensure execute_queries is called with expected query (closest method to SF before _execute_string, all queries call this method.) +# - can mock execute_queries to return result we want +# - can assert calls are made in the right order + + +class MockQueriesExecutor(SqlExecutor): + def __init__(self, expected_calls: List[Tuple[str, Any]]): + super().__init__() + self._mock_results = deque(expected_calls) + + @property + def _conn(self): + raise Exception("Connection should not be accessed in tests.") + + def _execute_queries(self, queries: str, **kwargs): + (query_expected, result) = self._mock_results.popleft() + if dedent(queries).strip() != dedent(query_expected).strip(): + raise Exception( + f"Test failed, expected query {query_expected} but received query {queries}" + ) + + return result + + def all_calls_made(self): + return len(self._mock_results) == 0 + + +def test_execute_with_role_wh_db(mock_cursor): + mock_script = "-- my comment\nselect 1;\nselect 2;" + mock_script_name = "test-user-sql-script.sql" + role = "mock_role" + wh = "mock_wh" + database = "mock_db" + sql_facade = SnowflakeSQLFacade( + MockQueriesExecutor( + expected_calls=[ + ("select current_role()", [mock_cursor([("old_role",)], [])]), + ("use role mock_role", [None]), + ("select current_warehouse()", [mock_cursor([("old_wh",)], [])]), + ("use warehouse mock_wh", [None]), + ("select current_database()", [mock_cursor([("old_db",)], [])]), + ("use database mock_db", [None]), + (mock_script, [None]), + ("use database old_db", [None]), + ("use warehouse old_wh", [None]), + ("use role old_role", [None]), + ] + ) + ) + + sql_facade.execute_user_script( + queries=mock_script, + script_name=mock_script_name, + role=role, + warehouse=wh, + database=database, + ) + + assert sql_facade._sql_executor.all_calls_made() # noqa: SLF001 diff --git a/tests/nativeapp/test_post_deploy_for_app.py b/tests/nativeapp/test_post_deploy_for_app.py index bcba039e9b..0600f80847 100644 --- a/tests/nativeapp/test_post_deploy_for_app.py +++ b/tests/nativeapp/test_post_deploy_for_app.py @@ -172,8 +172,8 @@ def test_sql_scripts_with_no_warehouse_no_database( @mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() def test_missing_sql_script( - mock_execute_query, mock_conn, + mock_execute_query, project_directory, ): mock_conn.return_value = MockConnectionCtx()