Skip to content

Commit

Permalink
SNOW-923398: improve robustness in handling authentication response (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aling authored Sep 29, 2023
1 parent 682c14b commit 246eb8f
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 21 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne

- v3.2.1(September 26,2023)

- Fixed a bug where url port and path were ignore in private link oscp retry.
- Fixed a bug where url port and path were ignored in private link oscp retry.
- Added thread safety in telemetry when instantiating multiple connections concurrently.
- Bumped platformdirs dependency from >=2.6.0,<3.9.0 to >=2.6.0,<4.0.0.0 and made necessary changes to allow this.
- Removed the deprecation warning from the vendored urllib3 about urllib3.contrib.pyopenssl deprecation.
- Improved robustness in handling authentication response.

- v3.2.0(September 06,2023)

Expand Down
60 changes: 41 additions & 19 deletions src/snowflake/connector/auth/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,11 @@ def authenticate(
)

# waiting for MFA authentication
if ret["data"].get("nextAction") in (
if ret["data"] and ret["data"].get("nextAction") in (
"EXT_AUTHN_DUO_ALL",
"EXT_AUTHN_DUO_PUSH_N_PASSCODE",
):
body["inFlightCtx"] = ret["data"]["inFlightCtx"]
body["inFlightCtx"] = ret["data"].get("inFlightCtx")
body["data"]["EXT_AUTHN_DUO_METHOD"] = "push"
self.ret = {"message": "Timeout", "data": {}}

Expand All @@ -310,9 +310,13 @@ def post_request_wrapper(self, url, headers, body) -> None:
t.join(timeout=timeout)

ret = self.ret
if ret and ret["data"].get("nextAction") == "EXT_AUTHN_SUCCESS":
if (
ret
and ret["data"]
and ret["data"].get("nextAction") == "EXT_AUTHN_SUCCESS"
):
body = copy.deepcopy(body_template)
body["inFlightCtx"] = ret["data"]["inFlightCtx"]
body["inFlightCtx"] = ret["data"].get("inFlightCtx")
# final request to get tokens
ret = self._rest._post_request(
url,
Expand All @@ -321,7 +325,7 @@ def post_request_wrapper(self, url, headers, body) -> None:
timeout=self._rest._connection.login_timeout,
socket_timeout=self._rest._connection.login_timeout,
)
elif not ret or not ret["data"].get("token"):
elif not ret or not ret["data"] or not ret["data"].get("token"):
# not token is returned.
Error.errorhandler_wrapper(
self._rest._connection,
Expand All @@ -343,10 +347,10 @@ def post_request_wrapper(self, url, headers, body) -> None:
)
return session_parameters # required for unit test

elif ret["data"].get("nextAction") == "PWD_CHANGE":
elif ret["data"] and ret["data"].get("nextAction") == "PWD_CHANGE":
if callable(password_callback):
body = copy.deepcopy(body_template)
body["inFlightCtx"] = ret["data"]["inFlightCtx"]
body["inFlightCtx"] = ret["data"].get("inFlightCtx")
body["data"]["LOGIN_NAME"] = user
body["data"]["PASSWORD"] = (
auth_instance.password
Expand Down Expand Up @@ -411,41 +415,59 @@ def post_request_wrapper(self, url, headers, body) -> None:
)
else:
logger.debug(
"token = %s", "******" if ret["data"]["token"] is not None else "NULL"
"token = %s",
"******"
if ret["data"] and ret["data"].get("token") is not None
else "NULL",
)
logger.debug(
"master_token = %s",
"******" if ret["data"]["masterToken"] is not None else "NULL",
"******"
if ret["data"] and ret["data"].get("masterToken") is not None
else "NULL",
)
logger.debug(
"id_token = %s",
"******" if ret["data"].get("idToken") is not None else "NULL",
"******"
if ret["data"] and ret["data"].get("idToken") is not None
else "NULL",
)
logger.debug(
"mfa_token = %s",
"******" if ret["data"].get("mfaToken") is not None else "NULL",
"******"
if ret["data"] and ret["data"].get("mfaToken") is not None
else "NULL",
)
if not ret["data"]:
Error.errorhandler_wrapper(
None,
None,
Error,
{
"msg": "There is no data in the returning response, please retry the operation."
},
)
self._rest.update_tokens(
ret["data"]["token"],
ret["data"]["masterToken"],
ret["data"].get("token"),
ret["data"].get("masterToken"),
master_validity_in_seconds=ret["data"].get("masterValidityInSeconds"),
id_token=ret["data"].get("idToken"),
mfa_token=ret["data"].get("mfaToken"),
)
self.write_temporary_credentials(
self._rest._host, user, session_parameters, ret
)
if "sessionId" in ret["data"]:
self._rest._connection._session_id = ret["data"]["sessionId"]
if "sessionInfo" in ret["data"]:
session_info = ret["data"]["sessionInfo"]
if ret["data"] and "sessionId" in ret["data"]:
self._rest._connection._session_id = ret["data"].get("sessionId")
if ret["data"] and "sessionInfo" in ret["data"]:
session_info = ret["data"].get("sessionInfo")
self._rest._connection._database = session_info.get("databaseName")
self._rest._connection._schema = session_info.get("schemaName")
self._rest._connection._warehouse = session_info.get("warehouseName")
self._rest._connection._role = session_info.get("roleName")
if "parameters" in ret["data"]:
if ret["data"] and "parameters" in ret["data"]:
session_parameters.update(
{p["name"]: p["value"] for p in ret["data"]["parameters"]}
{p["name"]: p["value"] for p in ret["data"].get("parameters")}
)
self._rest._connection._update_parameters(session_parameters)
return session_parameters
Expand Down
22 changes: 21 additions & 1 deletion test/unit/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pytest

import snowflake.connector.errors
from snowflake.connector.constants import OCSPMode
from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION
from snowflake.connector.network import SnowflakeRestful
Expand Down Expand Up @@ -102,7 +103,12 @@ def _mock_auth_mfa_rest_response_failure(url, headers, body, **kwargs):
"inFlightCtx": "inFlightCtx",
},
}

elif mock_cnt == 2:
ret = {
"success": True,
"message": None,
"data": None,
}
mock_cnt += 1
return ret

Expand All @@ -126,6 +132,12 @@ def _mock_auth_mfa_rest_response_timeout(url, headers, body, **kwargs):
elif mock_cnt == 1:
time.sleep(10) # should timeout while here
ret = {}
elif mock_cnt == 2:
ret = {
"success": True,
"message": None,
"data": None,
}

mock_cnt += 1
return ret
Expand Down Expand Up @@ -168,6 +180,14 @@ def test_auth_mfa(next_action: str):
auth.authenticate(auth_instance, account, user, timeout=1)
assert rest._connection.errorhandler.called # error

# ret["data"] is none
with pytest.raises(snowflake.connector.errors.Error):
mock_cnt = 2
rest = _init_rest(application, _mock_auth_mfa_rest_response_timeout)
auth = Auth(rest)
auth_instance = AuthByDefault(password)
auth.authenticate(auth_instance, account, user)


def _mock_auth_password_change_rest_response(url, headers, body, **kwargs):
"""Test successful case."""
Expand Down

0 comments on commit 246eb8f

Please sign in to comment.