diff --git a/luma/interface/typing.py b/luma/interface/typing.py index 8eb35f4..7594231 100644 --- a/luma/interface/typing.py +++ b/luma/interface/typing.py @@ -1,5 +1,14 @@ from functools import wraps -from typing import Any, Callable, Generic, NoReturn, Self, Type, TypeVar +from collections import defaultdict +from typing import ( + Any, + Callable, + Generic, + NoReturn, + Self, + Type, + TypeVar, +) import sys import numpy as np @@ -162,7 +171,9 @@ def wrapper(self, *args: Any, **kwargs: Any) -> Any: if param_name in all_args: tensor = all_args[param_name] if not isinstance(tensor, (Tensor, np.ndarray)): - raise TypeError(f"'{param_name}' must be of type Tensor.") + raise TypeError( + f"'{param_name}' must be an insatnce of Tensor.", + ) if tensor.ndim != n_dim: raise ValueError( f"'{param_name}' must be {n_dim}D-tensor", @@ -175,6 +186,64 @@ def wrapper(self, *args: Any, **kwargs: Any) -> Any: return decorator + @classmethod + def force_shape(cls, *shape_consts: tuple[int]) -> Callable: + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + def wrapper(self, *args: Any, **kwargs: Any) -> Any: + arg_names = func.__code__.co_varnames + all_args = {**dict(zip(arg_names, (self,) + args)), **kwargs} + + mismatch_dict = defaultdict(lambda: np.empty((0, 3), dtype=np.int64)) + for i, shape in enumerate(shape_consts): + param_name = arg_names[i + 1] + + if param_name in all_args: + tensor = all_args[param_name] + if not isinstance(tensor, (Tensor, np.ndarray)): + raise TypeError( + f"'{param_name}' must be an instance of Tensor.", + ) + + if tensor.ndim != len(shape): + raise ValueError( + f"Dimensionalities of '{param_name}' and" + + f" the constraint '{shape}' does not match!" + ) + + for axis, (s, ts) in enumerate(zip(shape, tensor.shape)): + if s == -1: + continue + if s != ts: + np.vstack((mismatch_dict[param_name], [axis, s, ts])) + + if len(mismatch_dict): + title = f"{"Argument":^15} {"Axes":^15} {"Expexted":^15} {"Shape":^15}" + msg = str() + + for name in mismatch_dict.keys(): + errmat = mismatch_dict[name] + + axes = tuple(errmat[:, 0]) + expect = tuple(errmat[:, 1]) + got = tuple(errmat[:, 2]) + + msg += f"{name:^15} {axes:^15} {expect:^15} {got:^15}\n" + + raise ValueError( + f"Shape mismatch(es) detected as follows:" + + f"\n{title}" + + f"\n{"-" * 63}" + + f"\n{msg}", + ) + + return func(self, *args, **kwargs) + + return wrapper + + return decorator + class Scalar: """ diff --git a/luma/neural/scheduler.py b/luma/neural/scheduler.py index 0eb43bd..b8f9fc4 100644 --- a/luma/neural/scheduler.py +++ b/luma/neural/scheduler.py @@ -165,8 +165,7 @@ def new_learning_rate(self) -> float: else: new_lr = ( self.max_lr - - (self.max_lr - self.init_lr / self.final_div_factor) - * factor + - (self.max_lr - self.init_lr / self.final_div_factor) * factor ) self.lr_trace.append(new_lr)