Skip to content

Commit

Permalink
first vacay
Browse files Browse the repository at this point in the history
  • Loading branch information
ChanLumerico committed Aug 30, 2024
1 parent 1c17892 commit 26e87d6
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions luma/neural/autoprop/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ class MergeMode(Enum):
CHCAT = "chcat"
SUM = "sum"
HADAMARD = "hadamard"
AVERAGE = "average"
AVG = "avg"
MAX = "max"
MIN = "min"
DOT = "dot"
SUBTRACT = "subtract"
SUB = "sub"

def forward(self, f_queue: list[TensorLike]) -> TensorLike:
match self:
Expand All @@ -28,7 +28,7 @@ def forward(self, f_queue: list[TensorLike]) -> TensorLike:
X *= tensor
return X

case MergeMode.AVERAGE:
case MergeMode.AVG:
return np.mean(f_queue, axis=0)

case MergeMode.MAX:
Expand All @@ -40,7 +40,7 @@ def forward(self, f_queue: list[TensorLike]) -> TensorLike:
case MergeMode.DOT:
return np.dot(f_queue[0], f_queue[1])

case MergeMode.SUBTRACT:
case MergeMode.SUB:
result = f_queue[0]
for tensor in f_queue[1:]:
result -= tensor
Expand All @@ -66,7 +66,7 @@ def backward(
prod_except_current *= f_queue[j]
return d_out * prod_except_current

case MergeMode.AVERAGE:
case MergeMode.AVG:
return d_out / len(f_queue)

case MergeMode.MAX | MergeMode.MIN:
Expand All @@ -80,5 +80,5 @@ def backward(
elif i == 1:
return np.dot(f_queue[0].T, d_out)

case MergeMode.SUBTRACT:
case MergeMode.SUB:
return d_out if i == 0 else -d_out

0 comments on commit 26e87d6

Please sign in to comment.