Skip to content

Commit

Permalink
Fix for whitespaces problem in connection add (#1410)
Browse files Browse the repository at this point in the history
* Fix

* Fix

* Fix
  • Loading branch information
sfc-gh-jsikorski authored Aug 2, 2024
1 parent 9537a43 commit eb18f58
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 2 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
## New additions

## Fixes and improvements
* Fixed problem with whitespaces in `snow connection add` command


# v2.7.0
Expand Down
16 changes: 14 additions & 2 deletions src/snowflake/cli/plugins/connection/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
MessageResult,
ObjectResult,
)
from snowflake.cli.plugins.connection.util import strip_if_value_present
from snowflake.cli.plugins.object.manager import ObjectManager
from snowflake.connector import ProgrammingError
from snowflake.connector.config_manager import CONFIG_MANAGER
Expand Down Expand Up @@ -94,8 +95,8 @@ def require_integer(field_name: str):
def callback(value: str):
if value is None:
return None
if value.isdigit():
return value
if value.strip().isdigit():
return value.strip()
raise ClickException(f"Value of {field_name} must be integer")

return callback
Expand All @@ -117,6 +118,7 @@ def add(
prompt="Name for this connection",
help="Name of the new connection.",
show_default=False,
callback=strip_if_value_present,
),
account: str = typer.Option(
None,
Expand All @@ -126,6 +128,7 @@ def add(
prompt="Snowflake account name",
help="Account name to use when authenticating with Snowflake.",
show_default=False,
callback=strip_if_value_present,
),
user: str = typer.Option(
None,
Expand All @@ -135,6 +138,7 @@ def add(
prompt="Snowflake username",
show_default=False,
help="Username to connect to Snowflake.",
callback=strip_if_value_present,
),
password: str = typer.Option(
EmptyInput(),
Expand All @@ -153,6 +157,7 @@ def add(
click_type=OptionalPrompt(),
prompt="Role for the connection",
help="Role to use on Snowflake.",
callback=strip_if_value_present,
),
warehouse: str = typer.Option(
EmptyInput(),
Expand All @@ -161,6 +166,7 @@ def add(
click_type=OptionalPrompt(),
prompt="Warehouse for the connection",
help="Warehouse to use on Snowflake.",
callback=strip_if_value_present,
),
database: str = typer.Option(
EmptyInput(),
Expand All @@ -169,6 +175,7 @@ def add(
click_type=OptionalPrompt(),
prompt="Database for the connection",
help="Database to use on Snowflake.",
callback=strip_if_value_present,
),
schema: str = typer.Option(
EmptyInput(),
Expand All @@ -177,6 +184,7 @@ def add(
click_type=OptionalPrompt(),
prompt="Schema for the connection",
help="Schema to use on Snowflake.",
callback=strip_if_value_present,
),
host: str = typer.Option(
EmptyInput(),
Expand All @@ -185,6 +193,7 @@ def add(
click_type=OptionalPrompt(),
prompt="Connection host",
help="Host name the connection attempts to connect to Snowflake.",
callback=strip_if_value_present,
),
port: int = typer.Option(
EmptyInput(),
Expand All @@ -202,6 +211,7 @@ def add(
click_type=OptionalPrompt(),
prompt="Snowflake region",
help="Region name if not the default Snowflake deployment.",
callback=strip_if_value_present,
),
authenticator: str = typer.Option(
EmptyInput(),
Expand All @@ -218,6 +228,7 @@ def add(
click_type=OptionalPrompt(),
prompt="Path to private key file",
help="Path to file containing private key",
callback=strip_if_value_present,
),
token_file_path: str = typer.Option(
EmptyInput(),
Expand All @@ -226,6 +237,7 @@ def add(
click_type=OptionalPrompt(),
prompt="Path to token file",
help="Path to file with an OAuth token that should be used when connecting to Snowflake",
callback=strip_if_value_present,
),
set_as_default: bool = typer.Option(
False,
Expand Down
5 changes: 5 additions & 0 deletions src/snowflake/cli/plugins/connection/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import json
import logging
from typing import Optional

from click.exceptions import ClickException
from snowflake.connector import SnowflakeConnection
Expand Down Expand Up @@ -177,3 +178,7 @@ def make_snowsight_url(conn: SnowflakeConnection, path: str) -> str:
account = get_account(conn)
path_with_slash = path if path.startswith("/") else f"/{path}"
return f"{snowsight_host}/{deployment}/{account}{path_with_slash}"


def strip_if_value_present(value: Optional[str]) -> Optional[str]:
return value.strip() if value else value
45 changes: 45 additions & 0 deletions tests/__snapshots__/test_connection.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,49 @@
# serializer version: 1
# name: test_if_whitespaces_are_stripped_from_connection_name
'''
[connections.whitespaceTest]
account = "accName"
user = "userName"
password = "123"
host = "baz"
region = "Kaszuby"
port = "12345"
database = "foo"
schema = "bar"
warehouse = "some warehouse"
role = "some role"
authenticator = " foo "
private_key_path = "bar"
token_file_path = "baz"

'''
# ---
# name: test_if_whitespaces_are_stripped_from_connection_name.1
'''
[
{
"connection_name": "whitespaceTest",
"parameters": {
"account": "accName",
"user": "userName",
"password": "****",
"host": "baz",
"region": "Kaszuby",
"port": "12345",
"database": "foo",
"schema": "bar",
"warehouse": "some warehouse",
"role": "some role",
"authenticator": " foo ",
"private_key_path": "bar",
"token_file_path": "baz"
},
"is_default": false
}
]

'''
# ---
# name: test_new_connection_add_prompt_handles_default_values
'''
[connections.connName]
Expand Down
34 changes: 34 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,40 @@ def test_new_connection_with_jwt_auth(runner, os_agnostic_snapshot):
assert content == os_agnostic_snapshot


def test_if_whitespaces_are_stripped_from_connection_name(runner, os_agnostic_snapshot):
with NamedTemporaryFile("w+", suffix=".toml") as tmp_file:
result = runner.invoke_with_config_file(
tmp_file.name,
[
"connection",
"add",
"--connection-name",
" whitespaceTest ",
"--username",
"userName ",
"--account",
" accName",
],
input="123\n some role \n some warehouse\n foo \n bar \n baz \n 12345 \n Kaszuby \n foo \n bar \n baz ",
)
content = tmp_file.read()

assert result.exit_code == 0, result.output
assert content == os_agnostic_snapshot

connections_list = runner.invoke_with_config_file(
tmp_file.name, ["connection", "list", "--format", "json"]
)
assert connections_list.exit_code == 0
assert connections_list.output == os_agnostic_snapshot

set_as_default = runner.invoke_with_config_file(
tmp_file.name, ["connection", "set-default", "whitespaceTest"]
)
assert set_as_default.exit_code == 0
assert "Default connection set to: whitespaceTest" in set_as_default.output


def test_port_has_cannot_be_string(runner):
with NamedTemporaryFile("w+", suffix=".toml") as tmp_file:
result = runner.invoke_with_config_file(
Expand Down

0 comments on commit eb18f58

Please sign in to comment.