diff --git a/sqlframe/base/dataframe.py b/sqlframe/base/dataframe.py index e5f8dab..9181ad3 100644 --- a/sqlframe/base/dataframe.py +++ b/sqlframe/base/dataframe.py @@ -80,7 +80,7 @@ } -DF = t.TypeVar("DF", bound="_BaseDataFrame") +DF = t.TypeVar("DF", bound="BaseDataFrame") class OpenAIMode(enum.Enum): @@ -198,7 +198,7 @@ def cov(self, col1: str, col2: str) -> float: STAT = t.TypeVar("STAT", bound=_BaseDataFrameStatFunctions) -class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]): +class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]): _na: t.Type[NA] _stat: t.Type[STAT] _group_data: t.Type[GROUP_DATA] diff --git a/sqlframe/base/decorators.py b/sqlframe/base/decorators.py index 45af52a..39229ed 100644 --- a/sqlframe/base/decorators.py +++ b/sqlframe/base/decorators.py @@ -43,7 +43,7 @@ def wrapper(*args, **kwargs): col_name = col_name.this alias_name = f"{func.__name__}__{col_name or ''}__" # BigQuery has restrictions on alias names so we constrain it to alphanumeric characters and underscores - return result.alias(re.sub("\W", "_", alias_name)) + return result.alias(re.sub("\W", "_", alias_name)) # type: ignore return result wrapper.unsupported_engines = ( # type: ignore diff --git a/sqlframe/base/mixins/dataframe_mixins.py b/sqlframe/base/mixins/dataframe_mixins.py index 859eeca..8c8e556 100644 --- a/sqlframe/base/mixins/dataframe_mixins.py +++ b/sqlframe/base/mixins/dataframe_mixins.py @@ -11,7 +11,7 @@ SESSION, STAT, WRITER, - _BaseDataFrame, + BaseDataFrame, ) if sys.version_info >= (3, 11): @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) -class NoCachePersistSupportMixin(_BaseDataFrame, t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]): +class NoCachePersistSupportMixin(BaseDataFrame, t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]): def cache(self) -> Self: logger.warning("This engine does not support caching. Ignoring cache() call.") return self @@ -34,7 +34,7 @@ def persist(self) -> Self: class TypedColumnsFromTempViewMixin( - _BaseDataFrame, t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA] + BaseDataFrame, t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA] ): @property def _typed_columns(self) -> t.List[Column]: diff --git a/sqlframe/base/operations.py b/sqlframe/base/operations.py index 72bf883..fbe4cab 100644 --- a/sqlframe/base/operations.py +++ b/sqlframe/base/operations.py @@ -7,7 +7,7 @@ from enum import IntEnum if t.TYPE_CHECKING: - from sqlframe.base.dataframe import _BaseDataFrame + from sqlframe.base.dataframe import BaseDataFrame from sqlframe.base.group import _BaseGroupedData @@ -37,7 +37,7 @@ def operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]: def decorator(func: t.Callable) -> t.Callable: @functools.wraps(func) - def wrapper(self: _BaseDataFrame, *args, **kwargs) -> _BaseDataFrame: + def wrapper(self: BaseDataFrame, *args, **kwargs) -> BaseDataFrame: if self.last_op == Operation.INIT: self = self._convert_leaf_to_cte() self.last_op = Operation.NO_OP @@ -45,7 +45,7 @@ def wrapper(self: _BaseDataFrame, *args, **kwargs) -> _BaseDataFrame: new_op = op if op != Operation.NO_OP else last_op if new_op < last_op or (last_op == new_op == Operation.SELECT): self = self._convert_leaf_to_cte() - df: t.Union[_BaseDataFrame, _BaseGroupedData] = func(self, *args, **kwargs) + df: t.Union[BaseDataFrame, _BaseGroupedData] = func(self, *args, **kwargs) df.last_op = new_op # type: ignore return df # type: ignore @@ -69,7 +69,7 @@ def group_operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]: def decorator(func: t.Callable) -> t.Callable: @functools.wraps(func) - def wrapper(self: _BaseGroupedData, *args, **kwargs) -> _BaseDataFrame: + def wrapper(self: _BaseGroupedData, *args, **kwargs) -> BaseDataFrame: if self._df.last_op == Operation.INIT: self._df = self._df._convert_leaf_to_cte() self._df.last_op = Operation.NO_OP @@ -77,7 +77,7 @@ def wrapper(self: _BaseGroupedData, *args, **kwargs) -> _BaseDataFrame: new_op = op if op != Operation.NO_OP else last_op if new_op < last_op or (last_op == new_op == Operation.SELECT): self._df = self._df._convert_leaf_to_cte() - df: _BaseDataFrame = func(self, *args, **kwargs) + df: BaseDataFrame = func(self, *args, **kwargs) df.last_op = new_op # type: ignore return df diff --git a/sqlframe/base/session.py b/sqlframe/base/session.py index f093b23..f2ddee8 100644 --- a/sqlframe/base/session.py +++ b/sqlframe/base/session.py @@ -24,7 +24,7 @@ from sqlglot.schema import MappingSchema from sqlframe.base.catalog import _BaseCatalog -from sqlframe.base.dataframe import _BaseDataFrame +from sqlframe.base.dataframe import BaseDataFrame from sqlframe.base.normalize import normalize_dict from sqlframe.base.readerwriter import _BaseDataFrameReader, _BaseDataFrameWriter from sqlframe.base.udf import _BaseUDFRegistration @@ -64,7 +64,7 @@ def fetchdf(self) -> pd.DataFrame: ... CATALOG = t.TypeVar("CATALOG", bound=_BaseCatalog) READER = t.TypeVar("READER", bound=_BaseDataFrameReader) WRITER = t.TypeVar("WRITER", bound=_BaseDataFrameWriter) -DF = t.TypeVar("DF", bound=_BaseDataFrame) +DF = t.TypeVar("DF", bound=BaseDataFrame) UDF_REGISTRATION = t.TypeVar("UDF_REGISTRATION", bound=_BaseUDFRegistration) _MISSING = "MISSING" diff --git a/sqlframe/bigquery/dataframe.py b/sqlframe/bigquery/dataframe.py index a23931a..47c9d2c 100644 --- a/sqlframe/bigquery/dataframe.py +++ b/sqlframe/bigquery/dataframe.py @@ -5,7 +5,7 @@ from sqlframe.base.catalog import Column as CatalogColumn from sqlframe.base.dataframe import ( - _BaseDataFrame, + BaseDataFrame, _BaseDataFrameNaFunctions, _BaseDataFrameStatFunctions, ) @@ -30,7 +30,7 @@ class BigQueryDataFrameStatFunctions(_BaseDataFrameStatFunctions["BigQueryDataFr class BigQueryDataFrame( NoCachePersistSupportMixin, - _BaseDataFrame[ + BaseDataFrame[ "BigQuerySession", "BigQueryDataFrameWriter", "BigQueryDataFrameNaFunctions", diff --git a/sqlframe/databricks/dataframe.py b/sqlframe/databricks/dataframe.py index eba9704..5e40a52 100644 --- a/sqlframe/databricks/dataframe.py +++ b/sqlframe/databricks/dataframe.py @@ -5,7 +5,7 @@ from sqlframe.base.catalog import Column as CatalogColumn from sqlframe.base.dataframe import ( - _BaseDataFrame, + BaseDataFrame, _BaseDataFrameNaFunctions, _BaseDataFrameStatFunctions, ) @@ -31,7 +31,7 @@ class DatabricksDataFrameStatFunctions(_BaseDataFrameStatFunctions["DatabricksDa class DatabricksDataFrame( NoCachePersistSupportMixin, - _BaseDataFrame[ + BaseDataFrame[ "DatabricksSession", "DatabricksDataFrameWriter", "DatabricksDataFrameNaFunctions", diff --git a/sqlframe/duckdb/dataframe.py b/sqlframe/duckdb/dataframe.py index 79005fc..c799296 100644 --- a/sqlframe/duckdb/dataframe.py +++ b/sqlframe/duckdb/dataframe.py @@ -4,7 +4,7 @@ import typing as t from sqlframe.base.dataframe import ( - _BaseDataFrame, + BaseDataFrame, _BaseDataFrameNaFunctions, _BaseDataFrameStatFunctions, ) @@ -34,7 +34,7 @@ class DuckDBDataFrameStatFunctions(_BaseDataFrameStatFunctions["DuckDBDataFrame" class DuckDBDataFrame( NoCachePersistSupportMixin, TypedColumnsFromTempViewMixin, - _BaseDataFrame[ + BaseDataFrame[ "DuckDBSession", "DuckDBDataFrameWriter", "DuckDBDataFrameNaFunctions", diff --git a/sqlframe/postgres/dataframe.py b/sqlframe/postgres/dataframe.py index bddda39..3fc1fa2 100644 --- a/sqlframe/postgres/dataframe.py +++ b/sqlframe/postgres/dataframe.py @@ -5,7 +5,7 @@ import typing as t from sqlframe.base.dataframe import ( - _BaseDataFrame, + BaseDataFrame, _BaseDataFrameNaFunctions, _BaseDataFrameStatFunctions, ) @@ -39,7 +39,7 @@ class PostgresDataFrameStatFunctions(_BaseDataFrameStatFunctions["PostgresDataFr class PostgresDataFrame( NoCachePersistSupportMixin, TypedColumnsFromTempViewMixin, - _BaseDataFrame[ + BaseDataFrame[ "PostgresSession", "PostgresDataFrameWriter", "PostgresDataFrameNaFunctions", diff --git a/sqlframe/redshift/dataframe.py b/sqlframe/redshift/dataframe.py index 3cec874..eabdc6f 100644 --- a/sqlframe/redshift/dataframe.py +++ b/sqlframe/redshift/dataframe.py @@ -5,7 +5,7 @@ import typing as t from sqlframe.base.dataframe import ( - _BaseDataFrame, + BaseDataFrame, _BaseDataFrameNaFunctions, _BaseDataFrameStatFunctions, ) @@ -30,7 +30,7 @@ class RedshiftDataFrameStatFunctions(_BaseDataFrameStatFunctions["RedshiftDataFr class RedshiftDataFrame( NoCachePersistSupportMixin, - _BaseDataFrame[ + BaseDataFrame[ "RedshiftSession", "RedshiftDataFrameWriter", "RedshiftDataFrameNaFunctions", diff --git a/sqlframe/snowflake/dataframe.py b/sqlframe/snowflake/dataframe.py index c8d569a..4a57962 100644 --- a/sqlframe/snowflake/dataframe.py +++ b/sqlframe/snowflake/dataframe.py @@ -6,7 +6,7 @@ from sqlframe.base.catalog import Column as CatalogColumn from sqlframe.base.dataframe import ( - _BaseDataFrame, + BaseDataFrame, _BaseDataFrameNaFunctions, _BaseDataFrameStatFunctions, ) @@ -32,7 +32,7 @@ class SnowflakeDataFrameStatFunctions(_BaseDataFrameStatFunctions["SnowflakeData class SnowflakeDataFrame( NoCachePersistSupportMixin, - _BaseDataFrame[ + BaseDataFrame[ "SnowflakeSession", "SnowflakeDataFrameWriter", "SnowflakeDataFrameNaFunctions", diff --git a/sqlframe/spark/dataframe.py b/sqlframe/spark/dataframe.py index 5988122..d96c5ea 100644 --- a/sqlframe/spark/dataframe.py +++ b/sqlframe/spark/dataframe.py @@ -7,7 +7,7 @@ from sqlframe.base.catalog import Column from sqlframe.base.dataframe import ( - _BaseDataFrame, + BaseDataFrame, _BaseDataFrameNaFunctions, _BaseDataFrameStatFunctions, ) @@ -31,7 +31,7 @@ class SparkDataFrameStatFunctions(_BaseDataFrameStatFunctions["SparkDataFrame"]) class SparkDataFrame( NoCachePersistSupportMixin, - _BaseDataFrame[ + BaseDataFrame[ "SparkSession", "SparkDataFrameWriter", "SparkDataFrameNaFunctions", diff --git a/sqlframe/standalone/dataframe.py b/sqlframe/standalone/dataframe.py index d63a198..3625689 100644 --- a/sqlframe/standalone/dataframe.py +++ b/sqlframe/standalone/dataframe.py @@ -3,7 +3,7 @@ import typing as t from sqlframe.base.dataframe import ( - _BaseDataFrame, + BaseDataFrame, _BaseDataFrameNaFunctions, _BaseDataFrameStatFunctions, ) @@ -23,7 +23,7 @@ class StandaloneDataFrameStatFunctions(_BaseDataFrameStatFunctions["StandaloneDa class StandaloneDataFrame( - _BaseDataFrame[ + BaseDataFrame[ "StandaloneSession", "StandaloneDataFrameWriter", "StandaloneDataFrameNaFunctions", diff --git a/sqlframe/testing/utils.py b/sqlframe/testing/utils.py index f9ed221..058b663 100644 --- a/sqlframe/testing/utils.py +++ b/sqlframe/testing/utils.py @@ -7,7 +7,7 @@ from itertools import zip_longest from sqlframe.base import types -from sqlframe.base.dataframe import _BaseDataFrame +from sqlframe.base.dataframe import BaseDataFrame from sqlframe.base.exceptions import ( DataFrameDiffError, SchemaDiffError, @@ -64,8 +64,8 @@ def red(s: str) -> str: # Source: https://github.com/apache/spark/blob/master/python/pyspark/testing/utils.py#L519 def assertDataFrameEqual( - actual: t.Union[_BaseDataFrame, pd.DataFrame, t.List[types.Row]], - expected: t.Union[_BaseDataFrame, pd.DataFrame, t.List[types.Row]], + actual: t.Union[BaseDataFrame, pd.DataFrame, t.List[types.Row]], + expected: t.Union[BaseDataFrame, pd.DataFrame, t.List[types.Row]], checkRowOrder: bool = False, rtol: float = 1e-5, atol: float = 1e-8, diff --git a/tests/integration/engines/test_engine_dataframe.py b/tests/integration/engines/test_engine_dataframe.py index 668b439..ba1099b 100644 --- a/tests/integration/engines/test_engine_dataframe.py +++ b/tests/integration/engines/test_engine_dataframe.py @@ -5,12 +5,12 @@ from sqlframe.base.types import Row if t.TYPE_CHECKING: - from sqlframe.base.dataframe import _BaseDataFrame + from sqlframe.base.dataframe import BaseDataFrame pytest_plugins = ["tests.integration.fixtures"] -def test_collect(get_engine_df: t.Callable[[str], _BaseDataFrame], get_func): +def test_collect(get_engine_df: t.Callable[[str], BaseDataFrame], get_func): employee = get_engine_df("employee") col = get_func("col", employee.session) results = employee.select(col("fname"), col("lname")).collect() @@ -24,7 +24,7 @@ def test_collect(get_engine_df: t.Callable[[str], _BaseDataFrame], get_func): def test_show( - get_engine_df: t.Callable[[str], _BaseDataFrame], + get_engine_df: t.Callable[[str], BaseDataFrame], get_func, capsys, caplog, @@ -53,7 +53,7 @@ def test_show( def test_show_limit( - get_engine_df: t.Callable[[str], _BaseDataFrame], capsys, is_snowflake: t.Callable + get_engine_df: t.Callable[[str], BaseDataFrame], capsys, is_snowflake: t.Callable ): employee = get_engine_df("employee") employee.show(1) diff --git a/tests/integration/engines/test_engine_writer.py b/tests/integration/engines/test_engine_writer.py index 6a3f3c9..236f3fd 100644 --- a/tests/integration/engines/test_engine_writer.py +++ b/tests/integration/engines/test_engine_writer.py @@ -10,15 +10,15 @@ from sqlframe.duckdb.session import DuckDBSession if t.TYPE_CHECKING: - from sqlframe.base.dataframe import _BaseDataFrame + from sqlframe.base.dataframe import BaseDataFrame pytest_plugins = ["tests.integration.fixtures"] @pytest.fixture def cleanup_employee_df( - get_engine_df: t.Callable[[str], _BaseDataFrame], -) -> t.Iterator[_BaseDataFrame]: + get_engine_df: t.Callable[[str], BaseDataFrame], +) -> t.Iterator[BaseDataFrame]: df = get_engine_df("employee") df.session._execute("DROP TABLE IF EXISTS insert_into_employee") df.session._execute("DROP TABLE IF EXISTS save_as_table_employee") @@ -27,7 +27,7 @@ def cleanup_employee_df( df.session._execute("DROP TABLE IF EXISTS save_as_table_employee") -def test_write_json(get_engine_df: t.Callable[[str], _BaseDataFrame], tmp_path: pathlib.Path): +def test_write_json(get_engine_df: t.Callable[[str], BaseDataFrame], tmp_path: pathlib.Path): df_employee = get_engine_df("employee") temp_json = str(tmp_path / "employee.json") df_employee.write.json(temp_json) @@ -50,9 +50,7 @@ def test_write_json_append(get_session: t.Callable[[], _BaseSession], tmp_path: assert df_result.collect() == [Row(_1=1), Row(_1=2)] -def test_write_json_ignore( - get_engine_df: t.Callable[[str], _BaseDataFrame], tmp_path: pathlib.Path -): +def test_write_json_ignore(get_engine_df: t.Callable[[str], BaseDataFrame], tmp_path: pathlib.Path): df_employee = get_engine_df("employee") temp_json = tmp_path / "employee.json" temp_json.touch() @@ -62,7 +60,7 @@ def test_write_json_ignore( def test_write_json_error( - get_engine_df: t.Callable[[str], _BaseDataFrame], tmp_path: pathlib.Path, caplog + get_engine_df: t.Callable[[str], BaseDataFrame], tmp_path: pathlib.Path, caplog ): df_employee = get_engine_df("employee") temp_json = tmp_path / "employee.json" @@ -71,7 +69,7 @@ def test_write_json_error( df_employee.write.json(temp_json, mode="error") -def test_write_parquet(get_engine_df: t.Callable[[str], _BaseDataFrame], tmp_path: pathlib.Path): +def test_write_parquet(get_engine_df: t.Callable[[str], BaseDataFrame], tmp_path: pathlib.Path): df_employee = get_engine_df("employee") temp_parquet = str(tmp_path / "employee.parquet") df_employee.write.parquet(temp_parquet) @@ -80,7 +78,7 @@ def test_write_parquet(get_engine_df: t.Callable[[str], _BaseDataFrame], tmp_pat def test_write_parquet_ignore( - get_engine_df: t.Callable[[str], _BaseDataFrame], tmp_path: pathlib.Path + get_engine_df: t.Callable[[str], BaseDataFrame], tmp_path: pathlib.Path ): df_employee = get_engine_df("employee") temp_parquet = str(tmp_path / "employee.parquet") @@ -95,7 +93,7 @@ def test_write_parquet_ignore( def test_write_parquet_error( - get_engine_df: t.Callable[[str], _BaseDataFrame], tmp_path: pathlib.Path, caplog + get_engine_df: t.Callable[[str], BaseDataFrame], tmp_path: pathlib.Path, caplog ): df_employee = get_engine_df("employee") temp_parquet = tmp_path / "employee.parquet" @@ -105,7 +103,7 @@ def test_write_parquet_error( def test_write_parquet_unsupported_modes( - get_engine_df: t.Callable[[str], _BaseDataFrame], tmp_path: pathlib.Path + get_engine_df: t.Callable[[str], BaseDataFrame], tmp_path: pathlib.Path ): df_employee = get_engine_df("employee") temp_json = tmp_path / "employee.parquet" @@ -113,7 +111,7 @@ def test_write_parquet_unsupported_modes( df_employee.write.parquet(str(temp_json), mode="append") -def test_write_csv(get_engine_df: t.Callable[[str], _BaseDataFrame], tmp_path: pathlib.Path): +def test_write_csv(get_engine_df: t.Callable[[str], BaseDataFrame], tmp_path: pathlib.Path): df_employee = get_engine_df("employee") temp_csv = str(tmp_path / "employee.csv") df_employee.write.csv(temp_csv) @@ -136,7 +134,7 @@ def test_write_csv_append(get_session: t.Callable[[], _BaseSession], tmp_path: p assert df_result.collect() == [Row(_1=1), Row(_1=2)] -def test_write_csv_ignore(get_engine_df: t.Callable[[str], _BaseDataFrame], tmp_path: pathlib.Path): +def test_write_csv_ignore(get_engine_df: t.Callable[[str], BaseDataFrame], tmp_path: pathlib.Path): df_employee = get_engine_df("employee") temp_csv = str(tmp_path / "employee.csv") df1 = df_employee.session.createDataFrame([(1,)]) @@ -150,7 +148,7 @@ def test_write_csv_ignore(get_engine_df: t.Callable[[str], _BaseDataFrame], tmp_ assert df_result.collect() == df1.collect() -def test_write_csv_error(get_engine_df: t.Callable[[str], _BaseDataFrame], tmp_path: pathlib.Path): +def test_write_csv_error(get_engine_df: t.Callable[[str], BaseDataFrame], tmp_path: pathlib.Path): df_employee = get_engine_df("employee") temp_csv = tmp_path / "employee.csv" temp_csv.touch() @@ -158,14 +156,14 @@ def test_write_csv_error(get_engine_df: t.Callable[[str], _BaseDataFrame], tmp_p df_employee.write.json(temp_csv, mode="error") -def test_save_as_table(cleanup_employee_df: _BaseDataFrame, caplog): +def test_save_as_table(cleanup_employee_df: BaseDataFrame, caplog): df_employee = cleanup_employee_df df_employee.write.saveAsTable("save_as_table_employee") df2 = df_employee.session.read.table("save_as_table_employee") assert sorted(df2.collect()) == sorted(df_employee.collect()) -def test_insertInto(cleanup_employee_df: _BaseDataFrame, caplog): +def test_insertInto(cleanup_employee_df: BaseDataFrame, caplog): df_employee = cleanup_employee_df df = df_employee.session.createDataFrame( [(9, "Sayid", "Jarrah", 40, 1)], ["id", "first_name", "last_name", "age", "store_id"] diff --git a/tests/integration/engines/test_int_functions.py b/tests/integration/engines/test_int_functions.py index 6e52364..cfd6142 100644 --- a/tests/integration/engines/test_int_functions.py +++ b/tests/integration/engines/test_int_functions.py @@ -25,7 +25,7 @@ from sqlframe.spark.session import SparkSession if t.TYPE_CHECKING: - from sqlframe.base.dataframe import _BaseDataFrame + from sqlframe.base.dataframe import BaseDataFrame pytest_plugins = ["tests.integration.fixtures"] @@ -33,7 +33,7 @@ class GetDfAndFuncCallable(t.Protocol): def __call__( self, name: str, limit: t.Optional[int] = None - ) -> t.Tuple[_BaseDataFrame, t.Callable]: ... + ) -> t.Tuple[BaseDataFrame, t.Callable]: ... def get_func_from_session(name: str, session: t.Union[PySparkSession, _BaseSession]) -> t.Callable: diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index 151949b..2b2ba3d 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -46,7 +46,7 @@ from sqlframe.standalone.session import StandaloneSession if t.TYPE_CHECKING: - from sqlframe.base.dataframe import _BaseDataFrame + from sqlframe.base.dataframe import BaseDataFrame from sqlframe.base.session import _BaseSession from tests.types import DistrictData, EmployeeData, StoreData @@ -650,14 +650,14 @@ def _make_function(df: PySparkDataFrame, mode: str = "extended") -> str: @pytest.fixture(params=ENGINE_PARAMETERS_NO_PYSPARK_STANDALONE) -def get_engine_df(request: FixtureRequest) -> t.Callable[[str], _BaseDataFrame]: +def get_engine_df(request: FixtureRequest) -> t.Callable[[str], BaseDataFrame]: mapping = { "employee": f"{request.param}_employee", "store": f"{request.param}_store", "district": f"{request.param}_district", } - def _get_engine_df(name: str) -> _BaseDataFrame: + def _get_engine_df(name: str) -> BaseDataFrame: return request.getfixturevalue(mapping[name]) return _get_engine_df @@ -672,14 +672,14 @@ def _get_session() -> _BaseSession: @pytest.fixture(params=ENGINE_PARAMETERS_NO_PYSPARK) -def get_df(request: FixtureRequest) -> t.Callable[[str], _BaseDataFrame]: +def get_df(request: FixtureRequest) -> t.Callable[[str], BaseDataFrame]: mapping = { "employee": f"{request.param}_employee", "store": f"{request.param}_store", "district": f"{request.param}_district", } - def _get_df(name: str) -> _BaseDataFrame: + def _get_df(name: str) -> BaseDataFrame: return request.getfixturevalue(mapping[name]) return _get_df @@ -698,14 +698,14 @@ def _get_engine_session_and_spark() -> t.Union[_BaseSession, PySparkSession]: @pytest.fixture(params=ENGINE_PARAMETERS_NO_STANDALONE) def get_engine_df_and_pyspark( request: FixtureRequest, -) -> t.Callable[[str], t.Union[_BaseDataFrame, PySparkDataFrame]]: +) -> t.Callable[[str], t.Union[BaseDataFrame, PySparkDataFrame]]: mapping = { "employee": f"{request.param}_employee", "store": f"{request.param}_store", "district": f"{request.param}_district", } - def _get_engine_df_and_pyspark(name: str) -> t.Union[_BaseDataFrame, PySparkDataFrame]: + def _get_engine_df_and_pyspark(name: str) -> t.Union[BaseDataFrame, PySparkDataFrame]: return request.getfixturevalue(mapping[name]) return _get_engine_df_and_pyspark diff --git a/tests/integration/test_int_dataframe.py b/tests/integration/test_int_dataframe.py index 18cfa7f..bb89303 100644 --- a/tests/integration/test_int_dataframe.py +++ b/tests/integration/test_int_dataframe.py @@ -12,16 +12,16 @@ from tests.integration.fixtures import StandaloneSession if t.TYPE_CHECKING: - from sqlframe.base.dataframe import _BaseDataFrame + from sqlframe.base.dataframe import BaseDataFrame - DataFrameMapping = t.Dict[str, _BaseDataFrame] + DataFrameMapping = t.Dict[str, BaseDataFrame] pytest_plugins = ["tests.integration.fixtures"] def test_empty_df( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): df_empty = pyspark_employee.sparkSession.createDataFrame([], "cola int, colb int") @@ -31,7 +31,7 @@ def test_empty_df( def test_simple_select( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -42,7 +42,7 @@ def test_simple_select( def test_simple_select_from_table( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -53,7 +53,7 @@ def test_simple_select_from_table( def test_select_star_from_table( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): df = pyspark_employee @@ -63,7 +63,7 @@ def test_select_star_from_table( def test_simple_select_df_attribute( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -74,7 +74,7 @@ def test_simple_select_df_attribute( def test_simple_select_df_dict( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -85,7 +85,7 @@ def test_simple_select_df_dict( def test_multiple_selects( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -98,7 +98,7 @@ def test_multiple_selects( def test_alias_no_op( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -109,7 +109,7 @@ def test_alias_no_op( def test_alias_with_select( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -128,7 +128,7 @@ def test_alias_with_select( def test_alias_with_space( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -149,7 +149,7 @@ def test_alias_with_space( def test_case_when_otherwise( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -176,7 +176,7 @@ def test_case_when_otherwise( def test_case_when_no_otherwise( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -199,7 +199,7 @@ def test_case_when_no_otherwise( def test_case_when_implicit_lit( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -222,7 +222,7 @@ def test_case_when_implicit_lit( def test_where_clause_single( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -233,7 +233,7 @@ def test_where_clause_single( def test_where_clause_eq_nullsafe( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -244,7 +244,7 @@ def test_where_clause_eq_nullsafe( def test_where_clause_multiple_and( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -259,7 +259,7 @@ def test_where_clause_multiple_and( def test_where_many_and( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -280,7 +280,7 @@ def test_where_many_and( def test_where_clause_multiple_or( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -295,7 +295,7 @@ def test_where_clause_multiple_or( def test_where_many_or( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -316,7 +316,7 @@ def test_where_many_or( def test_where_mixed_and_or( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -333,7 +333,7 @@ def test_where_mixed_and_or( def test_where_multiple_chained( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -348,7 +348,7 @@ def test_where_multiple_chained( def test_where_sql_expr( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -359,7 +359,7 @@ def test_where_sql_expr( def test_operators( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -409,7 +409,7 @@ def test_operators( def test_join_inner( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -438,7 +438,7 @@ def test_join_inner( def test_join_inner_no_select( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -459,7 +459,7 @@ def test_join_inner_no_select( def test_join_inner_equality_single( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -498,7 +498,7 @@ def test_join_inner_equality_single( def test_join_inner_equality_multiple( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -541,7 +541,7 @@ def test_join_inner_equality_multiple( def test_join_inner_equality_multiple_bitwise_and( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -579,7 +579,7 @@ def test_join_inner_equality_multiple_bitwise_and( def test_join_left_outer( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -616,7 +616,7 @@ def test_join_left_outer( def test_join_full_outer( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, is_bigquery: t.Callable, ): @@ -651,7 +651,7 @@ def test_triple_join( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, pyspark_district: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -690,7 +690,7 @@ def test_triple_join_no_select( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, pyspark_district: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, is_duckdb: t.Callable, is_postgres: t.Callable, @@ -752,7 +752,7 @@ def test_triple_joins_filter( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, pyspark_district: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, is_duckdb: t.Callable, is_postgres: t.Callable, @@ -810,7 +810,7 @@ def test_triple_join_column_name_only( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, pyspark_district: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, is_duckdb: t.Callable, is_postgres: t.Callable, @@ -863,7 +863,7 @@ def test_triple_join_column_name_only( def test_join_select_and_select_start( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -882,7 +882,7 @@ def test_join_select_and_select_start( def test_join_no_on( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): # No on results in a cross. Testing that "how" is ignored @@ -899,7 +899,7 @@ def test_join_no_on( def test_cross_join( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): # No on results in a cross. Testing that "how" is ignored @@ -915,7 +915,7 @@ def test_cross_join( def test_branching_root_dataframes( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, is_duckdb: t.Callable, is_postgres: t.Callable, @@ -990,7 +990,7 @@ def test_branching_root_dataframes( def test_basic_union( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1009,7 +1009,7 @@ def test_union_with_join( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, pyspark_district: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1040,7 +1040,7 @@ def test_double_union_all( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, pyspark_district: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1064,7 +1064,7 @@ def test_double_union_all( def test_union_by_name( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1091,7 +1091,7 @@ def test_union_by_name( def test_union_by_name_allow_missing( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, is_postgres: t.Callable, ): @@ -1130,7 +1130,7 @@ def test_union_by_name_allow_missing( def test_order_by_default( pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): store = get_df("store") @@ -1149,7 +1149,7 @@ def test_order_by_default( def test_order_by_array_bool( pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): store = get_df("store") @@ -1170,7 +1170,7 @@ def test_order_by_array_bool( def test_order_by_single_bool( pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): store = get_df("store") @@ -1191,7 +1191,7 @@ def test_order_by_single_bool( def test_order_by_column_sort_method( pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): store = get_df("store") @@ -1212,7 +1212,7 @@ def test_order_by_column_sort_method( def test_order_by_column_sort_method_nulls_last( pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): store = get_df("store") @@ -1235,7 +1235,7 @@ def test_order_by_column_sort_method_nulls_last( def test_order_by_column_sort_method_nulls_first( pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): store = get_df("store") @@ -1259,7 +1259,7 @@ def test_order_by_column_sort_method_nulls_first( def test_intersect( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1290,7 +1290,7 @@ def test_intersect( def test_intersect_all( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, is_bigquery: t.Callable, is_redshift: t.Callable, @@ -1330,7 +1330,7 @@ def test_intersect_all( def test_except_all( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, is_bigquery: t.Callable, is_redshift: t.Callable, @@ -1369,7 +1369,7 @@ def test_except_all( def test_distinct( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1382,7 +1382,7 @@ def test_distinct( def test_union_distinct( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1402,7 +1402,7 @@ def test_union_distinct( def test_drop_duplicates_no_subset( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1413,7 +1413,7 @@ def test_drop_duplicates_no_subset( def test_drop_duplicates_subset( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, is_redshift: t.Callable, ): @@ -1430,7 +1430,7 @@ def test_drop_duplicates_subset( def test_drop_na_default( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1447,7 +1447,7 @@ def test_drop_na_default( def test_dropna_how( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1464,7 +1464,7 @@ def test_dropna_how( def test_dropna_thresh( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1485,7 +1485,7 @@ def test_dropna_thresh( def test_dropna_subset( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1506,7 +1506,7 @@ def test_dropna_subset( def test_dropna_na_function( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1523,7 +1523,7 @@ def test_dropna_na_function( def test_fillna_default( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1540,7 +1540,7 @@ def test_fillna_default( def test_fillna_dict_replacement( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1563,7 +1563,7 @@ def test_fillna_dict_replacement( def test_fillna_na_func( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1580,7 +1580,7 @@ def test_fillna_na_func( def test_replace_basic( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1596,7 +1596,7 @@ def test_replace_basic( def test_replace_basic_subset( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1613,7 +1613,7 @@ def test_replace_basic_subset( def test_replace_mapping( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1626,7 +1626,7 @@ def test_replace_mapping( def test_replace_mapping_subset( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1643,7 +1643,7 @@ def test_replace_mapping_subset( def test_replace_na_func_basic( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1660,7 +1660,7 @@ def test_replace_na_func_basic( def test_with_column( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1673,7 +1673,7 @@ def test_with_column( def test_with_column_existing_name( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1686,7 +1686,7 @@ def test_with_column_existing_name( def test_with_column_renamed( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1699,7 +1699,7 @@ def test_with_column_renamed( def test_with_column_renamed_double( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1716,7 +1716,7 @@ def test_with_column_renamed_double( def test_with_columns( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1739,7 +1739,7 @@ def test_with_columns( def test_with_columns_reference_another( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, is_bigquery: t.Callable, is_postgres: t.Callable, @@ -1774,7 +1774,7 @@ def test_with_columns_reference_another( def test_drop_column_single( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1788,7 +1788,7 @@ def test_drop_column_single( def test_drop_column_reference_join( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1814,7 +1814,7 @@ def test_drop_column_reference_join( def test_limit( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -1828,7 +1828,7 @@ def test_limit( def test_hint_broadcast_alias( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, get_explain_plan: t.Callable, is_duckdb: t.Callable, @@ -1870,7 +1870,7 @@ def test_hint_broadcast_alias( def test_hint_broadcast_no_alias( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, get_explain_plan: t.Callable, ): @@ -1912,7 +1912,7 @@ def test_hint_broadcast_no_alias( def test_broadcast_func( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, get_explain_plan: t.Callable, ): @@ -1953,7 +1953,7 @@ def test_broadcast_func( def test_repartition_by_num( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): """ @@ -1974,7 +1974,7 @@ def test_repartition_by_num( def test_repartition_name_only( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, get_explain_plan: t.Callable, ): @@ -1993,7 +1993,7 @@ def test_repartition_name_only( def test_repartition_num_and_multiple_names( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, get_explain_plan: t.Callable, ): @@ -2014,7 +2014,7 @@ def test_repartition_num_and_multiple_names( def test_coalesce( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -2031,7 +2031,7 @@ def test_coalesce( def test_cache_select( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -2058,7 +2058,7 @@ def test_cache_select( def test_persist_select( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): employee = get_df("employee") @@ -2084,7 +2084,7 @@ def test_persist_select( def test_transform( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): def cast_all_to_int_pyspark(input_df): @@ -2106,7 +2106,7 @@ def sort_columns_asc(input_df): # https://github.com/eakmanrq/sqlframe/issues/51 def test_join_full_outer_no_match( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): spark = pyspark_employee._session @@ -2151,7 +2151,7 @@ def test_join_full_outer_no_match( # https://github.com/eakmanrq/sqlframe/issues/102 def test_join_with_duplicate_column_name( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): spark = pyspark_employee._session @@ -2172,7 +2172,7 @@ def test_join_with_duplicate_column_name( # https://github.com/eakmanrq/sqlframe/issues/103 def test_chained_join_common_key( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, ): spark = pyspark_employee._session @@ -2198,7 +2198,7 @@ def test_chaining_joins_with_selects( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame, pyspark_district: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, is_spark: t.Callable, ): @@ -2253,7 +2253,7 @@ def test_chaining_joins_with_selects( # https://github.com/eakmanrq/sqlframe/issues/212 def test_self_join( pyspark_employee: PySparkDataFrame, - get_df: t.Callable[[str], _BaseDataFrame], + get_df: t.Callable[[str], BaseDataFrame], compare_frames: t.Callable, is_spark: t.Callable, ): diff --git a/tests/integration/test_int_dataframe_stats.py b/tests/integration/test_int_dataframe_stats.py index bf419df..93ea4ca 100644 --- a/tests/integration/test_int_dataframe_stats.py +++ b/tests/integration/test_int_dataframe_stats.py @@ -3,14 +3,14 @@ import pytest -from sqlframe.base.dataframe import _BaseDataFrame +from sqlframe.base.dataframe import BaseDataFrame from sqlframe.postgres import PostgresDataFrame from sqlframe.snowflake import SnowflakeDataFrame pytest_plugins = ["tests.integration.fixtures"] -def test_approx_quantile(get_engine_df_and_pyspark: t.Callable[[str], _BaseDataFrame]): +def test_approx_quantile(get_engine_df_and_pyspark: t.Callable[[str], BaseDataFrame]): employee = get_engine_df_and_pyspark("employee") if isinstance(employee, PostgresDataFrame): pytest.skip("Approx quantile is not supported by the engine: postgres") @@ -22,13 +22,13 @@ def test_approx_quantile(get_engine_df_and_pyspark: t.Callable[[str], _BaseDataF assert results == expected -def test_corr(get_engine_df_and_pyspark: t.Callable[[str], _BaseDataFrame]): +def test_corr(get_engine_df_and_pyspark: t.Callable[[str], BaseDataFrame]): employee = get_engine_df_and_pyspark("employee") results = employee.stat.corr("employee_id", "age") assert math.isclose(results, -0.5605569890127448) -def test_cov(get_engine_df_and_pyspark: t.Callable[[str], _BaseDataFrame]): +def test_cov(get_engine_df_and_pyspark: t.Callable[[str], BaseDataFrame]): employee = get_engine_df_and_pyspark("employee") results = employee.stat.cov("employee_id", "age") assert results == -13.5