Skip to content

Commit

Permalink
Added support for encrypted key-pair authentication on snowflake (#474)
Browse files Browse the repository at this point in the history
* Added support for encrypted key-pair authentication on snowflake
Current repo should support non encrypted p8 file too, but pywht needs a fix in sending that

* PR comments addressed
Added cryptography in snowflake import packages too
  • Loading branch information
dpatchigolla authored Oct 9, 2024
1 parent bd5bf43 commit a2e3541
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pandas as pd
from datetime import datetime

from typing import Any, Iterable, List, Tuple, Union, Optional, Sequence, Dict

from typing import Any, Iterable, List, Union, Sequence, Dict

import snowflake.snowpark
import snowflake.snowpark.types as T
Expand All @@ -22,6 +23,8 @@
from ..utils.logger import logger
from ..connectors.Connector import Connector
from ..wht.rudderPB import MATERIAL_PREFIX
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.hazmat.backends import default_backend

local_folder = constants.SF_LOCAL_STORAGE_DIR

Expand Down Expand Up @@ -74,7 +77,20 @@ def __init__(self, creds: dict) -> None:
def build_session(self, credentials: dict) -> snowflake.snowpark.Session:
self.schema = credentials.get("schema", None)
self.connection_parameters = self.remap_credentials(credentials)
if "privateKey" in credentials:
private_key = load_pem_private_key(
credentials["privateKey"].encode(),
password=(
credentials["privateKeyPassphrase"].encode()
if credentials.get("privateKeyPassphrase")
else None
),
backend=default_backend(),
)
self.connection_parameters["private_key"] = private_key
session = Session.builder.configs(self.connection_parameters).create()
# Removing the private key to prevent serialisation error in the snowflake stored procedure
_ = self.connection_parameters.pop("private_key", None)
return session

def join_file_path(self, file_name: str) -> str:
Expand Down
1 change: 1 addition & 0 deletions src/predictions/profiles_mlcorelib/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"seaborn==0.12.0",
"scikit-plot==0.3.7",
"pycaret<=3.3.0",
"cryptography==42.0.2",
]


Expand Down
1 change: 1 addition & 0 deletions src/predictions/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"pycaret==3.3.1",
"boto3>=1.34.153",
"google-auth-oauthlib>=1.0.0",
"cryptography>=42.0.2",
"plotly>=5.24.1",
"networkx>=3.3",
],
Expand Down
1 change: 0 additions & 1 deletion tests/unit/pythonWHT.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def test_all_data_present_and_valid(self, mock_validate_historical_materials_has
self.assertEqual(len(materials), 1)

def test_missing_sequence_number(self):

self.pythonWHT.connector.check_table_entry_in_material_registry = Mock(
return_value=True
)
Expand Down

0 comments on commit a2e3541

Please sign in to comment.