Skip to content

Commit

Permalink
ModelScatterPlot wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ChanLumerico authored Oct 7, 2024
1 parent 5656110 commit a752ce1
Showing 1 changed file with 123 additions and 0 deletions.
123 changes: 123 additions & 0 deletions luma/visual/neural.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Any, List, Literal, Optional, Type
import matplotlib.pyplot as plt
import numpy as np

from luma.core.super import Visualizer
from luma.neural.base import NeuralModel
from luma.neural.model import get_model, load_model_registry


FormattedKey = str


class ModelScatterPlot(Visualizer):
def __init__(
self,
models: List[Type[NeuralModel] | str],
x_axis: FormattedKey,
y_axis: FormattedKey,
s_key: FormattedKey | None = None,
) -> None:
self.x_axis = x_axis
self.y_axis = y_axis
self.s_key = s_key

self.models = []
self.model_names = []

for model in models:
model_type = model if isinstance(model, type) else get_model(model)
if model_type is None:
raise ValueError(f"'{model}' is an invalid model!")

self.models.append(model_type)
self.model_names.append(model_type.__name__)

self.x_data, self.y_data, self.s_data = [], [], []

model_regs = [load_model_registry(m) for m in self.model_names]
for reg in model_regs:
x_val = self._get_key_value(reg, self.x_axis)
y_val = self._get_key_value(reg, self.y_axis)

s_val = None
if self.s_key is not None:
s_val = self._get_key_value(reg, self.s_key)

if isinstance(x_val, dict) or isinstance(y_val, dict):
raise ValueError(
f"Key pair '{self.x_axis}, {self.y_axis}' is"
+ f" invalid for the model '{reg["name"]}'!"
)

self.x_data.append(x_val)
self.y_data.append(y_val)

if s_val is not None and isinstance(s_val, (int, float)):
self.s_data.append(s_val)

def _get_key_value(self, reg: dict, key: FormattedKey) -> Any:
value = reg
split_key = key.split(":")
for k in split_key:
value = value[k]
return value

def _scale(self, data: list[int | float]) -> list[int | float]:
return [d / min(data) * 20 for d in data]

def plot(
self,
ax: Optional[plt.Axes] = None,
x_scale: str = "linear",
y_scale: str = "linear",
cmap: str = "viridis",
scale_size: bool = False,
grid: bool = True,
title: Literal["auto"] | str = "auto", # handle this further
show: bool = False,
) -> plt.Axes:
if ax is None:
_, ax = plt.subplots()
show = True

size_arr = self.s_data if self.s_data else None
sc = ax.scatter(
self.x_data,
self.y_data,
s=self._scale(size_arr) if scale_size else size_arr,
c=size_arr,
marker="o",
cmap=cmap,
alpha=0.7,
)

for x, y, name in zip(self.x_data, self.y_data, self.model_names):
ax.text(
x,
y,
name,
fontsize="x-small",
alpha=0.8,
horizontalalignment="center",
verticalalignment="center",
)

ax.set_xscale(x_scale)
ax.set_yscale(y_scale)

ax.set_xlabel(self.x_axis.split(":")[0])
ax.set_ylabel(self.y_axis.split(":")[0])
ax.set_title(title)

if grid:
ax.grid(alpha=0.2)

cbar = ax.figure.colorbar(sc)
cbar.set_label(self.s_key.split(":")[0])
ax.figure.tight_layout()

if show:
plt.show()
plt.savefig("test") # Remove this later
return ax

0 comments on commit a752ce1

Please sign in to comment.