Skip to content

Commit

Permalink
20681: Adds UMAP plot function, MINOR (#17)
Browse files Browse the repository at this point in the history
Co-authored-by: howso-automation <support@howso.com>
  • Loading branch information
jdbeel and howso-automation authored Jul 12, 2024
1 parent bd84307 commit d9ffb0b
Show file tree
Hide file tree
Showing 13 changed files with 5,983 additions and 2,379 deletions.
1,356 changes: 1,299 additions & 57 deletions LICENSE-3RD-PARTY.txt

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions howso/visuals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
plot_feature_importances,
plot_interpretable_prediction,
plot_kl_divergence,
plot_umap,
)

__all__ = [
Expand All @@ -18,4 +19,5 @@
"plot_feature_importances",
"plot_interpretable_prediction",
"plot_kl_divergence",
"plot_umap",
]
24 changes: 18 additions & 6 deletions howso/visuals/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
plot_dataset,
plot_interpretable_prediction,
plot_fairness_disparity,
plot_kl_divergence
plot_kl_divergence,
plot_umap,
)
from howso.visuals.visuals import plot_drift, plot_feature_importances

Expand Down Expand Up @@ -56,8 +57,8 @@ def test_plot_interpretable_prediction_react(
generative_reacts = None

if do_residual:
iris_trainee.react_into_trainee(residuals=True)
residual = iris_trainee.get_prediction_stats(stats=["mae"])[action_feature].iloc[0]
residual = iris_trainee.react_aggregate(details={"prediction_stats": True, "selected_prediction_stats": ["mae"]})
residual = residual[action_feature].iloc[0]
else:
residual = None

Expand Down Expand Up @@ -164,12 +165,12 @@ def outliers_convictions(iris_trainee, iris_features):
details={
"boundary_cases": True,
"influential_cases": True,
"global_case_feature_residual_convictions": True,
"local_case_feature_residual_convictions": True,
"global_case_feature_residual_convictions_full": True,
"local_case_feature_residual_convictions_full": True,
}
)
convictions = pd.DataFrame(
convictions["details"]["global_case_feature_residual_convictions"]
convictions["details"]["global_case_feature_residual_convictions_full"]
)

yield outliers, convictions
Expand Down Expand Up @@ -224,3 +225,14 @@ def test_plot_fairness_disparity(x_tickangle):
fig = plot_fairness_disparity(fairness_results, reference_class='Male', x_tickangle=x_tickangle)

assert fig is not None


@pytest.mark.parametrize("n_cases", [None, 50])
@pytest.mark.parametrize("data", ["iris_train", "iris_trainee"])
def test_plot_umap(data, n_cases, request):
data = request.getfixturevalue(data)
fig = plot_umap(data, n_cases=n_cases)

assert fig is not None
if n_cases is not None:
assert len(fig.data[0].y) == n_cases
111 changes: 111 additions & 0 deletions howso/visuals/visuals.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import typing as t
import warnings

import numpy as np
import numpy.typing as npt
from pandas import (
DataFrame,
Series,
Expand All @@ -13,6 +15,14 @@
from plotly.subplots import make_subplots
from scipy.stats import gaussian_kde

with warnings.catch_warnings():
warnings.simplefilter("ignore")
import umap

from howso.engine import Trainee
from howso.utilities import infer_feature_attributes


if t.TYPE_CHECKING:
from howso.engine.trainee import Reaction

Expand Down Expand Up @@ -728,3 +738,104 @@ def compose_figures(
return_figure.update_traces(bingroup=f"subplot_{i},{j}", row=i, col=j)

return return_figure


def plot_umap(
data: DataFrame | Trainee,
*,
color: t.Optional[str] = None,
min_dist: t.Optional[float] = None,
n_cases: t.Optional[int] = None,
n_neighbors: t.Optional[int] = None,
title: str = "UMAP Representation",
xaxis_title: str = "Component 1",
yaxis_title: str = "Component 2",
) -> go.Figure:
"""
Transform data into a lower-dimensionality representation using Howso Engine and UMAP and then plot it.
Howso Engine computes pairwise distances which are then used with UMAP's ``precomputed``
metric.
Parameters
----------
data : DataFrame | Trainee
The data to transform or a :class:`Trainee` containing the data to transform.
color : str, optional
The name of the column in ``data`` to use for determining marker color.
min_dist : float, optional
The ``min_dist`` parameter for ``umap.UMAP``. If None, this will be the :math:`p` norm
of the feature residuals, where :math:`p` is selected by :meth:`Trainee.analyze`.
n_cases : int, optional
The number of cases to compute pairwise distances for. If None, then all of the cases
are used.
n_neighbors : int, optional
The ``n_neighbors`` parameter for ``umap.UMAP``. If None, this will be the :math:`k`
selected by :meth:`Trainee.analyze`.
title : str, default "UMAP Representation"
The title for the figure.
xaxis_title : str, default "Component 1"
The title for the x-axis.
yaxis_title : str, default "Component 2"
The title for the y-axis.
Returns
-------
Figure
The resultant `Plotly` figure.
"""
if isinstance(data, DataFrame):
features = infer_feature_attributes(data)
t = Trainee(features=features)
t.train(data, skip_auto_analyze=True)
t.analyze()
elif isinstance(data, Trainee):
t = data
else:
raise TypeError("`data` must be a Trainee or a DataFrame.")

case_indices = None
if n_cases is not None:
sampled_cases = t.get_cases(
features=[".session", ".session_training_index"] + list(t.features),
session=t.get_sessions()[0]["id"],
).sample(n_cases)
case_indices = sampled_cases[[".session", ".session_training_index"]]
case_indices = case_indices.values.tolist()

distances = t.get_distances(case_indices=case_indices)["distances"]
hyperparameter_map = t.get_params(action_feature=".targetless")["hyperparameter_map"]

n_neighbors = n_neighbors or hyperparameter_map["k"]
p = hyperparameter_map["p"]

if min_dist is None:
residuals = t.react_aggregate(details={"feature_residuals_full": True})
min_dist = float((residuals.values ** p).sum() ** (1 / p))
min_dist = min(round(min_dist, 3), 1)

with warnings.catch_warnings():
warnings.simplefilter("ignore")
points = umap.UMAP(
metric="precomputed",
min_dist=min_dist,
n_neighbors=n_neighbors,
).fit_transform(distances)

scatter_kwargs = {}
labels = {
"x": xaxis_title,
"y": yaxis_title,
}
if color is not None:
scatter_kwargs["color"] = (sampled_cases[color] if n_cases is not None else data[color]).astype(object)
labels["color"] = color

fig = px.scatter(
x=points[:, 0],
y=points[:, 1],
title=title,
labels=labels,
**scatter_kwargs
)
return fig
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ classifiers = [
"Programming Language :: Python :: 3"
]
dependencies = [
"seaborn~=0.12.0",
"howso-engine~=25.0",
"plotly",
"scipy"
"scipy",
"seaborn~=0.12.0",
"umap-learn~=0.5",
]

[project.optional-dependencies]
dev = [
"flake8",
"howso-engine",
"isort",
"pytest",
"pytest-cov",
Expand Down
Loading

0 comments on commit d9ffb0b

Please sign in to comment.