diff --git a/src/predictions/profiles_mlcorelib/connectors/SnowflakeConnector.py b/src/predictions/profiles_mlcorelib/connectors/SnowflakeConnector.py index c117d9069..e04ebb84a 100644 --- a/src/predictions/profiles_mlcorelib/connectors/SnowflakeConnector.py +++ b/src/predictions/profiles_mlcorelib/connectors/SnowflakeConnector.py @@ -8,7 +8,8 @@ import pandas as pd from datetime import datetime -from typing import Any, Iterable, List, Tuple, Union, Optional, Sequence, Dict + +from typing import Any, Iterable, List, Union, Sequence, Dict import snowflake.snowpark import snowflake.snowpark.types as T @@ -22,6 +23,8 @@ from ..utils.logger import logger from ..connectors.Connector import Connector from ..wht.rudderPB import MATERIAL_PREFIX +from cryptography.hazmat.primitives.serialization import load_pem_private_key +from cryptography.hazmat.backends import default_backend local_folder = constants.SF_LOCAL_STORAGE_DIR @@ -74,7 +77,20 @@ def __init__(self, creds: dict) -> None: def build_session(self, credentials: dict) -> snowflake.snowpark.Session: self.schema = credentials.get("schema", None) self.connection_parameters = self.remap_credentials(credentials) + if "privateKey" in credentials: + private_key = load_pem_private_key( + credentials["privateKey"].encode(), + password=( + credentials["privateKeyPassphrase"].encode() + if credentials.get("privateKeyPassphrase") + else None + ), + backend=default_backend(), + ) + self.connection_parameters["private_key"] = private_key session = Session.builder.configs(self.connection_parameters).create() + # Removing the private key to prevent serialisation error in the snowflake stored procedure + _ = self.connection_parameters.pop("private_key", None) return session def join_file_path(self, file_name: str) -> str: diff --git a/src/predictions/profiles_mlcorelib/utils/constants.py b/src/predictions/profiles_mlcorelib/utils/constants.py index be77be165..c972e48b1 100644 --- a/src/predictions/profiles_mlcorelib/utils/constants.py +++ b/src/predictions/profiles_mlcorelib/utils/constants.py @@ -83,6 +83,7 @@ "seaborn==0.12.0", "scikit-plot==0.3.7", "pycaret<=3.3.0", + "cryptography==42.0.2", ] diff --git a/src/predictions/setup.py b/src/predictions/setup.py index 812764a90..7d97d4cd3 100644 --- a/src/predictions/setup.py +++ b/src/predictions/setup.py @@ -39,6 +39,7 @@ "pycaret==3.3.1", "boto3>=1.34.153", "google-auth-oauthlib>=1.0.0", + "cryptography>=42.0.2", "plotly>=5.24.1", "networkx>=3.3", ], diff --git a/tests/unit/pythonWHT.py b/tests/unit/pythonWHT.py index 85c03eb41..d32623e24 100644 --- a/tests/unit/pythonWHT.py +++ b/tests/unit/pythonWHT.py @@ -77,7 +77,6 @@ def test_all_data_present_and_valid(self, mock_validate_historical_materials_has self.assertEqual(len(materials), 1) def test_missing_sequence_number(self): - self.pythonWHT.connector.check_table_entry_in_material_registry = Mock( return_value=True )