Skip to content

Commit

Permalink
@force_shape finished
Browse files Browse the repository at this point in the history
  • Loading branch information
ChanLumerico committed Aug 10, 2024
1 parent 28c580a commit daf79a5
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions luma/interface/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def wrapper(self, *args: Any, **kwargs: Any) -> Any:
arg_names = func.__code__.co_varnames
all_args = {**dict(zip(arg_names, (self,) + args)), **kwargs}

mismatch_dict = defaultdict(lambda: np.empty((0, 3), dtype=np.int64))
mismatch_dict = defaultdict(lambda: np.empty((0, 3)))
for i, shape in enumerate(shape_consts):
param_name = arg_names[i + 1]

Expand All @@ -216,25 +216,32 @@ def wrapper(self, *args: Any, **kwargs: Any) -> Any:
if s == -1:
continue
if s != ts:
np.vstack((mismatch_dict[param_name], [axis, s, ts]))
mismatch_dict[param_name] = np.vstack(
(mismatch_dict[param_name], [axis, s, ts])
)

def _tuplize(vec: Vector):
return tuple(int(v) for v in vec)

if len(mismatch_dict):
title = f"{"Argument":^15} {"Axes":^15} {"Expexted":^15} {"Shape":^15}"
title = (
f"{"Argument":^14} {"Axes":^14} {"Expected":^14} {"Shape":^14}"
)
msg = str()

for name in mismatch_dict.keys():
errmat = mismatch_dict[name]

axes = tuple(errmat[:, 0])
expect = tuple(errmat[:, 1])
got = tuple(errmat[:, 2])
axes = str(_tuplize(errmat[:, 0]))
expect = str(_tuplize(errmat[:, 1]))
got = str(_tuplize(errmat[:, 2]))

msg += f"{name:^15} {axes:^15} {expect:^15} {got:^15}\n"
msg += f"{name:^14} {axes:<14} {expect:<14} {got:<14}\n"

raise ValueError(
f"Shape mismatch(es) detected as follows:"
+ f"\n{title}"
+ f"\n{"-" * 63}"
+ f"\n{"-" * (14 * 4 + 3)}"
+ f"\n{msg}",
)

Expand Down

0 comments on commit daf79a5

Please sign in to comment.