diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index ab2b19d5b2..dfaf491701 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -26,6 +26,7 @@ ## New additions ## Fixes and improvements +* Fixed problem with whitespaces in `snow connection add` command # v2.7.0 diff --git a/src/snowflake/cli/plugins/connection/commands.py b/src/snowflake/cli/plugins/connection/commands.py index 23650e3dcb..d86ac2f247 100644 --- a/src/snowflake/cli/plugins/connection/commands.py +++ b/src/snowflake/cli/plugins/connection/commands.py @@ -43,6 +43,7 @@ MessageResult, ObjectResult, ) +from snowflake.cli.plugins.connection.util import strip_if_value_present from snowflake.cli.plugins.object.manager import ObjectManager from snowflake.connector import ProgrammingError from snowflake.connector.config_manager import CONFIG_MANAGER @@ -94,8 +95,8 @@ def require_integer(field_name: str): def callback(value: str): if value is None: return None - if value.isdigit(): - return value + if value.strip().isdigit(): + return value.strip() raise ClickException(f"Value of {field_name} must be integer") return callback @@ -117,6 +118,7 @@ def add( prompt="Name for this connection", help="Name of the new connection.", show_default=False, + callback=strip_if_value_present, ), account: str = typer.Option( None, @@ -126,6 +128,7 @@ def add( prompt="Snowflake account name", help="Account name to use when authenticating with Snowflake.", show_default=False, + callback=strip_if_value_present, ), user: str = typer.Option( None, @@ -135,6 +138,7 @@ def add( prompt="Snowflake username", show_default=False, help="Username to connect to Snowflake.", + callback=strip_if_value_present, ), password: str = typer.Option( EmptyInput(), @@ -153,6 +157,7 @@ def add( click_type=OptionalPrompt(), prompt="Role for the connection", help="Role to use on Snowflake.", + callback=strip_if_value_present, ), warehouse: str = typer.Option( EmptyInput(), @@ -161,6 +166,7 @@ def add( click_type=OptionalPrompt(), prompt="Warehouse for the connection", help="Warehouse to use on Snowflake.", + callback=strip_if_value_present, ), database: str = typer.Option( EmptyInput(), @@ -169,6 +175,7 @@ def add( click_type=OptionalPrompt(), prompt="Database for the connection", help="Database to use on Snowflake.", + callback=strip_if_value_present, ), schema: str = typer.Option( EmptyInput(), @@ -177,6 +184,7 @@ def add( click_type=OptionalPrompt(), prompt="Schema for the connection", help="Schema to use on Snowflake.", + callback=strip_if_value_present, ), host: str = typer.Option( EmptyInput(), @@ -185,6 +193,7 @@ def add( click_type=OptionalPrompt(), prompt="Connection host", help="Host name the connection attempts to connect to Snowflake.", + callback=strip_if_value_present, ), port: int = typer.Option( EmptyInput(), @@ -202,6 +211,7 @@ def add( click_type=OptionalPrompt(), prompt="Snowflake region", help="Region name if not the default Snowflake deployment.", + callback=strip_if_value_present, ), authenticator: str = typer.Option( EmptyInput(), @@ -218,6 +228,7 @@ def add( click_type=OptionalPrompt(), prompt="Path to private key file", help="Path to file containing private key", + callback=strip_if_value_present, ), token_file_path: str = typer.Option( EmptyInput(), @@ -226,6 +237,7 @@ def add( click_type=OptionalPrompt(), prompt="Path to token file", help="Path to file with an OAuth token that should be used when connecting to Snowflake", + callback=strip_if_value_present, ), set_as_default: bool = typer.Option( False, diff --git a/src/snowflake/cli/plugins/connection/util.py b/src/snowflake/cli/plugins/connection/util.py index 6096bacbf6..af92f33bd9 100644 --- a/src/snowflake/cli/plugins/connection/util.py +++ b/src/snowflake/cli/plugins/connection/util.py @@ -16,6 +16,7 @@ import json import logging +from typing import Optional from click.exceptions import ClickException from snowflake.connector import SnowflakeConnection @@ -177,3 +178,7 @@ def make_snowsight_url(conn: SnowflakeConnection, path: str) -> str: account = get_account(conn) path_with_slash = path if path.startswith("/") else f"/{path}" return f"{snowsight_host}/{deployment}/{account}{path_with_slash}" + + +def strip_if_value_present(value: Optional[str]) -> Optional[str]: + return value.strip() if value else value diff --git a/tests/__snapshots__/test_connection.ambr b/tests/__snapshots__/test_connection.ambr index 0378a70011..2fc1f12c21 100644 --- a/tests/__snapshots__/test_connection.ambr +++ b/tests/__snapshots__/test_connection.ambr @@ -1,4 +1,49 @@ # serializer version: 1 +# name: test_if_whitespaces_are_stripped_from_connection_name + ''' + [connections.whitespaceTest] + account = "accName" + user = "userName" + password = "123" + host = "baz" + region = "Kaszuby" + port = "12345" + database = "foo" + schema = "bar" + warehouse = "some warehouse" + role = "some role" + authenticator = " foo " + private_key_path = "bar" + token_file_path = "baz" + + ''' +# --- +# name: test_if_whitespaces_are_stripped_from_connection_name.1 + ''' + [ + { + "connection_name": "whitespaceTest", + "parameters": { + "account": "accName", + "user": "userName", + "password": "****", + "host": "baz", + "region": "Kaszuby", + "port": "12345", + "database": "foo", + "schema": "bar", + "warehouse": "some warehouse", + "role": "some role", + "authenticator": " foo ", + "private_key_path": "bar", + "token_file_path": "baz" + }, + "is_default": false + } + ] + + ''' +# --- # name: test_new_connection_add_prompt_handles_default_values ''' [connections.connName] diff --git a/tests/test_connection.py b/tests/test_connection.py index 26d4ad8d16..ea75897f90 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -102,6 +102,40 @@ def test_new_connection_with_jwt_auth(runner, os_agnostic_snapshot): assert content == os_agnostic_snapshot +def test_if_whitespaces_are_stripped_from_connection_name(runner, os_agnostic_snapshot): + with NamedTemporaryFile("w+", suffix=".toml") as tmp_file: + result = runner.invoke_with_config_file( + tmp_file.name, + [ + "connection", + "add", + "--connection-name", + " whitespaceTest ", + "--username", + "userName ", + "--account", + " accName", + ], + input="123\n some role \n some warehouse\n foo \n bar \n baz \n 12345 \n Kaszuby \n foo \n bar \n baz ", + ) + content = tmp_file.read() + + assert result.exit_code == 0, result.output + assert content == os_agnostic_snapshot + + connections_list = runner.invoke_with_config_file( + tmp_file.name, ["connection", "list", "--format", "json"] + ) + assert connections_list.exit_code == 0 + assert connections_list.output == os_agnostic_snapshot + + set_as_default = runner.invoke_with_config_file( + tmp_file.name, ["connection", "set-default", "whitespaceTest"] + ) + assert set_as_default.exit_code == 0 + assert "Default connection set to: whitespaceTest" in set_as_default.output + + def test_port_has_cannot_be_string(runner): with NamedTemporaryFile("w+", suffix=".toml") as tmp_file: result = runner.invoke_with_config_file(