diff --git a/src/snowflake/cli/app/snow_connector.py b/src/snowflake/cli/app/snow_connector.py index 2906be0700..d577bd8d5d 100644 --- a/src/snowflake/cli/app/snow_connector.py +++ b/src/snowflake/cli/app/snow_connector.py @@ -3,7 +3,7 @@ import contextlib import logging import os -from typing import Dict, Optional +from typing import Any, Dict, Optional import snowflake.connector from click.exceptions import ClickException @@ -77,13 +77,8 @@ def connect_to_snowflake( "connection_diag_allowlist_path" ] = diag_allowlist_path - if ( - "session_token" in overrides - and overrides["session_token"] is not None - and "master_token" in overrides - and overrides["master_token"] is not None - ): - connection_parameters["server_session_keep_alive"] = True + # Make sure the connection is not closed if it was shared to the SnowCLI, instead of being created in the SnowCLI + _avoid_closing_the_connection_if_it_was_shared(overrides, connection_parameters) try: # Whatever output is generated when creating connection, @@ -101,6 +96,18 @@ def connect_to_snowflake( raise InvalidConnectionConfiguration(err.msg) +def _avoid_closing_the_connection_if_it_was_shared( + overrides: Dict[str, Any], connection_parameters: Dict +): + if ( + "session_token" in overrides + and overrides["session_token"] is not None + and "master_token" in overrides + and overrides["master_token"] is not None + ): + connection_parameters["server_session_keep_alive"] = True + + def _update_connection_details_with_private_key(connection_parameters: Dict): if "private_key_path" in connection_parameters: if connection_parameters.get("authenticator") == "SNOWFLAKE_JWT":