Skip to content

Commit

Permalink
test: use fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
mbelak-dtml committed Feb 29, 2024
1 parent f7b00be commit 81d708f
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 131 deletions.
27 changes: 14 additions & 13 deletions tests/test_bivariate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from .pyarrow_utils import pyarrow_parameterize


def get_test_df(pyarrow_dtypes: bool = False) -> pd.DataFrame:
@pytest.fixture
def test_df(pyarrow_dtypes: bool = False) -> pd.DataFrame:
test_df = pd.DataFrame(data=[[1.1, "a"], [2.2, "b"], [3.3, "c"]], columns=["A", "B"])
if pyarrow_dtypes:
test_df = test_df.convert_dtypes(dtype_backend="pyarrow")
Expand Down Expand Up @@ -125,7 +126,7 @@ def test_section_adding():
), "Subsection should be ContingencyTable"


def test_code_export_verbosity_low():
def test_code_export_verbosity_low(test_df: pd.DataFrame):
bivariate_section = bivariate_analysis.BivariateAnalysis(verbosity=Verbosity.LOW)
# Export code
exported_cells = []
Expand All @@ -138,10 +139,10 @@ def test_code_export_verbosity_low():
assert len(exported_code) == 1
assert exported_code[0] == expected_code[0], "Exported code mismatch"

check_section_executes(bivariate_section, df=get_test_df())
check_section_executes(bivariate_section, df=test_df)


def test_code_export_verbosity_low_with_subsections():
def test_code_export_verbosity_low_with_subsections(test_df: pd.DataFrame):
bivariate_section = bivariate_analysis.BivariateAnalysis(
subsections=[
BivariateAnalysisSubsection.ContingencyTable,
Expand All @@ -164,7 +165,7 @@ def test_code_export_verbosity_low_with_subsections():
assert len(exported_code) == 1
assert exported_code[0] == expected_code[0], "Exported code mismatch"

check_section_executes(bivariate_section, df=get_test_df())
check_section_executes(bivariate_section, df=test_df)


def test_generated_code_verbosity_low_columns():
Expand Down Expand Up @@ -209,7 +210,7 @@ def test_generated_code_verbosity_low_columns():
check_section_executes(bivariate_section, df=test_df)


def test_generated_code_verbosity_medium():
def test_generated_code_verbosity_medium(test_df: pd.DataFrame):
bivariate_section = bivariate_analysis.BivariateAnalysis(
verbosity=Verbosity.MEDIUM,
subsections=[
Expand All @@ -233,7 +234,7 @@ def test_generated_code_verbosity_medium():
for expected_line, exported_line in zip(expected_code, exported_code):
assert expected_line == exported_line, "Exported code mismatch"

check_section_executes(bivariate_section, df=get_test_df())
check_section_executes(bivariate_section, df=test_df)


def test_generated_code_verbosity_medium_columns_x_y():
Expand Down Expand Up @@ -307,7 +308,7 @@ def test_generated_code_verbosity_medium_columns_pairs():
check_section_executes(bivariate_section, df=test_df)


def test_generated_code_verbosity_high():
def test_generated_code_verbosity_high(test_df: pd.DataFrame):
bivariate_section = bivariate_analysis.BivariateAnalysis(
verbosity=Verbosity.HIGH,
subsections=[
Expand Down Expand Up @@ -345,10 +346,10 @@ def test_generated_code_verbosity_high():
for expected_line, exported_line in zip(expected_code, exported_code):
assert expected_line == exported_line, "Exported code mismatch"

check_section_executes(bivariate_section, df=get_test_df())
check_section_executes(bivariate_section, df=test_df)


def test_verbosity_low_different_subsection_verbosities():
def test_verbosity_low_different_subsection_verbosities(test_df: pd.DataFrame):
bivariate_section = BivariateAnalysis(
verbosity=Verbosity.LOW,
subsections=[
Expand Down Expand Up @@ -377,7 +378,7 @@ def test_verbosity_low_different_subsection_verbosities():
for expected_line, exported_line in zip(expected_code, exported_code):
assert expected_line == exported_line, "Exported code mismatch"

check_section_executes(bivariate_section, df=get_test_df())
check_section_executes(bivariate_section, df=test_df)


def test_imports_verbosity_low():
Expand Down Expand Up @@ -450,9 +451,9 @@ def test_imports_verbosity_low_different_subsection_verbosities():


@pyarrow_parameterize
def test_show(pyarrow_dtypes: bool):
def test_show(pyarrow_dtypes: bool, test_df: pd.DataFrame):
bivariate_section = BivariateAnalysis()
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
with redirect_stdout(None):
bivariate_section.show(get_test_df(pyarrow_dtypes=pyarrow_dtypes))
bivariate_section.show(test_df)
84 changes: 39 additions & 45 deletions tests/test_group_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
plotly.io.renderers.default = "json"


def get_test_df(pyarrow_dtypes: bool = False) -> pd.DataFrame:
@pytest.fixture
def test_df(pyarrow_dtypes: bool = False) -> pd.DataFrame:
test_df = pd.DataFrame(
data=[
["P" if np.random.uniform() < 0.4 else "N", 1.5 * i, "X" if i % 2 == 0 else "Y"]
Expand All @@ -54,50 +55,46 @@ def test_invalid_verbosities():


@pyarrow_parameterize
def test_groupby_nonexistent_col(pyarrow_dtypes: bool):
def test_groupby_nonexistent_col(pyarrow_dtypes: bool, test_df: pd.DataFrame):
with pytest.raises(ValueError):
show_group_analysis(df=get_test_df(pyarrow_dtypes=pyarrow_dtypes), groupby=["non-existent"])
show_group_analysis(df=test_df, groupby=["non-existent"])
with pytest.raises(ValueError):
group_missing_values(
df=get_test_df(pyarrow_dtypes=pyarrow_dtypes), groupby=["non-existent"]
)
group_missing_values(df=test_df, groupby=["non-existent"])


@pyarrow_parameterize
def test_static_methods(pyarrow_dtypes: bool):
df = get_test_df(pyarrow_dtypes=pyarrow_dtypes)
def test_static_methods(pyarrow_dtypes: bool, test_df: pd.DataFrame):
with redirect_stdout(None):
show_group_analysis(df=df, groupby="C")
show_group_analysis(df=df, groupby=["C"], columns=["A"])
show_group_analysis(df=df, groupby=["C"], columns=["A", "B"])
show_group_analysis(df=df, groupby="C", columns=["A", "B", "C"])
show_group_analysis(df=df, groupby="C", columns=["C"])
show_group_analysis(df=test_df, groupby="C")
show_group_analysis(df=test_df, groupby=["C"], columns=["A"])
show_group_analysis(df=test_df, groupby=["C"], columns=["A", "B"])
show_group_analysis(df=test_df, groupby="C", columns=["A", "B", "C"])
show_group_analysis(df=test_df, groupby="C", columns=["C"])

group_barplot(df, groupby=["A"], column="B")
group_barplot(df, groupby=["A"], column="A")
group_barplot(df, groupby=["A", "C"], column="B")
group_barplot(df, groupby=["A"], column="C")
group_barplot(df, groupby=["A"], column="C")
group_barplot(test_df, groupby=["A"], column="B")
group_barplot(test_df, groupby=["A"], column="A")
group_barplot(test_df, groupby=["A", "C"], column="B")
group_barplot(test_df, groupby=["A"], column="C")
group_barplot(test_df, groupby=["A"], column="C")

group_missing_values(df, groupby=["C"])
group_missing_values(df, groupby=["C"], columns=["A", "B"])
group_missing_values(df, groupby=["C"], columns=["A", "B", "C"])
group_missing_values(df, groupby=["C"], columns=["C"])
group_missing_values(test_df, groupby=["C"])
group_missing_values(test_df, groupby=["C"], columns=["A", "B"])
group_missing_values(test_df, groupby=["C"], columns=["A", "B", "C"])
group_missing_values(test_df, groupby=["C"], columns=["C"])

overlaid_histograms(df, groupby=["A"], column="B")
overlaid_histograms(df, groupby=["A", "C"], column="B")
overlaid_histograms(df, groupby=["A", "C"], column="B")
overlaid_histograms(df, groupby=["B"], column="B")
overlaid_histograms(test_df, groupby=["A"], column="B")
overlaid_histograms(test_df, groupby=["A", "C"], column="B")
overlaid_histograms(test_df, groupby=["A", "C"], column="B")
overlaid_histograms(test_df, groupby=["B"], column="B")


@pyarrow_parameterize
def test_code_export_verbosity_low(pyarrow_dtypes: bool):
df = get_test_df(pyarrow_dtypes=pyarrow_dtypes)
def test_code_export_verbosity_low(pyarrow_dtypes: bool, test_df: pd.DataFrame):
group_section = GroupAnalysis(groupby="B", verbosity=Verbosity.LOW)

# Export code
exported_cells = []
group_section.add_cells(exported_cells, df=df)
group_section.add_cells(exported_cells, df=test_df)
# Remove markdown and other cells and get code strings
exported_code = [cell["source"] for cell in exported_cells if cell["cell_type"] == "code"]
# Define expected code
Expand All @@ -106,17 +103,16 @@ def test_code_export_verbosity_low(pyarrow_dtypes: bool):
assert len(exported_code) == 1
assert exported_code[0] == expected_code[0], "Exported code mismatch"

check_section_executes(group_section, df)
check_section_executes(group_section, test_df)


@pyarrow_parameterize
def test_code_export_verbosity_medium(pyarrow_dtypes: bool):
df = get_test_df(pyarrow_dtypes=pyarrow_dtypes)
def test_code_export_verbosity_medium(pyarrow_dtypes: bool, test_df: pd.DataFrame):
group_section = GroupAnalysis(groupby="A", verbosity=Verbosity.MEDIUM)

# Export code
exported_cells = []
group_section.add_cells(exported_cells, df=df)
group_section.add_cells(exported_cells, df=test_df)
# Remove markdown and other cells and get code strings
exported_code = [cell["source"] for cell in exported_cells if cell["cell_type"] == "code"]
# Define expected code
Expand All @@ -135,17 +131,16 @@ def test_code_export_verbosity_medium(pyarrow_dtypes: bool):
for expected_line, exported_line in zip(expected_code, exported_code):
assert expected_line == exported_line, "Exported code mismatch"

check_section_executes(group_section, df)
check_section_executes(group_section, test_df)


@pyarrow_parameterize
def test_code_export_verbosity_high(pyarrow_dtypes: bool):
df = get_test_df(pyarrow_dtypes=pyarrow_dtypes)
def test_code_export_verbosity_high(pyarrow_dtypes: bool, test_df: pd.DataFrame):
group_section = GroupAnalysis(groupby="A", verbosity=Verbosity.HIGH)

# Export code
exported_cells = []
group_section.add_cells(exported_cells, df=df)
group_section.add_cells(exported_cells, df=test_df)
# Remove markdown and other cells and get code strings
exported_code = [cell["source"] for cell in exported_cells if cell["cell_type"] == "code"]
# Define expected code
Expand Down Expand Up @@ -192,21 +187,20 @@ def test_code_export_verbosity_high(pyarrow_dtypes: bool):
for expected_line, exported_line in zip(expected_code, exported_code):
assert expected_line == exported_line, "Exported code mismatch"

check_section_executes(group_section, df)
check_section_executes(group_section, test_df)


@pyarrow_parameterize
def test_columns_parameter(pyarrow_dtypes: bool):
df = get_test_df(pyarrow_dtypes=pyarrow_dtypes)
def test_columns_parameter(pyarrow_dtypes: bool, test_df: pd.DataFrame):
ga = GroupAnalysis(groupby="A", columns=["B"])
assert ga.groupby == ["A"]
assert ga.columns == ["B"]

ga = GroupAnalysis(groupby="A")
assert ga.groupby == ["A"]
assert ga.columns is None
ga.show(df)
ga.add_cells([], df=df)
ga.show(test_df)
ga.add_cells([], df=test_df)
assert ga.groupby == ["A"]
assert ga.columns is None

Expand All @@ -218,10 +212,10 @@ def test_column_list_not_modified():


@pyarrow_parameterize
def test_show(pyarrow_dtypes: bool):
df = get_test_df(pyarrow_dtypes=pyarrow_dtypes)
def test_show(pyarrow_dtypes: bool, test_df: pd.DataFrame):
df = test_df
group_section = GroupAnalysis(groupby="A")
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
with redirect_stdout(None):
group_section.show(df)
group_section.show(test_df)
Loading

0 comments on commit 81d708f

Please sign in to comment.