Skip to content

Commit

Permalink
ModelFamilyPlot wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ChanLumerico authored Oct 8, 2024
1 parent a752ce1 commit 69a3cdf
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 42 deletions.
18 changes: 12 additions & 6 deletions luma/__import__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@
Xception,
MobileNet_V1,
MobileNet_V2,
MobileNet_V3_Small,
MobileNet_V3_Large,
MobileNet_V3_S,
MobileNet_V3_L,
SE_ResNet_50,
SE_ResNet_152,
SE_Inception_ResNet_V2,
Expand Down Expand Up @@ -298,9 +298,13 @@
ClusterPlot,
ROCCurve,
PrecisionRecallCurve,
ConfusionMatrix,
ResidualPlot,
LearningCurve,
ValidationCurve,
InertiaPlot,
)
from luma.visual.eval import ConfusionMatrix, ResidualPlot, LearningCurve
from luma.visual.eval import ValidationCurve, InertiaPlot
from luma.visual.neural import ModelScatterPlot

from luma.migrate.port import ModelPorter

Expand Down Expand Up @@ -415,8 +419,8 @@
ResNet_18, ResNet_34, ResNet_50, ResNet_101, ResNet_152,
ResNet_200, ResNet_1001,
Xception,
MobileNet_V1, MobileNet_V2, MobileNet_V3_Small,
MobileNet_V3_Large,
MobileNet_V1, MobileNet_V2, MobileNet_V3_S,
MobileNet_V3_L,
SE_ResNet_50, SE_ResNet_152, SE_Inception_ResNet_V2,
DenseNet_121, DenseNet_169, DenseNet_201, DenseNet_264,
EfficientNet_B0, EfficientNet_B1, EfficientNet_B2,
Expand Down Expand Up @@ -502,5 +506,7 @@
ConfusionMatrix, ResidualPlot, LearningCurve,
ValidationCurve, InertiaPlot

ModelScatterPlot

# ------------------ [ luma.migrate ] ----------------------
ModelPorter
2 changes: 1 addition & 1 deletion luma/neural
146 changes: 111 additions & 35 deletions luma/visual/neural.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,42 @@

from luma.core.super import Visualizer
from luma.neural.base import NeuralModel
from luma.neural.model import get_model, load_model_registry
from luma.neural.model import get_model, load_model_registry, load_entire_registry


__all__ = ("ModelScatterPlot", "ModelFamilyPlot")


FormattedKey = str


def _get_key_value(reg: dict, key: FormattedKey) -> Any:
value = reg
split_key = key.split(":")
for k in split_key:
value = value[k]
return value


def _scale(data: list[int | float]) -> list[int | float]:
return [d / min(data) * 20 for d in data]


def _format_number(num: float, decimals: int = 1) -> str:
suffixes = ["K", "M", "B", "T", "P", "E", "Z", "Y"]

magnitude = 0
while abs(num) >= 1000 and magnitude < len(suffixes):
num /= 1000.0
magnitude += 1

formatted_num = f"{num:.{decimals}f}"
if magnitude > 0:
formatted_num += suffixes[magnitude - 1]

return formatted_num


class ModelScatterPlot(Visualizer):
def __init__(
self,
Expand All @@ -35,14 +65,14 @@ def __init__(

self.x_data, self.y_data, self.s_data = [], [], []

model_regs = [load_model_registry(m) for m in self.model_names]
for reg in model_regs:
x_val = self._get_key_value(reg, self.x_axis)
y_val = self._get_key_value(reg, self.y_axis)
self.model_regs = [load_model_registry(m) for m in self.model_names]
for reg in self.model_regs:
x_val = _get_key_value(reg, self.x_axis)
y_val = _get_key_value(reg, self.y_axis)

s_val = None
if self.s_key is not None:
s_val = self._get_key_value(reg, self.s_key)
s_val = _get_key_value(reg, self.s_key)

if isinstance(x_val, dict) or isinstance(y_val, dict):
raise ValueError(
Expand All @@ -56,16 +86,6 @@ def __init__(
if s_val is not None and isinstance(s_val, (int, float)):
self.s_data.append(s_val)

def _get_key_value(self, reg: dict, key: FormattedKey) -> Any:
value = reg
split_key = key.split(":")
for k in split_key:
value = value[k]
return value

def _scale(self, data: list[int | float]) -> list[int | float]:
return [d / min(data) * 20 for d in data]

def plot(
self,
ax: Optional[plt.Axes] = None,
Expand All @@ -74,7 +94,7 @@ def plot(
cmap: str = "viridis",
scale_size: bool = False,
grid: bool = True,
title: Literal["auto"] | str = "auto", # handle this further
title: Optional[str] = None,
show: bool = False,
) -> plt.Axes:
if ax is None:
Expand All @@ -83,41 +103,97 @@ def plot(

size_arr = self.s_data if self.s_data else None
sc = ax.scatter(
self.x_data,
self.y_data,
s=self._scale(size_arr) if scale_size else size_arr,
c=size_arr,
marker="o",
cmap=cmap,
self.x_data,
self.y_data,
s=_scale(size_arr) if scale_size else size_arr,
c=size_arr if size_arr else self.y_data,
marker="o",
cmap=cmap,
alpha=0.7,
)

for x, y, name in zip(self.x_data, self.y_data, self.model_names):
ax.text(
x,
y,
name,
fontsize="x-small",
alpha=0.8,
horizontalalignment="center",
verticalalignment="center",
x,
y,
name,
fontsize="x-small",
alpha=0.8,
ha="center",
va="bottom",
)
for x, y, reg in zip(self.x_data, self.y_data, self.model_regs):
ax.text(
x,
y,
_format_number(reg["params"]),
fontsize="x-small",
alpha=0.5,
ha="center",
va="top",
)

ax.set_xscale(x_scale)
ax.set_yscale(y_scale)

ax.set_xlabel(self.x_axis.split(":")[0])
ax.set_ylabel(self.y_axis.split(":")[0])
ax.set_title(title)
if title is not None:
ax.set_title(title)

if grid:
ax.grid(alpha=0.2)

cbar = ax.figure.colorbar(sc)
cbar.set_label(self.s_key.split(":")[0])
ax.figure.tight_layout()
if self.s_key is not None:
cbar.set_label(self.s_key.split(":")[0])

ax.figure.tight_layout()
if show:
plt.show()
plt.savefig("test") # Remove this later
return ax


class ModelFamilyPlot(Visualizer):
def __init__(
self,
families: list[str],
x_axis: FormattedKey,
y_axis: FormattedKey,
) -> None:
self.x_axis = x_axis
self.y_axis = y_axis

self.model_families = []
reg_json = load_entire_registry()

for fam_name in families:
model_reg_arr = []

for reg_dict in reg_json:
reg_fam_name: str = reg_dict["family"]
if reg_fam_name is None:
continue

alt_fam_name = reg_fam_name.lower().replace("_", "-")
if reg_fam_name == fam_name or alt_fam_name == fam_name:
model_reg_arr.append(reg_dict)

self.model_families.append(model_reg_arr)

self.x_data_arr, self.y_data_arr = [], []
for family_reg in self.model_families:
x_data, y_data = [], []

for model_reg in family_reg:
x_val = _get_key_value(model_reg, self.x_axis)
y_val = _get_key_value(model_reg, self.y_axis)

x_data.append(x_val)
y_data.append(y_val)

self.x_data_arr.append(x_data)
self.y_data_arr.append(y_data)

def plot(self, *args) -> None:
return super().plot(*args)

0 comments on commit 69a3cdf

Please sign in to comment.