diff --git a/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py index 1ff6f9d55..f01651e84 100644 --- a/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py +++ b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py @@ -8,6 +8,7 @@ import pandas as pd import snowflake.snowpark as sp from kedro.io.core import AbstractDataset, DatasetError +from snowflake.snowpark import DataFrame, Session logger = logging.getLogger(__name__) @@ -111,7 +112,7 @@ def __init__( # noqa: PLR0913 load_args: dict[str, Any] | None = None, save_args: dict[str, Any] | None = None, credentials: dict[str, Any] | None = None, - session: sp.Session | None = None, + session: Session | None = None, metadata: dict[str, Any] | None = None, ) -> None: """ @@ -182,7 +183,7 @@ def _describe(self) -> dict[str, Any]: } @staticmethod - def _get_session(connection_parameters) -> sp.Session: + def _get_session(connection_parameters) -> Session: """ Given a connection string, create singleton connection to be used across all instances of `SnowparkTableDataset` that @@ -211,32 +212,32 @@ def _get_session(connection_parameters) -> sp.Session: return session @property - def _session(self) -> sp.Session: + def _session(self) -> Session: """ Retrieve or create a session. Returns: - sp.Session: The current session associated with the object. + Session: The current session associated with the object. """ if not self.__session: self.__session = self._get_session(self._connection_parameters) return self.__session - def load(self) -> sp.DataFrame: + def load(self) -> DataFrame: """ Load data from a specified database table. Returns: - sp.DataFrame: The loaded data as a Snowpark DataFrame. + DataFrame: The loaded data as a Snowpark DataFrame. """ return self._session.table(self._validate_and_get_table_name()) - def save(self, data: pd.DataFrame | sp.DataFrame) -> None: + def save(self, data: pd.DataFrame | DataFrame) -> None: """ Save data to a specified database table. Args: - data (pd.DataFrame | sp.DataFrame): The data to save. + data (pd.DataFrame | DataFrame): The data to save. """ if isinstance(data, pd.DataFrame): data = self._session.create_dataframe(data)