Skip to content

Commit

Permalink
[SNOW-1246922] Fix printing of env variables in connection list
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pjob committed Mar 18, 2024
1 parent 6a48f26 commit d3d488f
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 48 deletions.
5 changes: 5 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
* Changing imports in function/procedure section in `snowflake.yml` will cause the definition update on replace
* Adding `--pattern` flag to `stage list` command for filtering out results with regex.

# v2.1.1

## Fixes and improvements
* Improved security of printing connection details in `snow connection list`.

# v2.1.0

## Backward incompatibility
Expand Down
70 changes: 63 additions & 7 deletions src/snowflake/cli/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
from contextlib import contextmanager
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional, Union

Expand Down Expand Up @@ -47,6 +48,50 @@ class Empty:
)


@dataclass
class ConnectionConfig:
account: Optional[str] = None
user: Optional[str] = None
password: Optional[str] = None
host: Optional[str] = None
region: Optional[str] = None
port: Optional[int] = None
database: Optional[str] = None
schema: Optional[str] = None
warehouse: Optional[str] = None
role: Optional[str] = None
authenticator: Optional[str] = None
private_key_path: Optional[str] = None

_other_settings: dict = field(default_factory=lambda: {})

@classmethod
def from_dict(cls, config_dict: dict) -> ConnectionConfig:
self = cls()
for key, value in config_dict.items():
if hasattr(self, key):
setattr(self, key, value)
else:
self._other_settings[key] = value
return self

def to_dict_of_known_non_empty_values(self) -> dict:
return {
k: v
for k, v in asdict(self).items()
if k != "_other_settings" and v is not None
}

def _non_empty_other_values(self) -> dict:
return {k: v for k, v in self._other_settings.items() if v is not None}

def to_dict_of_all_non_empty_values(self) -> dict:
return {
**self.to_dict_of_known_non_empty_values(),
**self._non_empty_other_values(),
}


