Skip to content

Commit

Permalink
@force_shape wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ChanLumerico committed Aug 10, 2024
1 parent 693cf60 commit 28c580a
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 4 deletions.
73 changes: 71 additions & 2 deletions luma/interface/typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from functools import wraps
from typing import Any, Callable, Generic, NoReturn, Self, Type, TypeVar
from collections import defaultdict
from typing import (
Any,
Callable,
Generic,
NoReturn,
Self,
Type,
TypeVar,
)
import sys
import numpy as np

Expand Down Expand Up @@ -162,7 +171,9 @@ def wrapper(self, *args: Any, **kwargs: Any) -> Any:
if param_name in all_args:
tensor = all_args[param_name]
if not isinstance(tensor, (Tensor, np.ndarray)):
raise TypeError(f"'{param_name}' must be of type Tensor.")
raise TypeError(
f"'{param_name}' must be an insatnce of Tensor.",
)
if tensor.ndim != n_dim:
raise ValueError(
f"'{param_name}' must be {n_dim}D-tensor",
Expand All @@ -175,6 +186,64 @@ def wrapper(self, *args: Any, **kwargs: Any) -> Any:

return decorator

@classmethod
def force_shape(cls, *shape_consts: tuple[int]) -> Callable:

def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(self, *args: Any, **kwargs: Any) -> Any:
arg_names = func.__code__.co_varnames
all_args = {**dict(zip(arg_names, (self,) + args)), **kwargs}

mismatch_dict = defaultdict(lambda: np.empty((0, 3), dtype=np.int64))
for i, shape in enumerate(shape_consts):
param_name = arg_names[i + 1]

if param_name in all_args:
tensor = all_args[param_name]
if not isinstance(tensor, (Tensor, np.ndarray)):
raise TypeError(
f"'{param_name}' must be an instance of Tensor.",
)

if tensor.ndim != len(shape):
raise ValueError(
f"Dimensionalities of '{param_name}' and"
+ f" the constraint '{shape}' does not match!"
)

for axis, (s, ts) in enumerate(zip(shape, tensor.shape)):
if s == -1:
continue
if s != ts:
np.vstack((mismatch_dict[param_name], [axis, s, ts]))

if len(mismatch_dict):
title = f"{"Argument":^15} {"Axes":^15} {"Expexted":^15} {"Shape":^15}"
msg = str()

for name in mismatch_dict.keys():
errmat = mismatch_dict[name]

axes = tuple(errmat[:, 0])
expect = tuple(errmat[:, 1])
got = tuple(errmat[:, 2])

msg += f"{name:^15} {axes:^15} {expect:^15} {got:^15}\n"

raise ValueError(
f"Shape mismatch(es) detected as follows:"
+ f"\n{title}"
+ f"\n{"-" * 63}"
+ f"\n{msg}",
)

return func(self, *args, **kwargs)

return wrapper

return decorator


class Scalar:
"""
Expand Down
3 changes: 1 addition & 2 deletions luma/neural/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ def new_learning_rate(self) -> float:
else:
new_lr = (
self.max_lr
- (self.max_lr - self.init_lr / self.final_div_factor)
* factor
- (self.max_lr - self.init_lr / self.final_div_factor) * factor
)

self.lr_trace.append(new_lr)
Expand Down

0 comments on commit 28c580a

Please sign in to comment.