Skip to content

Commit

Permalink
AutoProp -.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ChanLumerico committed Jul 21, 2024
1 parent 48b0a63 commit 40f4bc6
Showing 1 changed file with 25 additions and 28 deletions.
53 changes: 25 additions & 28 deletions luma/neural/autoprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@ class LayerNode:
def __init__(
self,
layer: LayerLike,
prev_nodes: List[LayerLike] = [],
next_nodes: List[LayerLike] = [],
merge_mode: Literal["chcat", "sum"] = "chcat",
merge_mode: Literal["chcat", "sum"] = "sum",
name: str | None = None,
) -> None:
self.layer: LayerLike = layer
self.prev_nodes: List[LayerNode] = prev_nodes
self.next_nodes: List[LayerNode] = next_nodes
self.prev_nodes: List[LayerNode] = []
self.next_nodes: List[LayerNode] = []
self.merge_mode = merge_mode
self.name = name

Expand Down Expand Up @@ -82,20 +80,20 @@ def flush(self) -> None:

def __call__(self, is_train: bool = False) -> TensorLike:
return self.forward(is_train)

def __str__(self) -> str:
if self.name is None:
return type(self).__name__
return self.name

def __repr__(self) -> str:
return f"({str(self)}: {self.layer})"

def __eq__(self, other: object) -> bool:
if not isinstance(other, LayerNode):
return False
return self.name == other.name and self.layer == other.layer

def __hash__(self) -> int:
return hash((self.name, self.layer))

Expand All @@ -118,20 +116,21 @@ def build(self) -> None:
all_nodes = set(self.graph.keys())
for kn, vn in self.graph.items():
kn.next_nodes.extend(vn)
for v in vn:
for v in vn:
v.prev_nodes.append(kn)
all_nodes.add(v)

self.nodes = list(all_nodes)

visited = set()
def dfs(node: LayerNode):

def _dfs(node: LayerNode) -> None:
if node in visited:
return
visited.add(node)
for next_node in node.next_nodes:
dfs(next_node)
dfs(self.root)
_dfs(next_node)

_dfs(self.root)

if visited != set(self.nodes):
raise RuntimeError(f"'{self}' is not fully connected!")
Expand All @@ -140,31 +139,30 @@ def dfs(node: LayerNode):
raise RuntimeError(f"'{self}' contains a cycle!")

self.built = True
return

def detect_cycle(self) -> bool:
visited = set()
rec_stack = set()

def visit(node: LayerNode) -> bool:
def _visit(node: LayerNode) -> bool:
if node in rec_stack:
return True
if node in visited:
return False

visited.add(node)
rec_stack.add(node)

for next_node in node.next_nodes:
if visit(next_node):
if _visit(next_node):
return True

rec_stack.remove(node)
return False

for node in self.nodes:
if visit(node):
if _visit(node):
return True

return False

def forward(self, X: TensorLike, is_train: bool = False) -> TensorLike:
Expand All @@ -187,15 +185,14 @@ def _forward_bfs(self, X: TensorLike, is_train: bool) -> TensorLike:
self.root.f_visited = True

while queue:
cur = queue.pop()
cur = queue.popleft()
X = cur(is_train)
print(X.shape)

for next in cur.next_nodes:
if next.f_visited:
continue
next.for_enqueue(X)
next.f_visited = True
queue.append(next)
if not next.f_visited:
next.f_visited = True
queue.append(next)

return X

Expand All @@ -207,12 +204,12 @@ def _backward_bfs(self, d_out: TensorLike) -> TensorLike:
while queue:
cur = queue.pop()
d_out_arr = cur.backward()

for prev, dx in zip(cur.prev_nodes, d_out_arr):
if prev.b_visited:
continue
prev.back_enqueue(dx)
prev.b_visited = True
queue.append(prev)
if not prev.b_visited:
prev.b_visited = True
queue.append(prev)

cur.flush()

Expand Down

0 comments on commit 40f4bc6

Please sign in to comment.