From 6a9327cc25d970160ed25ed7ecbb1654e3a3b0ac Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Mon, 25 Mar 2024 11:46:24 -0700 Subject: [PATCH] Utilise USE instead of DESC in connection test --- src/snowflake/cli/api/sql_execution.py | 4 ++++ src/snowflake/cli/plugins/connection/commands.py | 12 ++++-------- tests/test_connection.py | 11 ++++++----- tests/test_sql.py | 15 +++++++++++++++ 4 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/snowflake/cli/api/sql_execution.py b/src/snowflake/cli/api/sql_execution.py index d372f4d40b..c9c2b9db34 100644 --- a/src/snowflake/cli/api/sql_execution.py +++ b/src/snowflake/cli/api/sql_execution.py @@ -8,6 +8,7 @@ from typing import Iterable, Optional, Tuple from snowflake.cli.api.cli_global_context import cli_context +from snowflake.cli.api.constants import ObjectType from snowflake.cli.api.exceptions import ( DatabaseNotProvidedError, SchemaNotProvidedError, @@ -62,6 +63,9 @@ def _execute_query(self, query: str, **kwargs): def _execute_queries(self, queries: str, **kwargs): return list(self._execute_string(dedent(queries), **kwargs)) + def use(self, object_type: ObjectType, name: str): + return self._execute_query(f"use {object_type.value.sf_name} {name}") + @contextmanager def use_role(self, new_role: str): """ diff --git a/src/snowflake/cli/plugins/connection/commands.py b/src/snowflake/cli/plugins/connection/commands.py index d0d8995a4d..9ef02ca0cb 100644 --- a/src/snowflake/cli/plugins/connection/commands.py +++ b/src/snowflake/cli/plugins/connection/commands.py @@ -254,17 +254,13 @@ def test( om = ObjectManager() try: if conn.role: - om.use_role(new_role=conn.role) + om.use(object_type=ObjectType.ROLE, name=conn.role) if conn.database: - om.describe( - object_type=ObjectType.DATABASE.value.cli_name, name=conn.database - ) + om.use(object_type=ObjectType.DATABASE, name=conn.database) if conn.schema: - om.describe(object_type=ObjectType.SCHEMA.value.cli_name, name=conn.schema) + om.use(object_type=ObjectType.SCHEMA, name=conn.schema) if conn.warehouse: - om.describe( - object_type=ObjectType.WAREHOUSE.value.cli_name, name=conn.warehouse - ) + om.use(object_type=ObjectType.WAREHOUSE, name=conn.warehouse) except ProgrammingError as err: raise ClickException(str(err)) diff --git a/tests/test_connection.py b/tests/test_connection.py index fb94f40a62..2fdd3e35d9 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -7,6 +7,7 @@ import pytest import tomlkit +from snowflake.cli.api.constants import ObjectType def test_new_connection_can_be_added(runner, snapshot): @@ -303,11 +304,11 @@ def test_connection_test(mock_connect, mock_om, runner): ) conn = mock_connect.return_value - mock_om.return_value.use_role.assert_called_once_with(new_role=conn.role) - assert mock_om.return_value.describe.mock_calls == [ - mock.call(object_type="database", name=conn.database), - mock.call(object_type="schema", name=conn.schema), - mock.call(object_type="warehouse", name=conn.warehouse), + assert mock_om.return_value.use.mock_calls == [ + mock.call(object_type=ObjectType.ROLE, name=conn.role), + mock.call(object_type=ObjectType.DATABASE, name=conn.database), + mock.call(object_type=ObjectType.SCHEMA, name=conn.schema), + mock.call(object_type=ObjectType.WAREHOUSE, name=conn.warehouse), ] diff --git a/tests/test_sql.py b/tests/test_sql.py index b32074f04a..b3cb324a74 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -3,6 +3,7 @@ from unittest import mock import pytest +from snowflake.cli.api.constants import ObjectType from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError from snowflake.cli.api.project.util import identifier_to_show_like_pattern from snowflake.cli.api.sql_execution import SqlExecutionMixin @@ -253,3 +254,17 @@ def test_show_specific_object_multiple_rows(mock_execute_query): mock_execute_query.assert_called_once_with( r"show objects like 'NAME'", cursor_class=DictCursor ) + + +@pytest.mark.parametrize( + "_object", + [ + ObjectType.WAREHOUSE, + ObjectType.ROLE, + ObjectType.DATABASE, + ], +) +@mock.patch("snowflake.cli.api.sql_execution.SqlExecutionMixin._execute_query") +def test_use_command(mock_execute_query, _object): + SqlExecutionMixin().use(object_type=_object, name="foo_name") + mock_execute_query.assert_called_once_with(f"use {_object.value.sf_name} foo_name")