Skip to content

Commit

Permalink
Update analyzer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yilin-bao authored Dec 10, 2023
1 parent 7e84f47 commit a914726
Showing 1 changed file with 215 additions and 43 deletions.
258 changes: 215 additions & 43 deletions analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from typing import Any
from matplotlib.pylab import pareto
from numpy import isin, var
from sqlalchemy import String
from sympy import false
from torch import rand
import torch.nn as nn
import torch
import random
import string
import numpy as np
Expand Down Expand Up @@ -72,6 +74,11 @@ class Color:
# ANSI escape code to reset text attributes to default
END = '\033[0m'

def print_ast_of_node(node):
code = astor.to_source(node)
tree = ast.parse(code)
ast_code = astor.to_source(tree)
print(f'{Color.BOLD_LIGHT_GRAY}{tree}{Color.END}')

def generate_random_string(length=8):
characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
Expand Down Expand Up @@ -235,7 +242,7 @@ def start_analyze_module(self, module:nn.Module):
self.all_parameters = dict(module.named_parameters())
self.analyze_module("self", module)

def analyze_module(self, var_name, module:nn.Module):
def analyze_module(self, var_name, module):
'''
Recursively analyzes the structure of PyTorch modules and their nested children.
Expand Down Expand Up @@ -293,7 +300,7 @@ def analyze_defined_module(self, var_name, module_name, module, var_module_layer
# [var_module_layer] is the variable-module dictionary of current layer
analyzer = ModuleAstAnalyzer(var_module_layer, var_name, module_name)
analyzer.visit(module_ast)
self.print_module_map(analyzer.module_map)
# self.print_module_map(analyzer.module_map)
return 0

def update_module_map(self, var_name, replace_map):
Expand Down Expand Up @@ -410,12 +417,23 @@ def all_Attribute(self, targets):
#----------------------------------------

def find_full_name_array(self, nodes):
ret = []
for node in nodes:
node_name = self.find_full_name(node)
if node_name:
ret.append(node_name)
return ret
if not nodes == None:
ret = []
for node in nodes:
if isinstance(node, ast.Tuple):
for elt in node.elts:
elt_name = self.find_full_name(elt)
if elt_name:
ret.append(elt_name)
elif isinstance(node, String):
if 'self' in node:
ret.append(node)
else:
node_name = self.find_full_name(node)
if node_name:
ret.append(node_name)
return ret
return None

def find_full_name(self, node):
if isinstance(node, ast.Name):
Expand All @@ -426,22 +444,12 @@ def find_full_name(self, node):
return None

def remove_starting_self(self, node_name):
arr = node_name.split('.')
if arr[0] == 'self':
return '.'.join(arr[1:])
else:
return '.'.join(arr)

# def find_all_names(self, node_list):
# ret = []
# for node in node_list:
# if isinstance(node, ast.Name) or isinstance(node, ast.Attribute):
# ret.append(self.find_full_name(node))
# elif isinstance(node, ast.Constant):
# ret.append(str(node.value))
# elif isinstance(node, ast.BinOp):
# ret.append(astor.to_source(node))
# return ret
if node_name:
arr = node_name.split('.')
if arr[0] == 'self':
return '.'.join(arr[1:])
else:
return '.'.join(arr)

#--------------------------------------------------
#--Analyze special functions (such as forward())---
Expand Down Expand Up @@ -507,21 +515,152 @@ def analyze_net_ast(self, parents, this:Name):
parents = parents[::-1]
# ===
self.print_parents_and_code(this)
for parent in parents:
if isinstance(parent, ast.Assign):
pass
elif isinstance(parent, ast.Call):
pass
elif isinstance(parent, ast.Subscript):
pass
elif isinstance(parent, ast.For):
pass
elif isinstance(parent, ast.BinOp):
pass
if len(parents) == 1:
self.special_case_length_one(parents[0], this)
else:
current = None
intermediate = None
for parent in parents:
if isinstance(parent, ast.Assign):
intermediate = self.analyze_net_ast_assign(parent, this, current, intermediate)
elif isinstance(parent, ast.Call):
intermediate = self.analyze_net_ast_call(parent, this, current, intermediate)
elif isinstance(parent, ast.Attribute):
intermediate = self.analyze_net_ast_attribute(parent, this, current, intermediate)
elif isinstance(parent, ast.Subscript):
intermediate = self.analyze_net_ast_subscript(parent, this, current, intermediate)
elif isinstance(parent, ast.For):
intermediate = self.analyze_net_ast_for(parent, this, current, intermediate)
elif isinstance(parent, ast.BinOp):
intermediate = self.analyze_net_ast_binop(parent, this, current, intermediate)
elif isinstance(parent, ast.Tuple):
intermediate = self.analyze_net_ast_tuple(parent, this, current, intermediate)
current = parent
# ===
return 0

def special_case_length_one(self, parent, this:Name):
if isinstance(parent, ast.Attribute):
# [<ast.Attribute object at 0x12aa0f790>] nn.Module
# [<ast.Attribute object at 0x12aa0d4b0>] self.revised
# self.print_parents_and_code(this)
pass
elif isinstance(parent, ast.Assign) and parent.targets[0] == this:
# [<ast.Assign object at 0x130b28760>] x = self.embedding_layer(x)
# [<ast.Assign object at 0x130b28940>] x = self.transformer(x)
# [<ast.Assign object at 0x130b29a50>] x = self.post_transformer_ln(x)
# [<ast.Assign object at 0x130b29f30>] x = self.cls_layer(x)
# What we find is the variable [x] on the left side of assign, so ignore
# self.print_parents_and_code(this)
pass
elif isinstance(parent, ast.Assign) and not parent.targets[0] == this:
# seems is a case not gonna happen
# self.print_parents_and_code(this)
pass
elif isinstance(parent, ast.Call):
# self.print_parents_and_code(this)
op_name = self.from_node_to_operation(parent.func)
args = self.find_full_name_array(parent.args)
self.update_module_name(args, args, op_name)
pass
else:
# self.print_parents_and_code(this)
pass

def analyze_net_ast_assign(self, node:ast.Assign, this:Name, current, intermediate):
print("Start analyzation on a node type of ast.Assign")
op_name = "="
if current:
if node.value == current or node.targets[0] == current: # B, N, C = x.shape
op_out = self.find_full_name_array(node.targets)
if intermediate:
ret_hash = self.update_module_name_only_out([intermediate], op_out, op_name)[0]
return ret_hash
else:
pass
else:
pass
else:
pass
return intermediate

def analyze_net_ast_call(self, node:ast.Call, this:Name, current, intermediate):
print("Start analyzation on a node type of ast.Call")
if current:
if current in node.args:
op_name = self.from_node_to_operation(node.func)
ret_hash = self.update_module_name_hash_in([intermediate], op_name)[0]
return ret_hash
else:
return intermediate
else:
op_name = self.from_node_to_operation(node.func)
args = self.find_full_name_array(node.args)
ret_hash = self.update_module_name_only_in(args, op_name)[0]
return ret_hash
return intermediate

def analyze_net_ast_attribute(self, node:ast.Attribute, this:Name, current, intermediate):
print("Start analyzation on a node type of ast.Attribute")
op_name = node.attr
if current:
if node.value == current:
ret_hash = self.update_module_name_hash_in([intermediate], op_name)[0]
return ret_hash
else:
pass
else:
args = self.find_full_name_array([node.value])
ret_hash = self.update_module_name_only_in(args, op_name)[0]
return ret_hash
return intermediate

def analyze_net_ast_subscript(self, node:ast.Subscript, this:Name, current, intermediate):
print("Start analyzation on a node type of ast.Subscript")
if current:
pass
else:
op_name = node.slice
op_name = astor.to_source(op_name)
args = [this.id + f'[{op_name}]']
ret_hash = self.update_module_name_only_in(args, op_name)[0]
return ret_hash
return intermediate

def analyze_net_ast_for(self, node:ast.For, this:Name, current, intermediate):
print("Start analyzation on a node type of ast.For")
if current:
pass
else:
pass
return intermediate

def analyze_net_ast_binop(self, node:ast.BinOp, this:Name, current, intermediate):
print("Start analyzation on a node type of ast.BinOp")
op_name = node.op
if current:
pass
else:
args_left = self.find_full_name_array([node.left])
args_right = self.find_full_name_array([node.right])
print(args_left)
print(args_right)
ret_hash = self.update_module_name_only_in(args_left+args_right, op_name)[0]
return ret_hash
return intermediate

def analyze_net_ast_tuple(self, node:ast.Tuple, this:Name, current, intermediate):
print("Start analyzation on a node type of ast.BinOp")
op_name = "="
if current:
pass
else:
args = self.find_full_name_array(node.elts)
ret_hash = self.update_module_name_only_in(args, op_name)[0]
return ret_hash
return intermediate

def is_parameter_or_tensor(self, op_in_array:Array, op_out_array:Array, op_name):
def is_parameter_or_tensor(self, op_name):
return 0

def from_node_to_operation(self, node):
Expand All @@ -535,13 +674,24 @@ def from_node_to_operation(self, node):
def update_module_name_only_in(self, op_in_array:Array, op_name):
op_in_hash = []
for op_in in op_in_array:
old_hash = self.out_dict[op_in]
if op_in in self.out_dict:
old_hash = self.out_dict[op_in]
else:
old_hash = hash_code(op_in)
op_in_hash.append(old_hash)
middle_hash = hash_code()
op_out_hash = [middle_hash]
operation = (op_in_hash, op_out_hash, op_name)
if not operation in self.module_map:
self.module_map.append(operation)
self.module_map.append(operation)
self.print_operation(operation)
return [middle_hash]

def update_module_name_hash_in(self, op_in_hash:Array, op_name):
middle_hash = hash_code()
op_out_hash: list[str] = [middle_hash]
operation = (op_in_hash, op_out_hash, op_name)
self.module_map.append(operation)
self.print_operation(operation)
return [middle_hash]

def update_module_name_only_out(self, middle_hash, op_out_array:Array, op_name):
Expand All @@ -552,11 +702,15 @@ def update_module_name_only_out(self, middle_hash, op_out_array:Array, op_name):
self.hash_var_dict[new_hash] = op_out
op_out_hash.append(new_hash)
operation = (middle_hash, op_out_hash, op_name)
if not operation in self.module_map:
self.module_map.append(operation)
self.module_map.append(operation)
self.print_operation(operation)
return [op_out_hash]

def update_module_name(self, op_in_array:Array, op_out_array:Array, op_name):
if len(op_in_array) == 0:
return None
if len(op_out_array) == 0:
return None
op_in_hash = []
for op_in in op_in_array:
old_hash = self.out_dict[op_in]
Expand All @@ -568,10 +722,28 @@ def update_module_name(self, op_in_array:Array, op_out_array:Array, op_name):
self.hash_var_dict[new_hash] = op_out
op_out_hash.append(new_hash)
operation = (op_in_hash, op_out_hash, op_name)
if not operation in self.module_map:
self.module_map.append(operation)
self.module_map.append(operation)
self.print_operation(operation)
return [op_out_hash]

def update_module_name_hash(self, op_in_hash:Array, op_out_array:Array, op_name):
if len(op_out_array) == 0:
return None
op_out_hash = []
for op_out in op_out_array:
new_hash = hash_code(op_out)
self.out_dict[op_out] = new_hash
self.hash_var_dict[new_hash] = op_out
op_out_hash.append(new_hash)
operation = (op_in_hash, op_out_hash, op_name)
self.module_map.append(operation)
self.print_operation(operation)
return [op_out_hash]

def print_parents_and_code(self, this:Name):
if len(self.parent_stack) >= 1:
print(f'{Color.BOLD_BLUE}{self.parent_stack}{Color.END} {this.id} {Color.LIME}{astor.to_source(self.parent_stack[0])}{Color.END}', end='')

def print_operation(self, mm, length=8):
if mm:
print(f'Here is a layer of module map:', [a[:length] for a in mm[0]], [a[:length] for a in mm[1]], mm[2])

0 comments on commit a914726

Please sign in to comment.