Skip to content

Commit

Permalink
Merge pull request #4 from ChanLumerico:force-shape
Browse files Browse the repository at this point in the history
Force-shape
  • Loading branch information
ChanLumerico authored Aug 10, 2024
2 parents 693cf60 + daf79a5 commit 2cbcff6
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 4 deletions.
80 changes: 78 additions & 2 deletions luma/interface/typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from functools import wraps
from typing import Any, Callable, Generic, NoReturn, Self, Type, TypeVar
from collections import defaultdict
from typing import (
Any,
Callable,
Generic,
NoReturn,
Self,
Type,
TypeVar,
)
import sys
import numpy as np

Expand Down Expand Up @@ -162,7 +171,9 @@ def wrapper(self, *args: Any, **kwargs: Any) -> Any:
if param_name in all_args:
tensor = all_args[param_name]
if not isinstance(tensor, (Tensor, np.ndarray)):
raise TypeError(f"'{param_name}' must be of type Tensor.")
raise TypeError(
f"'{param_name}' must be an insatnce of Tensor.",
)
if tensor.ndim != n_dim:
raise ValueError(
f"'{param_name}' must be {n_dim}D-tensor",
Expand All @@ -175,6 +186,71 @@ def wrapper(self, *args: Any, **kwargs: Any) -> Any:

return decorator

@classmethod
def force_shape(cls, *shape_consts: tuple[int]) -> Callable:

def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(self, *args: Any, **kwargs: Any) -> Any:
arg_names = func.__code__.co_varnames
all_args = {**dict(zip(arg_names, (self,) + args)), **kwargs}

mismatch_dict = defaultdict(lambda: np.empty((0, 3)))
for i, shape in enumerate(shape_consts):
param_name = arg_names[i + 1]

if param_name in all_args:
tensor = all_args[param_name]
if not isinstance(tensor, (Tensor, np.ndarray)):
raise TypeError(
f"'{param_name}' must be an instance of Tensor.",
)

if tensor.ndim != len(shape):
raise ValueError(
f"Dimensionalities of '{param_name}' and"
+ f" the constraint '{shape}' does not match!"
)

for axis, (s, ts) in enumerate(zip(shape, tensor.shape)):
if s == -1:
continue
if s != ts:
mismatch_dict[param_name] = np.vstack(
(mismatch_dict[param_name], [axis, s, ts])
)

def _tuplize(vec: Vector):
return tuple(int(v) for v in vec)

if len(mismatch_dict):
title = (
f"{"Argument":^14} {"Axes":^14} {"Expected":^14} {"Shape":^14}"
)
msg = str()

for name in mismatch_dict.keys():
errmat = mismatch_dict[name]

axes = str(_tuplize(errmat[:, 0]))
expect = str(_tuplize(errmat[:, 1]))
got = str(_tuplize(errmat[:, 2]))

msg += f"{name:^14} {axes:<14} {expect:<14} {got:<14}\n"

raise ValueError(
f"Shape mismatch(es) detected as follows:"
+ f"\n{title}"
+ f"\n{"-" * (14 * 4 + 3)}"
+ f"\n{msg}",
)

return func(self, *args, **kwargs)

return wrapper

return decorator


class Scalar:
"""
Expand Down
3 changes: 1 addition & 2 deletions luma/neural/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ def new_learning_rate(self) -> float:
else:
new_lr = (
self.max_lr
- (self.max_lr - self.init_lr / self.final_div_factor)
* factor
- (self.max_lr - self.init_lr / self.final_div_factor) * factor
)

self.lr_trace.append(new_lr)
Expand Down

0 comments on commit 2cbcff6

Please sign in to comment.