From d963fafdd3ea2465d77ed59a1012ac6d6b85ed2a Mon Sep 17 00:00:00 2001 From: yilin-bao <62961093+yilin-bao@users.noreply.github.com> Date: Fri, 8 Dec 2023 20:27:33 -0600 Subject: [PATCH] Update analyzer.py --- analyzer.py | 144 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 99 insertions(+), 45 deletions(-) diff --git a/analyzer.py b/analyzer.py index f98792d..9ee25d1 100644 --- a/analyzer.py +++ b/analyzer.py @@ -1,6 +1,6 @@ import ast, astor import hashlib -from ast import Assign, Attribute, BinOp, Call, For, Name, Return, Tuple, mod +from ast import Assign, Attribute, BinOp, Call, For, Name, Return, Subscript, Tuple, mod from gettext import find import inspect from re import L @@ -207,6 +207,15 @@ def remove_var_flag(self): self.var_flag = '.'.join(arr[:-1]) else: self.var_flag = None + + def var_name_from_whole(self, var_name): + arr = var_name.split('.') + if len(arr) > 1: + return arr[-1] + elif len(arr) == 1: + return arr[0] + else: + return var_name def start_analyze_module(self, module:nn.Module): ''' @@ -239,21 +248,22 @@ def analyze_module(self, var_name, module:nn.Module): # var_mod_list = {name:layer for name, layer in module.named_children()} # Print the name:class pair of the module # self.print_analyze_status(var_whole_name, module_name, module) - self.analyze_module_by_cases(var_name, module) + # print(var_name, var_whole_name) + self.analyze_module_by_cases(var_whole_name, module_name, module) return var_whole_name, module_name - def analyze_module_by_cases(self, var_name, module): + def analyze_module_by_cases(self, var_name, module_name, module): # If anymore following work is needed here if self.is_torch_module(module): pass else: pass self.update_module_flag(module) - self.update_var_flag(var_name) + self.update_var_flag(self.var_name_from_whole(var_name)) var_module_layer = {} for name, layer in module.named_children(): - var_whole_name, module_name = self.analyze_module(name, layer) - var_module_layer[var_whole_name] = module_name + var_whole_name, sub_module_name = self.analyze_module(name, layer) + var_module_layer[var_whole_name] = sub_module_name self.remove_module_flag() self.remove_var_flag() # Either if current module is a pyTorch in-built module @@ -262,19 +272,20 @@ def analyze_module_by_cases(self, var_name, module): if self.is_torch_module(module): self.analyze_inbuild_module() else: - self.analyze_defined_module(module, var_module_layer) + # print("self.layer_flag", self.var_flag) + self.analyze_defined_module(var_name, module_name, module, var_module_layer) return 0 def analyze_inbuild_module(self): return 0 - def analyze_defined_module(self, module, var_module_layer): + def analyze_defined_module(self, var_name, module_name, module, var_module_layer): module_code = inspect.getsource(type(module)) module_ast = ast.parse(module_code) # [var_module_layer] is the variable-module dictionary of current layer - analyzer = ModuleAstAnalyzer(var_module_layer) + analyzer = ModuleAstAnalyzer(var_module_layer, var_name, module_name) analyzer.visit(module_ast) - print("Results:", analyzer.module_map) + # print("Results:", analyzer.module_map) return 0 # def analyze_inbuild_module(self, var_name, var_whole_name, module_name, module): @@ -390,11 +401,13 @@ def print_current_layer_information(self, var_whole_name, module_name, depth=2): #============================================================ class ModuleAstAnalyzer(ast.NodeVisitor): - def __init__(self, var_module_dict): + def __init__(self, var_module_dict, var_name, module_name): # Parent stack, write in all the parents node visited before self.parent_stack = [] # In [ModuleAstAnalyzer], [var_module_dict] is just the current analyzed layer self.var_module_dict:dict = var_module_dict + self.var_name = var_name + self.module_name = module_name self.module_map = [] # forward_tensor_list: tensor, matrix, vector usd in deep learning @@ -437,31 +450,54 @@ def all_Attribute(self, targets): #---------Generically find names--------- #---------------------------------------- - def find_full_name(self, node):# -> _Identifier | Any | tuple[str, _Identifier] | None: + # def find_full_name(self, node):# -> _Identifier | Any | tuple[str, _Identifier] | None: + # if isinstance(node, ast.Name): + # return node.id + # elif isinstance(node, ast.Attribute): + # if isinstance(node.value, ast.Attribute): + # return node.attr + '.' + self.find_full_name(node.value) + # elif isinstance(node.value, ast.Name) and not node.value.id == 'self': + # return node.attr + '.', node.value.id + # elif isinstance(node.value, ast.Name) and node.value.id == 'self': + # return node.attr + # elif isinstance(node.value, ast.Call): + # return node.attr + # else: + # pass # This won't happen + + def find_full_name_array(self, nodes): + ret = [] + for node in nodes: + node_name = self.find_full_name(node) + if node_name: + ret.append(node_name) + return ret + + def find_full_name(self, node): if isinstance(node, ast.Name): return node.id elif isinstance(node, ast.Attribute): - if isinstance(node.value, ast.Attribute): - return node.attr + '.' + self.find_full_name(node.value) - elif isinstance(node.value, ast.Name) and not node.value.id == 'self': - return node.attr + '.', node.value.id - elif isinstance(node.value, ast.Name) and node.value.id == 'self': - return node.attr - elif isinstance(node.value, ast.Call): - return node.attr - else: - pass # This won't happen + return f'{self.find_full_name(node.value)}.{node.attr}' + else: + return None - def find_all_names(self, node_list): - ret = [] - for node in node_list: - if isinstance(node, ast.Name) or isinstance(node, ast.Attribute): - ret.append(self.find_full_name(node)) - elif isinstance(node, ast.Constant): - ret.append(str(node.value)) - elif isinstance(node, ast.BinOp): - ret.append(astor.to_source(node)) - return ret + def remove_starting_self(self, node_name): + arr = node_name.split('.') + if arr[0] == 'self': + return '.'.join(arr[1:]) + else: + return '.'.join(arr) + + # def find_all_names(self, node_list): + # ret = [] + # for node in node_list: + # if isinstance(node, ast.Name) or isinstance(node, ast.Attribute): + # ret.append(self.find_full_name(node)) + # elif isinstance(node, ast.Constant): + # ret.append(str(node.value)) + # elif isinstance(node, ast.BinOp): + # ret.append(astor.to_source(node)) + # return ret #-------------------------------------------------- #--Analyze special functions (such as forward())--- @@ -473,8 +509,14 @@ def visit_FunctionDef(self, node): for arg in node.args.args: # print("arguments in forward", arg.arg) if not arg.arg == 'self': + # self.forward_tensor_list = [] + # self.forward_param_list = [] + # self.out_dict = {} + # self.hash_var_dict = {} self.forward_tensor_list.append(arg.arg) - hashlib.sha256(arg.arg) + arg_hashed = hash_code(arg.arg) + self.out_dict[arg.arg] = arg_hashed + self.hash_var_dict[arg_hashed] = arg.arg self.generic_visit(node) #-------------------------------------------------- @@ -497,7 +539,10 @@ def visit_Assign(self, node: Assign) -> Any: self.generic_visit_with_parent_stack(node) def visit_Tuple(self, node: Tuple) -> Any: - return self.generic_visit_with_parent_stack(node) + self.generic_visit_with_parent_stack(node) + + def visit_Subscript(self, node: Subscript) -> Any: + self.generic_visit_with_parent_stack(node) #-------------------------------------------------- #-------------Core code of this class-------------- @@ -521,7 +566,8 @@ def analyze_net_name(self, parents, this:Name): elif len(parents) == 2: self.special_case_length_two(parents[0], parents[1], this) else: - self.print_parents_and_code() + # self.print_parents_and_code(this) + pass return 0 def special_case_length_one(self, parent, this:Name): @@ -537,9 +583,11 @@ def special_case_length_one(self, parent, this:Name): # What we find is the variable [x] on the left side of assign, so ignore pass elif isinstance(parent, ast.Assign) and not parent.targets[0] == this: - self.print_parents_and_code() + # self.print_parents_and_code(this) + pass else: - self.print_parents_and_code() + # self.print_parents_and_code(this) + pass def special_case_length_two(self, parent, grandparent, this:Name): # self.forward_tensor_list = [] @@ -547,18 +595,24 @@ def special_case_length_two(self, parent, grandparent, this:Name): # self.out_dict = {} # self.hash_var_dict = {} if isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Call): - print(parent.func) - self.print_parents_and_code() + print(self.var_name) + print(self.module_name) + print(self.remove_starting_self(self.find_full_name(parent.func))) + print(self.find_full_name_array(parent.args)) + self.print_parents_and_code(this) elif isinstance(grandparent, ast.Call) and isinstance(parent, ast.Call): - self.print_parents_and_code() + # self.print_parents_and_code(this) + pass elif isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Attribute): print(parent.attr) - self.print_parents_and_code() + # self.print_parents_and_code(this) elif isinstance(grandparent, ast.For) and isinstance(parent, ast.Assign): - self.print_parents_and_code() + # self.print_parents_and_code(this) + pass else: - self.print_parents_and_code() + # self.print_parents_and_code(this) + pass - def print_parents_and_code(self): + def print_parents_and_code(self, this:Name): if len(self.parent_stack) >= 1: - print(f'{Color.BOLD_BLUE}{self.parent_stack}{Color.END} {Color.LIME}{astor.to_source(self.parent_stack[0])}{Color.END}', end=' ') + print(f'{Color.BOLD_BLUE}{self.parent_stack}{Color.END} {this.id} {Color.LIME}{astor.to_source(self.parent_stack[0])}{Color.END}', end=' ')