def config_init(config_file: Optional[Path]):
"""
Initializes the app configuration. Config provided via cli flag takes precedence.
Expand All @@ -61,8 +106,12 @@ def config_init(config_file: Optional[Path]):
CONFIG_MANAGER.read_config()


def add_connection(name: str, parameters: dict):
set_config_value(section=CONNECTIONS_SECTION, key=name, value=parameters)
def add_connection(name: str, connection_config: ConnectionConfig):
set_config_value(
section=CONNECTIONS_SECTION,
key=name,
value=connection_config.to_dict_of_all_non_empty_values(),
)


_DEFAULT_LOGS_CONFIG = {
Expand Down Expand Up @@ -126,16 +175,23 @@ def config_section_exists(*path) -> bool:
return False


def get_connection(connection_name: str) -> dict:
def get_all_connections() -> dict[str, ConnectionConfig]:
return {
k: ConnectionConfig.from_dict(connection_dict)
for k, connection_dict in get_config_section("connections").items()
}


def get_connection_dict(connection_name: str) -> dict:
try:
return get_config_section(CONNECTIONS_SECTION, connection_name)
except KeyError:
raise MissingConfiguration(f"Connection {connection_name} is not configured")


def get_default_connection() -> dict:
def get_default_connection_dict() -> dict:
def_connection_name = CONFIG_MANAGER["default_connection_name"]
return get_connection(def_connection_name)
return get_connection_dict(def_connection_name)


def get_config_section(*path) -> dict:
Expand Down Expand Up @@ -215,9 +271,9 @@ def _merge_section_with_env(section: Union[Table, Any], *path) -> Dict[str, str]


def _get_envs_for_path(*path) -> dict:
env_variables_prefix = "SNOWFLAKE_" + "_".join(p.upper() for p in path)
env_variables_prefix = "_".join(["SNOWFLAKE"] + [p.upper() for p in path]) + "_"
return {
k.replace(f"{env_variables_prefix}_", "").lower(): os.environ[k]
k.replace(env_variables_prefix, "").lower(): os.environ[k]
for k in os.environ.keys()
if k.startswith(env_variables_prefix)
}
Expand Down
6 changes: 3 additions & 3 deletions src/snowflake/cli/app/snow_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import snowflake.connector
from click.exceptions import ClickException
from snowflake.cli.api.config import get_connection, get_default_connection
from snowflake.cli.api.config import get_connection_dict, get_default_connection_dict
from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB
from snowflake.cli.api.exceptions import (
InvalidConnectionConfiguration,
Expand Down Expand Up @@ -37,11 +37,11 @@ def connect_to_snowflake(
raise ClickException("Can't use connection name and temporary connection.")

if connection_name:
connection_parameters = get_connection(connection_name)
connection_parameters = get_connection_dict(connection_name)
elif temporary_connection:
connection_parameters = {} # we will apply overrides in next step
else:
connection_parameters = get_default_connection()
connection_parameters = get_default_connection_dict()

# Apply overrides to connection details
for key, value in overrides.items():
Expand Down
52 changes: 29 additions & 23 deletions src/snowflake/cli/plugins/connection/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
)
from snowflake.cli.api.commands.snow_typer import SnowTyper
from snowflake.cli.api.config import (
ConnectionConfig,
add_connection,
connection_exists,
get_config_section,
get_connection,
get_all_connections,
get_connection_dict,
set_config_value,
)
from snowflake.cli.api.console import cli_console
Expand Down Expand Up @@ -63,10 +64,15 @@ def list_connections(**options) -> CommandResult:
"""
Lists configured connections.
"""
connections = get_config_section("connections")
connections = get_all_connections()
result = (
{"connection_name": k, "parameters": _mask_password(v)}
for k, v in connections.items()
{
"connection_name": connection_name,
"parameters": _mask_password(
connection_config.to_dict_of_known_non_empty_values()
),
}
for connection_name, connection_config in connections.items()
)
return CollectionResult(result)

Expand Down Expand Up @@ -203,26 +209,26 @@ def add(
**options,
) -> CommandResult:
"""Adds a connection to configuration file."""
connection_entry = {
"account": account,
"user": user,
"password": password,
"host": host,
"region": region,
"port": port,
"database": database,
"schema": schema,
"warehouse": warehouse,
"role": role,
"authenticator": authenticator,
"private_key_path": private_key_path,
}
connection_entry = {k: v for k, v in connection_entry.items() if v is not None}

if connection_exists(connection_name):
raise ClickException(f"Connection {connection_name} already exists")

add_connection(connection_name, connection_entry)
add_connection(
connection_name,
ConnectionConfig(
account=account,
user=user,
password=password,
host=host,
region=region,
port=port,
database=database,
schema=schema,
warehouse=warehouse,
role=role,
authenticator=authenticator,
private_key_path=private_key_path,
),
)
return MessageResult(
f"Wrote new connection {connection_name} to {CONFIG_MANAGER.file_path}"
)
Expand Down Expand Up @@ -289,6 +295,6 @@ def set_default(
**options,
):
"""Changes default connection to provided value."""
get_connection(connection_name=name)
get_connection_dict(connection_name=name)
set_config_value(section=None, key="default_connection_name", value=name)
return MessageResult(f"Default connection set to: {name}")
18 changes: 9 additions & 9 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
ConfigFileTooWidePermissionsError,
config_init,
get_config_section,
get_connection,
get_default_connection,
get_connection_dict,
get_default_connection_dict,
)
from snowflake.cli.api.exceptions import MissingConfiguration

Expand All @@ -31,7 +31,7 @@ def test_empty_config_file_is_created_if_not_present():
def test_get_connection_from_file(test_snowcli_config):
config_init(test_snowcli_config)

assert get_connection("full") == {
assert get_connection_dict("full") == {
"account": "dev_account",
"user": "dev_user",
"host": "dev_host",
Expand All @@ -57,7 +57,7 @@ def test_get_connection_from_file(test_snowcli_config):
def test_environment_variables_override_configuration_value(test_snowcli_config):
config_init(test_snowcli_config)

assert get_connection("default") == {
assert get_connection_dict("default") == {
"database": "database_foo",
"schema": "test_public",
"role": "test_role",
Expand All @@ -79,7 +79,7 @@ def test_environment_variables_override_configuration_value(test_snowcli_config)
def test_environment_variables_works_if_config_value_not_present(test_snowcli_config):
config_init(test_snowcli_config)

assert get_connection("empty") == {
assert get_connection_dict("empty") == {
"account": "some_account",
"database": "test_database",
"warehouse": "large",
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_create_default_config_if_not_exists(mock_config_manager):
def test_default_connection_with_overwritten_values(test_snowcli_config):
config_init(test_snowcli_config)

assert get_default_connection() == {
assert get_default_connection_dict() == {
"database": "db_for_test",
"role": "test_role",
"schema": "test_public",
Expand All @@ -157,7 +157,7 @@ def test_default_connection_with_overwritten_values(test_snowcli_config):
def test_not_found_default_connection(test_root_path):
config_init(Path(test_root_path / "empty_config.toml"))
with pytest.raises(MissingConfiguration) as ex:
get_default_connection()
get_default_connection_dict()

assert ex.value.message == "Connection default is not configured"

Expand All @@ -172,7 +172,7 @@ def test_not_found_default_connection(test_root_path):
def test_not_found_default_connection_from_evn_variable(test_root_path):
config_init(Path(test_root_path / "empty_config.toml"))
with pytest.raises(MissingConfiguration) as ex:
get_default_connection()
get_default_connection_dict()

assert ex.value.message == "Connection not_existed_connection is not configured"

Expand All @@ -188,7 +188,7 @@ def test_connections_toml_override_config_toml(test_snowcli_config, snowflake_ho
)
config_init(test_snowcli_config)

assert get_default_connection() == {"database": "overridden_database"}
assert get_default_connection_dict() == {"database": "overridden_database"}
assert CONFIG_MANAGER["connections"] == {
"default": {"database": "overridden_database"}
}
Expand Down
44 changes: 43 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,49 @@ def test_lists_connection_information(runner):
"database": "dev_database",
"host": "dev_host",
"port": 8000,
"protocol": "dev_protocol",
"role": "dev_role",
"schema": "dev_schema",
"user": "dev_user",
"warehouse": "dev_warehouse",
},
},
{
"connection_name": "default",
"parameters": {
"database": "db_for_test",
"password": "****", # masked
"role": "test_role",
"schema": "test_public",
"warehouse": "xs",
},
},
{"connection_name": "empty", "parameters": {}},
{"connection_name": "test_connections", "parameters": {"user": "python"}},
]


@mock.patch.dict(
os.environ,
{
# connection not existing in config.toml but with a name starting with connection from config.toml ("empty")
"SNOWFLAKE_CONNECTIONS_EMPTYABC_PASSWORD": "abc123",
# connection existing in config.toml but key not used by CLI
"SNOWFLAKE_CONNECTIONS_EMPTY_PW": "abc123",
},
clear=True,
)
def test_connection_list_does_not_print_too_many_env_variables(runner):
result = runner.invoke(["connection", "list", "--format", "json"])
assert result.exit_code == 0, result.output
payload = json.loads(result.output)
assert payload == [
{
"connection_name": "full",
"parameters": {
"account": "dev_account",
"database": "dev_database",
"host": "dev_host",
"port": 8000,
"role": "dev_role",
"schema": "dev_schema",
"user": "dev_user",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data/projects/snowpark_procedures/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test(session: Session) -> str:
if __name__ == "__main__":
from snowflake.cli.api.config import cli_config

session = Session.builder.configs(cli_config.get_connection("dev")).create()
session = Session.builder.configs(cli_config.get_connection_dict("dev")).create()
if len(sys.argv) > 1:
print(hello(session, *sys.argv[1:])) # type: ignore
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test(session: Session) -> str:
if __name__ == "__main__":
from snowflake.cli.api.config import cli_config

session = Session.builder.configs(cli_config.get_connection("dev")).create()
session = Session.builder.configs(cli_config.get_connection_dict("dev")).create()
if len(sys.argv) > 1:
print(hello(session, *sys.argv[1:])) # type: ignore
else:
Expand Down
2 changes: 1 addition & 1 deletion tests_integration/test_data/projects/snowpark/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def hello_function(name: str) -> str:
if __name__ == "__main__":
from snowflake.cli.api.config import cli_config

session = Session.builder.configs(cli_config.get_connection("dev")).create()
session = Session.builder.configs(cli_config.get_connection_dict("dev")).create()
if len(sys.argv) > 1:
print(hello_procedure(session, *sys.argv[1:])) # type: ignore
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test(session: Session) -> str:
if __name__ == "__main__":
from snowflake.cli.api.config import cli_config

session = Session.builder.configs(cli_config.get_connection("dev")).create()
session = Session.builder.configs(cli_config.get_connection_dict("dev")).create()
if len(sys.argv) > 1:
print(hello(session, *sys.argv[1:])) # type: ignore
else:
Expand Down
Loading

0 comments on commit d3d488f

Please sign in to comment.