diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 2396bd63dd..b1c60804a1 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -7,6 +7,11 @@ ## Fixes and improvements +# v2.1.1 + +## Fixes and improvements +* Improved security of printing connection details in `snow connection list`. + # v2.1.0 ## Backward incompatibility diff --git a/src/snowflake/cli/api/config.py b/src/snowflake/cli/api/config.py index 67859815ab..b301f704be 100644 --- a/src/snowflake/cli/api/config.py +++ b/src/snowflake/cli/api/config.py @@ -3,6 +3,7 @@ import logging import os from contextlib import contextmanager +from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Any, Dict, Optional, Union @@ -45,6 +46,51 @@ class Empty: ) +@dataclass +class ConnectionConfig: + account: Optional[str] = None + user: Optional[str] = None + password: Optional[str] = field(default=None, repr=False) + host: Optional[str] = None + region: Optional[str] = None + port: Optional[int] = None + database: Optional[str] = None + schema: Optional[str] = None + warehouse: Optional[str] = None + role: Optional[str] = None + authenticator: Optional[str] = None + private_key_path: Optional[str] = None + + _other_settings: dict = field(default_factory=lambda: {}) + + @classmethod + def from_dict(cls, config_dict: dict) -> ConnectionConfig: + known_settings = {} + other_settings = {} + for key, value in config_dict.items(): + if key in cls.__dict__: + known_settings[key] = value + else: + other_settings[key] = value + return cls(**known_settings, _other_settings=other_settings) + + def to_dict_of_known_non_empty_values(self) -> dict: + return { + k: v + for k, v in asdict(self).items() + if k != "_other_settings" and v is not None + } + + def _non_empty_other_values(self) -> dict: + return {k: v for k, v in self._other_settings.items() if v is not None} + + def to_dict_of_all_non_empty_values(self) -> dict: + return { + **self.to_dict_of_known_non_empty_values(), + **self._non_empty_other_values(), + } + + def config_init(config_file: Optional[Path]): """ Initializes the app configuration. Config provided via cli flag takes precedence. @@ -59,8 +105,12 @@ def config_init(config_file: Optional[Path]): CONFIG_MANAGER.read_config() -def add_connection(name: str, parameters: dict): - set_config_value(section=CONNECTIONS_SECTION, key=name, value=parameters) +def add_connection(name: str, connection_config: ConnectionConfig): + set_config_value( + section=CONNECTIONS_SECTION, + key=name, + value=connection_config.to_dict_of_all_non_empty_values(), + ) _DEFAULT_LOGS_CONFIG = { @@ -124,16 +174,23 @@ def config_section_exists(*path) -> bool: return False -def get_connection(connection_name: str) -> dict: +def get_all_connections() -> dict[str, ConnectionConfig]: + return { + k: ConnectionConfig.from_dict(connection_dict) + for k, connection_dict in get_config_section("connections").items() + } + + +def get_connection_dict(connection_name: str) -> dict: try: return get_config_section(CONNECTIONS_SECTION, connection_name) except KeyError: raise MissingConfiguration(f"Connection {connection_name} is not configured") -def get_default_connection() -> dict: +def get_default_connection_dict() -> dict: def_connection_name = CONFIG_MANAGER["default_connection_name"] - return get_connection(def_connection_name) + return get_connection_dict(def_connection_name) def get_config_section(*path) -> dict: @@ -193,9 +250,9 @@ def _merge_section_with_env(section: Union[Table, Any], *path) -> Dict[str, str] def _get_envs_for_path(*path) -> dict: - env_variables_prefix = "SNOWFLAKE_" + "_".join(p.upper() for p in path) + env_variables_prefix = "_".join(["SNOWFLAKE"] + [p.upper() for p in path]) + "_" return { - k.replace(f"{env_variables_prefix}_", "").lower(): os.environ[k] + k.replace(env_variables_prefix, "").lower(): os.environ[k] for k in os.environ.keys() if k.startswith(env_variables_prefix) } diff --git a/src/snowflake/cli/app/snow_connector.py b/src/snowflake/cli/app/snow_connector.py index b34178c6ba..5d606a68d5 100644 --- a/src/snowflake/cli/app/snow_connector.py +++ b/src/snowflake/cli/app/snow_connector.py @@ -7,7 +7,7 @@ import snowflake.connector from click.exceptions import ClickException -from snowflake.cli.api.config import get_connection, get_default_connection +from snowflake.cli.api.config import get_connection_dict, get_default_connection_dict from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB from snowflake.cli.api.exceptions import ( InvalidConnectionConfiguration, @@ -34,11 +34,11 @@ def connect_to_snowflake( raise ClickException("Can't use connection name and temporary connection.") if connection_name: - connection_parameters = get_connection(connection_name) + connection_parameters = get_connection_dict(connection_name) elif temporary_connection: connection_parameters = {} # we will apply overrides in next step else: - connection_parameters = get_default_connection() + connection_parameters = get_default_connection_dict() # Apply overrides to connection details for key, value in overrides.items(): diff --git a/src/snowflake/cli/plugins/connection/commands.py b/src/snowflake/cli/plugins/connection/commands.py index 14bfd54982..03acfbeee9 100644 --- a/src/snowflake/cli/plugins/connection/commands.py +++ b/src/snowflake/cli/plugins/connection/commands.py @@ -14,10 +14,11 @@ ) from snowflake.cli.api.commands.snow_typer import SnowTyper from snowflake.cli.api.config import ( + ConnectionConfig, add_connection, connection_exists, - get_config_section, - get_connection, + get_all_connections, + get_connection_dict, set_config_value, ) from snowflake.cli.api.console import cli_console @@ -60,10 +61,15 @@ def list_connections(**options) -> CommandResult: """ Lists configured connections. """ - connections = get_config_section("connections") + connections = get_all_connections() result = ( - {"connection_name": k, "parameters": _mask_password(v)} - for k, v in connections.items() + { + "connection_name": connection_name, + "parameters": _mask_password( + connection_config.to_dict_of_known_non_empty_values() + ), + } + for connection_name, connection_config in connections.items() ) return CollectionResult(result) @@ -200,26 +206,26 @@ def add( **options, ) -> CommandResult: """Adds a connection to configuration file.""" - connection_entry = { - "account": account, - "user": user, - "password": password, - "host": host, - "region": region, - "port": port, - "database": database, - "schema": schema, - "warehouse": warehouse, - "role": role, - "authenticator": authenticator, - "private_key_path": private_key_path, - } - connection_entry = {k: v for k, v in connection_entry.items() if v is not None} - if connection_exists(connection_name): raise ClickException(f"Connection {connection_name} already exists") - add_connection(connection_name, connection_entry) + add_connection( + connection_name, + ConnectionConfig( + account=account, + user=user, + password=password, + host=host, + region=region, + port=port, + database=database, + schema=schema, + warehouse=warehouse, + role=role, + authenticator=authenticator, + private_key_path=private_key_path, + ), + ) return MessageResult( f"Wrote new connection {connection_name} to {CONFIG_MANAGER.file_path}" ) @@ -276,6 +282,6 @@ def set_default( **options, ): """Changes default connection to provided value.""" - get_connection(connection_name=name) + get_connection_dict(connection_name=name) set_config_value(section=None, key="default_connection_name", value=name) return MessageResult(f"Default connection set to: {name}") diff --git a/tests/test_config.py b/tests/test_config.py index ead231d01e..f71c2bff7c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,12 +1,15 @@ +import os +from pathlib import Path from tempfile import TemporaryDirectory +from unittest import mock - +import pytest from snowflake.cli.api.config import ( + ConfigFileTooWidePermissionsError, config_init, get_config_section, - get_connection, - get_default_connection, - ConfigFileTooWidePermissionsError, + get_connection_dict, + get_default_connection_dict, ) from snowflake.cli.api.exceptions import MissingConfiguration @@ -29,7 +32,7 @@ def test_empty_config_file_is_created_if_not_present(): def test_get_connection_from_file(test_snowcli_config): config_init(test_snowcli_config) - assert get_connection("full") == { + assert get_connection_dict("full") == { "account": "dev_account", "user": "dev_user", "host": "dev_host", @@ -55,7 +58,7 @@ def test_get_connection_from_file(test_snowcli_config): def test_environment_variables_override_configuration_value(test_snowcli_config): config_init(test_snowcli_config) - assert get_connection("default") == { + assert get_connection_dict("default") == { "database": "database_foo", "schema": "test_public", "role": "test_role", @@ -77,7 +80,7 @@ def test_environment_variables_override_configuration_value(test_snowcli_config) def test_environment_variables_works_if_config_value_not_present(test_snowcli_config): config_init(test_snowcli_config) - assert get_connection("empty") == { + assert get_connection_dict("empty") == { "account": "some_account", "database": "test_database", "warehouse": "large", @@ -142,7 +145,7 @@ def test_create_default_config_if_not_exists(mock_config_manager): def test_default_connection_with_overwritten_values(test_snowcli_config): config_init(test_snowcli_config) - assert get_default_connection() == { + assert get_default_connection_dict() == { "database": "db_for_test", "role": "test_role", "schema": "test_public", @@ -155,7 +158,7 @@ def test_default_connection_with_overwritten_values(test_snowcli_config): def test_not_found_default_connection(test_root_path): config_init(Path(test_root_path / "empty_config.toml")) with pytest.raises(MissingConfiguration) as ex: - get_default_connection() + get_default_connection_dict() assert ex.value.message == "Connection default is not configured" @@ -170,7 +173,7 @@ def test_not_found_default_connection(test_root_path): def test_not_found_default_connection_from_evn_variable(test_root_path): config_init(Path(test_root_path / "empty_config.toml")) with pytest.raises(MissingConfiguration) as ex: - get_default_connection() + get_default_connection_dict() assert ex.value.message == "Connection not_existed_connection is not configured" @@ -186,7 +189,7 @@ def test_connections_toml_override_config_toml(test_snowcli_config, snowflake_ho ) config_init(test_snowcli_config) - assert get_default_connection() == {"database": "overridden_database"} + assert get_default_connection_dict() == {"database": "overridden_database"} assert CONFIG_MANAGER["connections"] == { "default": {"database": "overridden_database"} } diff --git a/tests/test_connection.py b/tests/test_connection.py index 246bc62af5..c9e262d503 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -172,7 +172,49 @@ def test_lists_connection_information(runner): "database": "dev_database", "host": "dev_host", "port": 8000, - "protocol": "dev_protocol", + "role": "dev_role", + "schema": "dev_schema", + "user": "dev_user", + "warehouse": "dev_warehouse", + }, + }, + { + "connection_name": "default", + "parameters": { + "database": "db_for_test", + "password": "****", # masked + "role": "test_role", + "schema": "test_public", + "warehouse": "xs", + }, + }, + {"connection_name": "empty", "parameters": {}}, + {"connection_name": "test_connections", "parameters": {"user": "python"}}, + ] + + +@mock.patch.dict( + os.environ, + { + # connection not existing in config.toml but with a name starting with connection from config.toml ("empty") + "SNOWFLAKE_CONNECTIONS_EMPTYABC_PASSWORD": "abc123", + # connection existing in config.toml but key not used by CLI + "SNOWFLAKE_CONNECTIONS_EMPTY_PW": "abc123", + }, + clear=True, +) +def test_connection_list_does_not_print_too_many_env_variables(runner): + result = runner.invoke(["connection", "list", "--format", "json"]) + assert result.exit_code == 0, result.output + payload = json.loads(result.output) + assert payload == [ + { + "connection_name": "full", + "parameters": { + "account": "dev_account", + "database": "dev_database", + "host": "dev_host", + "port": 8000, "role": "dev_role", "schema": "dev_schema", "user": "dev_user", diff --git a/tests/test_data/projects/snowpark_procedures/app.py b/tests/test_data/projects/snowpark_procedures/app.py index ddd692bb14..602af440e0 100644 --- a/tests/test_data/projects/snowpark_procedures/app.py +++ b/tests/test_data/projects/snowpark_procedures/app.py @@ -18,7 +18,7 @@ def test(session: Session) -> str: if __name__ == "__main__": from snowflake.cli.api.config import cli_config - session = Session.builder.configs(cli_config.get_connection("dev")).create() + session = Session.builder.configs(cli_config.get_connection_dict("dev")).create() if len(sys.argv) > 1: print(hello(session, *sys.argv[1:])) # type: ignore else: diff --git a/tests/test_data/projects/snowpark_procedures_coverage/app.py b/tests/test_data/projects/snowpark_procedures_coverage/app.py index ddd692bb14..602af440e0 100644 --- a/tests/test_data/projects/snowpark_procedures_coverage/app.py +++ b/tests/test_data/projects/snowpark_procedures_coverage/app.py @@ -18,7 +18,7 @@ def test(session: Session) -> str: if __name__ == "__main__": from snowflake.cli.api.config import cli_config - session = Session.builder.configs(cli_config.get_connection("dev")).create() + session = Session.builder.configs(cli_config.get_connection_dict("dev")).create() if len(sys.argv) > 1: print(hello(session, *sys.argv[1:])) # type: ignore else: diff --git a/tests_integration/test_data/projects/snowpark/app/app.py b/tests_integration/test_data/projects/snowpark/app/app.py index fe5526abbf..3430e84c52 100644 --- a/tests_integration/test_data/projects/snowpark/app/app.py +++ b/tests_integration/test_data/projects/snowpark/app/app.py @@ -22,7 +22,7 @@ def hello_function(name: str) -> str: if __name__ == "__main__": from snowflake.cli.api.config import cli_config - session = Session.builder.configs(cli_config.get_connection("dev")).create() + session = Session.builder.configs(cli_config.get_connection_dict("dev")).create() if len(sys.argv) > 1: print(hello_procedure(session, *sys.argv[1:])) # type: ignore else: diff --git a/tests_integration/test_data/projects/snowpark_coverage/app/module/procedures.py b/tests_integration/test_data/projects/snowpark_coverage/app/module/procedures.py index 2a03fbc076..5873f53ce5 100644 --- a/tests_integration/test_data/projects/snowpark_coverage/app/module/procedures.py +++ b/tests_integration/test_data/projects/snowpark_coverage/app/module/procedures.py @@ -18,7 +18,7 @@ def test(session: Session) -> str: if __name__ == "__main__": from snowflake.cli.api.config import cli_config - session = Session.builder.configs(cli_config.get_connection("dev")).create() + session = Session.builder.configs(cli_config.get_connection_dict("dev")).create() if len(sys.argv) > 1: print(hello(session, *sys.argv[1:])) # type: ignore else: diff --git a/tests_integration/test_data/projects/snowpark_fully_qualified_name/app/app.py b/tests_integration/test_data/projects/snowpark_fully_qualified_name/app/app.py index fe5526abbf..3430e84c52 100644 --- a/tests_integration/test_data/projects/snowpark_fully_qualified_name/app/app.py +++ b/tests_integration/test_data/projects/snowpark_fully_qualified_name/app/app.py @@ -22,7 +22,7 @@ def hello_function(name: str) -> str: if __name__ == "__main__": from snowflake.cli.api.config import cli_config - session = Session.builder.configs(cli_config.get_connection("dev")).create() + session = Session.builder.configs(cli_config.get_connection_dict("dev")).create() if len(sys.argv) > 1: print(hello_procedure(session, *sys.argv[1:])) # type: ignore else: