Skip to content

Commit

Permalink
refactor: use mypy type checking
Browse files Browse the repository at this point in the history
* add proper type hints
* add mypy type checking to CI
  • Loading branch information
mbelak-dtml committed Mar 6, 2024
1 parent 8f63194 commit 2503421
Show file tree
Hide file tree
Showing 18 changed files with 86 additions and 59 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
poetry run pylint --rcfile=.pylintrc edvart/
poetry run black --check --line-length 100 edvart/ tests/
poetry run isort --check --line-length 100 --profile black edvart/ tests/
poetry run mypy edvart/
dismiss-stale-reviews:
runs-on: ubuntu-22.04
Expand Down
6 changes: 5 additions & 1 deletion edvart/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def embed_image_base64(image_path: str, mime: str = "image/png") -> str:
# Look up directory where currently executed template is located
# Jinja's @environmentfilter or @contextfilter does not seem to provide
# any information about the path of the template.
template_dir = os.path.dirname(inspect.getfile(inspect.currentframe().f_back))
current_frame = inspect.currentframe()
assert current_frame is not None
frame_back = current_frame.f_back
assert frame_back is not None
template_dir = os.path.dirname(inspect.getfile(frame_back))
with open(os.path.join(template_dir, image_path), "rb") as img:
return f"data:{mime};base64," + str(base64.b64encode(img.read()).decode("utf-8"))
2 changes: 1 addition & 1 deletion edvart/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _scatter_plot_2d_noninteractive(
color_categorical = pd.Categorical(df[color_col])
color_codes = color_categorical.codes
else:
color_codes = df[color_col]
color_codes = df[color_col].values.astype(np.signedinteger)
scatter = ax.scatter(x, y, c=color_codes, alpha=opacity)

if is_color_categorical:
Expand Down
14 changes: 7 additions & 7 deletions edvart/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from abc import ABC
from copy import copy
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Sized, Tuple, Union

import isort
import nbconvert
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(
self.df = dataframe
self.sections: list[Section] = []
self.verbosity = Verbosity(verbosity)
self._table_of_contents = None
self._table_of_contents: Optional[TableOfContents] = None

def _warn_if_empty(self) -> None:
"""Warns if the report contains no sections."""
Expand Down Expand Up @@ -207,7 +207,7 @@ def _export_html(
Maximum number of seconds to wait for a cell to finish execution.
"""
# Execute notebook to produce output of cells
html_exp_kwargs = dict(
html_exp_kwargs: Dict[str, Any] = dict(
preprocessors=[nbconvert.preprocessors.ExecutePreprocessor(timeout=timeout)]
)
if template_name is not None:
Expand Down Expand Up @@ -275,7 +275,7 @@ def export_html(
# and unpickles the the whole report object from the decoded binary data
unpickle_report = code_dedent(
f"""
data = {buffer_base64}
data = {buffer_base64!r}
report = pickle.loads(base64.b85decode(data), fix_imports=False)
"""
)
Expand Down Expand Up @@ -676,7 +676,7 @@ def __init__(
columns_bivariate_analysis: Optional[List[str]] = None,
columns_multivariate_analysis: Optional[List[str]] = None,
columns_group_analysis: Optional[List[str]] = None,
groupby: Union[str, List[str]] = None,
groupby: Optional[Union[str, List[str]]] = None,
):
super().__init__(dataframe, verbosity)

Expand All @@ -699,7 +699,7 @@ def __init__(
)
if isinstance(groupby, str):
color_col = groupby
elif hasattr(groupby, "__len__") and len(groupby) == 1:
elif isinstance(groupby, Sized) and len(groupby) == 1:
color_col = groupby[0]
else:
color_col = None
Expand Down Expand Up @@ -740,7 +740,7 @@ def __init__(
verbosity: Verbosity = Verbosity.LOW,
):
super().__init__(dataframe, verbosity)
if not is_date(dataframe.index):
if not is_date(dataframe.index.to_series()):
raise ValueError(
"Input dataframe needs to be indexed by time."
"Please reindex your data to be indexed by either a DatetimeIndex or a PeriodIndex."
Expand Down
8 changes: 7 additions & 1 deletion edvart/report_sections/bivariate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def __init__(
raise ValueError("Either both or neither of columns_x, columns_y must be specified.")
# For analyses which do not take columns_pairs, prepare columns_x and columns_y in case
# columns_pairs is the only parameter specified
columns_x_no_pairs: Optional[List[str]]
columns_y_no_pairs: Optional[List[str]]
if columns is None and columns_x is None and columns_pairs is not None:
columns_x_no_pairs = [pair[0] for pair in columns_pairs]
columns_y_no_pairs = [pair[1] for pair in columns_pairs]
Expand Down Expand Up @@ -456,6 +458,7 @@ def _get_columns_x_y(
if columns is None:
columns = list(df.columns)
columns_x = columns_y = columns
assert columns_y is not None
columns_x = [col for col in columns_x if is_numeric(df[col])]
columns_y = [col for col in columns_y if is_numeric(df[col])]

Expand Down Expand Up @@ -722,6 +725,7 @@ def include_column(col: str) -> bool:
columns_x = columns
columns_y = columns
if not allow_categorical:
assert columns_y is not None
columns_x = list(filter(include_column, columns_x))
columns_y = list(filter(include_column, columns_y))
sns.pairplot(df, x_vars=columns_x, y_vars=columns_y, hue=color_col)
Expand Down Expand Up @@ -908,6 +912,8 @@ def include_column(col: str) -> bool:
if columns_x is None:
columns_pairs = list(itertools.combinations(columns, 2))
else:
assert columns_x is not None
assert columns_y is not None
columns_pairs = [
(col_x, col_y)
for (col_x, col_y) in itertools.product(columns_x, columns_y)
Expand Down Expand Up @@ -971,7 +977,7 @@ def contingency_table(
annot = table.replace(0, "") if hide_zeros else table

ax = sns.heatmap(
scaling_func(table),
scaling_func(table.values),
annot=annot,
fmt="",
cbar=False,
Expand Down
8 changes: 4 additions & 4 deletions edvart/report_sections/dataset_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def data_types(df: pd.DataFrame, columns: Optional[List[str]] = None) -> None:
"""
if columns is not None:
df = df[columns]
dtypes = df.apply(
dtypes = df.apply( # type: ignore
func=lambda x_: str(infer_data_type(x_)),
axis=0,
result_type="expand",
Expand Down Expand Up @@ -652,7 +652,7 @@ def missing_values(
bar_plot_title: str = "Missing Values Percentage of Each Column",
bar_plot_ylim: float = 0,
bar_plot_color: str = "#FFA07A",
**bar_plot_args: Dict[str, Any],
**bar_plot_args: Any,
) -> None:
"""Displays a table of missing values percentages for each column of df and a bar plot
of the percentages.
Expand All @@ -675,7 +675,7 @@ def missing_values(
Bar plot y axis bottom limit.
bar_plot_color : str
Color of bars in the bar plot in hex format.
bar_plot_args : Dict[str, Any]
bar_plot_args : Any
Additional kwargs passed to pandas.Series.bar.
"""
if columns is not None:
Expand Down Expand Up @@ -717,7 +717,7 @@ def missing_values(
title=bar_plot_title,
ylim=bar_plot_ylim,
color=bar_plot_color,
**bar_plot_args,
**bar_plot_args, # type: ignore
)
.set_ylabel("Missing Values [%]")
)
Expand Down
20 changes: 12 additions & 8 deletions edvart/report_sections/group_analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Union

import colorlover as cl
import nbformat.v4 as nbfv4
Expand Down Expand Up @@ -102,7 +102,7 @@ def required_imports(self) -> List[str]:
"import plotly.graph_objects as go",
"from edvart.data_types import infer_data_type, DataType",
"from edvart import utils",
"from typing import List, Dict, Optional, Callable",
"from typing import List, Dict, Optional, Callable, Iterable",
"from plotly.subplots import make_subplots",
]

Expand Down Expand Up @@ -218,7 +218,7 @@ def add_cells(self, cells: List[Dict[str, Any]], df: pd.DataFrame) -> None:
)
cells.append(nbfv4.new_code_cell(code))

columns = self.columns if self.columns is not None else df.columns
columns = self.columns if self.columns is not None else df.columns.to_list()

if not self.show_statistics and not self.show_dist:
return
Expand Down Expand Up @@ -362,7 +362,7 @@ def within_group_stats(
df: pd.DataFrame,
groupby: List[str],
column: str,
stats: Dict[str, Callable[[pd.Series], float]] = None,
stats: Optional[Dict[str, Callable[[pd.Series], float]]] = None,
round_decimals: int = 2,
) -> None:
"""Display withing group statistics for a column of df grouped by one or other more columns.
Expand Down Expand Up @@ -448,7 +448,9 @@ def group_missing_values(
df_grouped = df.groupby(groupby)[columns]

# Calculate number of samples in each group
sizes = df_grouped.size().rename("Group Size")
sizes = df_grouped.size()
assert isinstance(sizes, pd.Series)
sizes = sizes.rename("Group Size")

# Calculate missing values percentage of each column for each group
missing = df_grouped.apply(lambda g: g.isna().sum(axis=0))
Expand Down Expand Up @@ -490,7 +492,7 @@ def color_cell(value):
background-color: {bg_hex};
"""

render = final_table.style.applymap(
render = final_table.style.map(
func=color_cell, subset=pd.IndexSlice[:, colored_columns]
).format(formatter="{0:.2f} %", subset=pd.IndexSlice[:, colored_columns])
else:
Expand Down Expand Up @@ -553,7 +555,8 @@ def group_barplot(

fig = go.Figure()
for color_idx, (idx, row) in enumerate(pivot.iterrows()):
if hasattr(idx, "__len__") and not isinstance(idx, str):
group_name: Hashable
if isinstance(idx, Iterable) and not isinstance(idx, str):
group_name = "_".join([str(i) for i in idx])
else:
group_name = idx
Expand Down Expand Up @@ -641,7 +644,8 @@ def overlaid_histograms(
)

for color_idx, (name, group) in enumerate(df.groupby(groupby)):
if hasattr(name, "__len__") and not isinstance(name, str):
group_name: Hashable
if isinstance(name, Iterable) and not isinstance(name, str):
group_name = "_".join([str(i) for i in name])
else:
group_name = name
Expand Down
20 changes: 11 additions & 9 deletions edvart/report_sections/multivariate_analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import nbformat.v4 as nbfv4
Expand Down Expand Up @@ -487,7 +487,7 @@ def pca_explained_variance(
plt.ylabel("Explained variance ratio")
plt.xticks(
ticks=range(len(pca.explained_variance_ratio_)),
labels=range(1, (len(pca.explained_variance_ratio_) + 1)),
labels=[str(label) for label in range(1, (len(pca.explained_variance_ratio_) + 1))],
)
if show_grid:
plt.grid()
Expand Down Expand Up @@ -630,13 +630,15 @@ def parallel_coordinates(
columns = [col for col in columns if col not in hide_columns]
if drop_na:
df = df.dropna()

line: Optional[Dict[str, Any]] = None
if color_col is not None:
is_categorical_color = infer_data_type(df[color_col]) in (
DataType.CATEGORICAL,
DataType.UNIQUE,
DataType.BOOLEAN,
)

colorscale: Union[List[Tuple[float, str]], str]
if is_categorical_color:
categories = df[color_col].unique()
colorscale = get_default_discrete_colorscale(n_colors=len(categories))
Expand Down Expand Up @@ -669,8 +671,6 @@ def parallel_coordinates(
"cmax": len(categories) - 0.5,
}
)
else:
line = None
# Add numeric columns to dimensions
dimensions = [{"label": col_name, "values": df[col_name]} for col_name in numeric_columns]
# Add categorical columns to dimensions
Expand Down Expand Up @@ -818,12 +818,15 @@ def parallel_categories(
columns = [col for col in columns if col not in hide_columns]
if drop_na:
df = df.dropna()

line: Optional[Dict[str, Any]] = None
if color_col is not None:
categorical_color = infer_data_type(df[color_col]) in (
DataType.CATEGORICAL,
DataType.UNIQUE,
DataType.BOOLEAN,
)
colorscale: Union[List[Tuple[float, str]], str]
if categorical_color:
categories = df[color_col].unique()
colorscale = get_default_discrete_colorscale(n_colors=len(categories))
Expand All @@ -833,14 +836,15 @@ def parallel_categories(
color_series = df[color_col]
colorscale = "Bluered_r"

colorbar: Dict[str, Any] = {"title": color_col}
line = {
"color": color_series,
"colorscale": colorscale,
"colorbar": {"title": color_col},
"colorbar": colorbar,
}

if categorical_color:
line["colorbar"].update(
colorbar.update(
{
"tickvals": color_series.unique(),
"ticktext": categories,
Expand All @@ -855,8 +859,6 @@ def parallel_categories(
"cmax": len(categories) - 0.5,
}
)
else:
line = None

dimensions = [go.parcats.Dimension(values=df[col_name], label=col_name) for col_name in columns]

Expand Down
2 changes: 1 addition & 1 deletion edvart/report_sections/table_of_contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def show(self, sections: List[Section]) -> None:
"""
display(Markdown(self._title))

lines = []
lines: List[str] = []
for section in sections:
self._add_section_lines(section, 1, lines, self._include_subsections)
display(Markdown("\n".join(lines)))
10 changes: 5 additions & 5 deletions edvart/report_sections/timeseries_analysis/boxplots_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self,
verbosity: Verbosity = Verbosity.LOW,
columns: Optional[List[str]] = None,
grouping_function: Callable[[Any], str] = None,
grouping_function: Optional[Callable[[Any], str]] = None,
grouping_function_imports: Optional[List[str]] = None,
grouping_name: Optional[str] = None,
default_nunique_max: int = 80,
Expand Down Expand Up @@ -161,7 +161,7 @@ def show(self, df: pd.DataFrame) -> None:
)


def default_grouping_functions() -> Dict[str, Callable[[datetime], str]]:
def default_grouping_functions() -> Dict[str, Callable[[pd.Timestamp], str]]:
"""Return a dictionary of function names and functions.
The function takes a pandas datetime and represents it as a rougher (in terms of time)
Expand All @@ -170,7 +170,7 @@ def default_grouping_functions() -> Dict[str, Callable[[datetime], str]]:
Returns
-------
Dict[str, Callable[[datetime], str]]
Dict[str, Callable[[pandas.Timestamp], str]]
Dictionary from grouping function names to grouping functions.
"""
return {
Expand Down Expand Up @@ -217,7 +217,7 @@ def get_default_grouping_func(df: pd.DataFrame, nunique_max: int = 80) -> Tuple[
def show_boxplots_over_time(
df: pd.DataFrame,
columns: Optional[List[str]] = None,
grouping_function: Callable[[Any], str] = None,
grouping_function: Optional[Callable[[Any], str]] = None,
grouping_name: Optional[str] = None,
default_nunique_max: int = 80,
figsize: Tuple[float, float] = (20, 7),
Expand Down Expand Up @@ -264,7 +264,7 @@ def show_boxplots_over_time(
grouping_name, grouping_function = get_default_grouping_func(
df, nunique_max=default_nunique_max
)
elif default_grouping_funcs.get(grouping_name) is not None:
elif grouping_name is not None and default_grouping_funcs.get(grouping_name) is not None:
grouping_function = default_grouping_funcs[grouping_name]

if columns is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def show_fourier_transform(
for col in columns:
if not is_numeric(df[col]):
raise ValueError(f"Cannot perform Fourier transform for non-numeric column `{col}`")
index_freq = pd.infer_freq(df.index) or ""
index_freq = pd.infer_freq(df.index.to_series()) or ""
for col in columns:
# FFT requires samples at regular intervals
df_col = df[col].interpolate(method="time")
Expand Down
Loading

0 comments on commit 2503421

Please sign in to comment.