Skip to content

Commit

Permalink
Update analyzer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yilin-bao authored Dec 9, 2023
1 parent 4cf602e commit d963faf
Showing 1 changed file with 99 additions and 45 deletions.
144 changes: 99 additions & 45 deletions analyzer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ast, astor
import hashlib
from ast import Assign, Attribute, BinOp, Call, For, Name, Return, Tuple, mod
from ast import Assign, Attribute, BinOp, Call, For, Name, Return, Subscript, Tuple, mod
from gettext import find
import inspect
from re import L
Expand Down Expand Up @@ -207,6 +207,15 @@ def remove_var_flag(self):
self.var_flag = '.'.join(arr[:-1])
else:
self.var_flag = None

def var_name_from_whole(self, var_name):
arr = var_name.split('.')
if len(arr) > 1:
return arr[-1]
elif len(arr) == 1:
return arr[0]
else:
return var_name

def start_analyze_module(self, module:nn.Module):
'''
Expand Down Expand Up @@ -239,21 +248,22 @@ def analyze_module(self, var_name, module:nn.Module):
# var_mod_list = {name:layer for name, layer in module.named_children()}
# Print the name:class pair of the module
# self.print_analyze_status(var_whole_name, module_name, module)
self.analyze_module_by_cases(var_name, module)
# print(var_name, var_whole_name)
self.analyze_module_by_cases(var_whole_name, module_name, module)
return var_whole_name, module_name

def analyze_module_by_cases(self, var_name, module):
def analyze_module_by_cases(self, var_name, module_name, module):
# If anymore following work is needed here
if self.is_torch_module(module):
pass
else:
pass
self.update_module_flag(module)
self.update_var_flag(var_name)
self.update_var_flag(self.var_name_from_whole(var_name))
var_module_layer = {}
for name, layer in module.named_children():
var_whole_name, module_name = self.analyze_module(name, layer)
var_module_layer[var_whole_name] = module_name
var_whole_name, sub_module_name = self.analyze_module(name, layer)
var_module_layer[var_whole_name] = sub_module_name
self.remove_module_flag()
self.remove_var_flag()
# Either if current module is a pyTorch in-built module
Expand All @@ -262,19 +272,20 @@ def analyze_module_by_cases(self, var_name, module):
if self.is_torch_module(module):
self.analyze_inbuild_module()
else:
self.analyze_defined_module(module, var_module_layer)
# print("self.layer_flag", self.var_flag)
self.analyze_defined_module(var_name, module_name, module, var_module_layer)
return 0

def analyze_inbuild_module(self):
return 0

def analyze_defined_module(self, module, var_module_layer):
def analyze_defined_module(self, var_name, module_name, module, var_module_layer):
module_code = inspect.getsource(type(module))
module_ast = ast.parse(module_code)
# [var_module_layer] is the variable-module dictionary of current layer
analyzer = ModuleAstAnalyzer(var_module_layer)
analyzer = ModuleAstAnalyzer(var_module_layer, var_name, module_name)
analyzer.visit(module_ast)
print("Results:", analyzer.module_map)
# print("Results:", analyzer.module_map)
return 0

# def analyze_inbuild_module(self, var_name, var_whole_name, module_name, module):
Expand Down Expand Up @@ -390,11 +401,13 @@ def print_current_layer_information(self, var_whole_name, module_name, depth=2):
#============================================================

class ModuleAstAnalyzer(ast.NodeVisitor):
def __init__(self, var_module_dict):
def __init__(self, var_module_dict, var_name, module_name):
# Parent stack, write in all the parents node visited before
self.parent_stack = []
# In [ModuleAstAnalyzer], [var_module_dict] is just the current analyzed layer
self.var_module_dict:dict = var_module_dict
self.var_name = var_name
self.module_name = module_name
self.module_map = []

# forward_tensor_list: tensor, matrix, vector usd in deep learning
Expand Down Expand Up @@ -437,31 +450,54 @@ def all_Attribute(self, targets):
#---------Generically find names---------
#----------------------------------------

def find_full_name(self, node):# -> _Identifier | Any | tuple[str, _Identifier] | None:
# def find_full_name(self, node):# -> _Identifier | Any | tuple[str, _Identifier] | None:
# if isinstance(node, ast.Name):
# return node.id
# elif isinstance(node, ast.Attribute):
# if isinstance(node.value, ast.Attribute):
# return node.attr + '.' + self.find_full_name(node.value)
# elif isinstance(node.value, ast.Name) and not node.value.id == 'self':
# return node.attr + '.', node.value.id
# elif isinstance(node.value, ast.Name) and node.value.id == 'self':
# return node.attr
# elif isinstance(node.value, ast.Call):
# return node.attr
# else:
# pass # This won't happen

def find_full_name_array(self, nodes):
ret = []
for node in nodes:
node_name = self.find_full_name(node)
if node_name:
ret.append(node_name)
return ret

def find_full_name(self, node):
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
if isinstance(node.value, ast.Attribute):
return node.attr + '.' + self.find_full_name(node.value)
elif isinstance(node.value, ast.Name) and not node.value.id == 'self':
return node.attr + '.', node.value.id
elif isinstance(node.value, ast.Name) and node.value.id == 'self':
return node.attr
elif isinstance(node.value, ast.Call):
return node.attr
else:
pass # This won't happen
return f'{self.find_full_name(node.value)}.{node.attr}'
else:
return None

def find_all_names(self, node_list):
ret = []
for node in node_list:
if isinstance(node, ast.Name) or isinstance(node, ast.Attribute):
ret.append(self.find_full_name(node))
elif isinstance(node, ast.Constant):
ret.append(str(node.value))
elif isinstance(node, ast.BinOp):
ret.append(astor.to_source(node))
return ret
def remove_starting_self(self, node_name):
arr = node_name.split('.')
if arr[0] == 'self':
return '.'.join(arr[1:])
else:
return '.'.join(arr)

# def find_all_names(self, node_list):
# ret = []
# for node in node_list:
# if isinstance(node, ast.Name) or isinstance(node, ast.Attribute):
# ret.append(self.find_full_name(node))
# elif isinstance(node, ast.Constant):
# ret.append(str(node.value))
# elif isinstance(node, ast.BinOp):
# ret.append(astor.to_source(node))
# return ret

#--------------------------------------------------
#--Analyze special functions (such as forward())---
Expand All @@ -473,8 +509,14 @@ def visit_FunctionDef(self, node):
for arg in node.args.args:
# print("arguments in forward", arg.arg)
if not arg.arg == 'self':
# self.forward_tensor_list = []
# self.forward_param_list = []
# self.out_dict = {}
# self.hash_var_dict = {}
self.forward_tensor_list.append(arg.arg)
hashlib.sha256(arg.arg)
arg_hashed = hash_code(arg.arg)
self.out_dict[arg.arg] = arg_hashed
self.hash_var_dict[arg_hashed] = arg.arg
self.generic_visit(node)

#--------------------------------------------------
Expand All @@ -497,7 +539,10 @@ def visit_Assign(self, node: Assign) -> Any:
self.generic_visit_with_parent_stack(node)

def visit_Tuple(self, node: Tuple) -> Any:
return self.generic_visit_with_parent_stack(node)
self.generic_visit_with_parent_stack(node)

def visit_Subscript(self, node: Subscript) -> Any:
self.generic_visit_with_parent_stack(node)

#--------------------------------------------------
#-------------Core code of this class--------------
Expand All @@ -521,7 +566,8 @@ def analyze_net_name(self, parents, this:Name):
elif len(parents) == 2:
self.special_case_length_two(parents[0], parents[1], this)
else:
self.print_parents_and_code()
# self.print_parents_and_code(this)
pass
return 0

def special_case_length_one(self, parent, this:Name):
Expand All @@ -537,28 +583,36 @@ def special_case_length_one(self, parent, this:Name):
# What we find is the variable [x] on the left side of assign, so ignore
pass
elif isinstance(parent, ast.Assign) and not parent.targets[0] == this:
self.print_parents_and_code()
# self.print_parents_and_code(this)
pass
else:
self.print_parents_and_code()
# self.print_parents_and_code(this)
pass

def special_case_length_two(self, parent, grandparent, this:Name):
# self.forward_tensor_list = []
# self.forward_param_list = []
# self.out_dict = {}
# self.hash_var_dict = {}
if isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Call):
print(parent.func)
self.print_parents_and_code()
print(self.var_name)
print(self.module_name)
print(self.remove_starting_self(self.find_full_name(parent.func)))
print(self.find_full_name_array(parent.args))
self.print_parents_and_code(this)
elif isinstance(grandparent, ast.Call) and isinstance(parent, ast.Call):
self.print_parents_and_code()
# self.print_parents_and_code(this)
pass
elif isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Attribute):
print(parent.attr)
self.print_parents_and_code()
# self.print_parents_and_code(this)
elif isinstance(grandparent, ast.For) and isinstance(parent, ast.Assign):
self.print_parents_and_code()
# self.print_parents_and_code(this)
pass
else:
self.print_parents_and_code()
# self.print_parents_and_code(this)
pass

def print_parents_and_code(self):
def print_parents_and_code(self, this:Name):
if len(self.parent_stack) >= 1:
print(f'{Color.BOLD_BLUE}{self.parent_stack}{Color.END} {Color.LIME}{astor.to_source(self.parent_stack[0])}{Color.END}', end=' ')
print(f'{Color.BOLD_BLUE}{self.parent_stack}{Color.END} {this.id} {Color.LIME}{astor.to_source(self.parent_stack[0])}{Color.END}', end=' ')

0 comments on commit d963faf

Please sign in to comment.