diff --git a/luma/__import__.py b/luma/__import__.py index 2ccce3a..9d47150 100644 --- a/luma/__import__.py +++ b/luma/__import__.py @@ -183,8 +183,8 @@ Xception, MobileNet_V1, MobileNet_V2, - MobileNet_V3_Small, - MobileNet_V3_Large, + MobileNet_V3_S, + MobileNet_V3_L, SE_ResNet_50, SE_ResNet_152, SE_Inception_ResNet_V2, @@ -298,9 +298,13 @@ ClusterPlot, ROCCurve, PrecisionRecallCurve, + ConfusionMatrix, + ResidualPlot, + LearningCurve, + ValidationCurve, + InertiaPlot, ) -from luma.visual.eval import ConfusionMatrix, ResidualPlot, LearningCurve -from luma.visual.eval import ValidationCurve, InertiaPlot +from luma.visual.neural import ModelScatterPlot from luma.migrate.port import ModelPorter @@ -415,8 +419,8 @@ ResNet_18, ResNet_34, ResNet_50, ResNet_101, ResNet_152, ResNet_200, ResNet_1001, Xception, - MobileNet_V1, MobileNet_V2, MobileNet_V3_Small, - MobileNet_V3_Large, + MobileNet_V1, MobileNet_V2, MobileNet_V3_S, + MobileNet_V3_L, SE_ResNet_50, SE_ResNet_152, SE_Inception_ResNet_V2, DenseNet_121, DenseNet_169, DenseNet_201, DenseNet_264, EfficientNet_B0, EfficientNet_B1, EfficientNet_B2, @@ -502,5 +506,7 @@ ConfusionMatrix, ResidualPlot, LearningCurve, ValidationCurve, InertiaPlot + ModelScatterPlot + # ------------------ [ luma.migrate ] ---------------------- ModelPorter diff --git a/luma/neural b/luma/neural index 710a3da..91600eb 160000 --- a/luma/neural +++ b/luma/neural @@ -1 +1 @@ -Subproject commit 710a3daa8684f9550f61f324287c291e1f2d438c +Subproject commit 91600ebcf1b93d16fd4f1011835f5de638054628 diff --git a/luma/visual/neural.py b/luma/visual/neural.py index 0a0963a..5042e04 100644 --- a/luma/visual/neural.py +++ b/luma/visual/neural.py @@ -4,12 +4,42 @@ from luma.core.super import Visualizer from luma.neural.base import NeuralModel -from luma.neural.model import get_model, load_model_registry +from luma.neural.model import get_model, load_model_registry, load_entire_registry + + +__all__ = ("ModelScatterPlot", "ModelFamilyPlot") FormattedKey = str +def _get_key_value(reg: dict, key: FormattedKey) -> Any: + value = reg + split_key = key.split(":") + for k in split_key: + value = value[k] + return value + + +def _scale(data: list[int | float]) -> list[int | float]: + return [d / min(data) * 20 for d in data] + + +def _format_number(num: float, decimals: int = 1) -> str: + suffixes = ["K", "M", "B", "T", "P", "E", "Z", "Y"] + + magnitude = 0 + while abs(num) >= 1000 and magnitude < len(suffixes): + num /= 1000.0 + magnitude += 1 + + formatted_num = f"{num:.{decimals}f}" + if magnitude > 0: + formatted_num += suffixes[magnitude - 1] + + return formatted_num + + class ModelScatterPlot(Visualizer): def __init__( self, @@ -35,14 +65,14 @@ def __init__( self.x_data, self.y_data, self.s_data = [], [], [] - model_regs = [load_model_registry(m) for m in self.model_names] - for reg in model_regs: - x_val = self._get_key_value(reg, self.x_axis) - y_val = self._get_key_value(reg, self.y_axis) + self.model_regs = [load_model_registry(m) for m in self.model_names] + for reg in self.model_regs: + x_val = _get_key_value(reg, self.x_axis) + y_val = _get_key_value(reg, self.y_axis) s_val = None if self.s_key is not None: - s_val = self._get_key_value(reg, self.s_key) + s_val = _get_key_value(reg, self.s_key) if isinstance(x_val, dict) or isinstance(y_val, dict): raise ValueError( @@ -56,16 +86,6 @@ def __init__( if s_val is not None and isinstance(s_val, (int, float)): self.s_data.append(s_val) - def _get_key_value(self, reg: dict, key: FormattedKey) -> Any: - value = reg - split_key = key.split(":") - for k in split_key: - value = value[k] - return value - - def _scale(self, data: list[int | float]) -> list[int | float]: - return [d / min(data) * 20 for d in data] - def plot( self, ax: Optional[plt.Axes] = None, @@ -74,7 +94,7 @@ def plot( cmap: str = "viridis", scale_size: bool = False, grid: bool = True, - title: Literal["auto"] | str = "auto", # handle this further + title: Optional[str] = None, show: bool = False, ) -> plt.Axes: if ax is None: @@ -83,24 +103,34 @@ def plot( size_arr = self.s_data if self.s_data else None sc = ax.scatter( - self.x_data, - self.y_data, - s=self._scale(size_arr) if scale_size else size_arr, - c=size_arr, - marker="o", - cmap=cmap, + self.x_data, + self.y_data, + s=_scale(size_arr) if scale_size else size_arr, + c=size_arr if size_arr else self.y_data, + marker="o", + cmap=cmap, alpha=0.7, ) for x, y, name in zip(self.x_data, self.y_data, self.model_names): ax.text( - x, - y, - name, - fontsize="x-small", - alpha=0.8, - horizontalalignment="center", - verticalalignment="center", + x, + y, + name, + fontsize="x-small", + alpha=0.8, + ha="center", + va="bottom", + ) + for x, y, reg in zip(self.x_data, self.y_data, self.model_regs): + ax.text( + x, + y, + _format_number(reg["params"]), + fontsize="x-small", + alpha=0.5, + ha="center", + va="top", ) ax.set_xscale(x_scale) @@ -108,16 +138,62 @@ def plot( ax.set_xlabel(self.x_axis.split(":")[0]) ax.set_ylabel(self.y_axis.split(":")[0]) - ax.set_title(title) + if title is not None: + ax.set_title(title) if grid: ax.grid(alpha=0.2) - + cbar = ax.figure.colorbar(sc) - cbar.set_label(self.s_key.split(":")[0]) - ax.figure.tight_layout() + if self.s_key is not None: + cbar.set_label(self.s_key.split(":")[0]) + ax.figure.tight_layout() if show: plt.show() - plt.savefig("test") # Remove this later return ax + + +class ModelFamilyPlot(Visualizer): + def __init__( + self, + families: list[str], + x_axis: FormattedKey, + y_axis: FormattedKey, + ) -> None: + self.x_axis = x_axis + self.y_axis = y_axis + + self.model_families = [] + reg_json = load_entire_registry() + + for fam_name in families: + model_reg_arr = [] + + for reg_dict in reg_json: + reg_fam_name: str = reg_dict["family"] + if reg_fam_name is None: + continue + + alt_fam_name = reg_fam_name.lower().replace("_", "-") + if reg_fam_name == fam_name or alt_fam_name == fam_name: + model_reg_arr.append(reg_dict) + + self.model_families.append(model_reg_arr) + + self.x_data_arr, self.y_data_arr = [], [] + for family_reg in self.model_families: + x_data, y_data = [], [] + + for model_reg in family_reg: + x_val = _get_key_value(model_reg, self.x_axis) + y_val = _get_key_value(model_reg, self.y_axis) + + x_data.append(x_val) + y_data.append(y_val) + + self.x_data_arr.append(x_data) + self.y_data_arr.append(y_data) + + def plot(self, *args) -> None: + return super().plot(*args)