Skip to content

Commit

Permalink
change doc type hinting for tests to pass
Browse files Browse the repository at this point in the history
Signed-off-by: tdhooghe <thomas_dhooghe@mckinsey.com>
  • Loading branch information
tdhooghe committed Oct 21, 2024
1 parent 2c804c9 commit 2ac1453
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
import snowflake.snowpark as sp
from kedro.io.core import AbstractDataset, DatasetError
from snowflake.snowpark import DataFrame, Session

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,7 +112,7 @@ def __init__( # noqa: PLR0913
load_args: dict[str, Any] | None = None,
save_args: dict[str, Any] | None = None,
credentials: dict[str, Any] | None = None,
session: sp.Session | None = None,
session: Session | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""
Expand Down Expand Up @@ -182,7 +183,7 @@ def _describe(self) -> dict[str, Any]:
}

@staticmethod
def _get_session(connection_parameters) -> sp.Session:
def _get_session(connection_parameters) -> Session:
"""
Given a connection string, create singleton connection
to be used across all instances of `SnowparkTableDataset` that
Expand Down Expand Up @@ -211,32 +212,32 @@ def _get_session(connection_parameters) -> sp.Session:
return session

@property
def _session(self) -> sp.Session:
def _session(self) -> Session:
"""
Retrieve or create a session.
Returns:
sp.Session: The current session associated with the object.
Session: The current session associated with the object.
"""
if not self.__session:
self.__session = self._get_session(self._connection_parameters)
return self.__session

def load(self) -> sp.DataFrame:
def load(self) -> DataFrame:
"""
Load data from a specified database table.
Returns:
sp.DataFrame: The loaded data as a Snowpark DataFrame.
DataFrame: The loaded data as a Snowpark DataFrame.
"""
return self._session.table(self._validate_and_get_table_name())

def save(self, data: pd.DataFrame | sp.DataFrame) -> None:
def save(self, data: pd.DataFrame | DataFrame) -> None:
"""
Save data to a specified database table.
Args:
data (pd.DataFrame | sp.DataFrame): The data to save.
data (pd.DataFrame | DataFrame): The data to save.
"""
if isinstance(data, pd.DataFrame):
data = self._session.create_dataframe(data)
Expand Down

0 comments on commit 2ac1453

Please sign in to comment.