From a9147264b9424428ae7d3808b3e0d4f782a5c189 Mon Sep 17 00:00:00 2001 From: "Yilin (Eleen) Bao" <62961093+yilin-bao@users.noreply.github.com> Date: Sat, 9 Dec 2023 18:14:39 -0600 Subject: [PATCH] Update analyzer.py --- analyzer.py | 258 +++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 215 insertions(+), 43 deletions(-) diff --git a/analyzer.py b/analyzer.py index ffa2802..430e617 100644 --- a/analyzer.py +++ b/analyzer.py @@ -10,9 +10,11 @@ from typing import Any from matplotlib.pylab import pareto from numpy import isin, var +from sqlalchemy import String from sympy import false from torch import rand import torch.nn as nn +import torch import random import string import numpy as np @@ -72,6 +74,11 @@ class Color: # ANSI escape code to reset text attributes to default END = '\033[0m' +def print_ast_of_node(node): + code = astor.to_source(node) + tree = ast.parse(code) + ast_code = astor.to_source(tree) + print(f'{Color.BOLD_LIGHT_GRAY}{tree}{Color.END}') def generate_random_string(length=8): characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" @@ -235,7 +242,7 @@ def start_analyze_module(self, module:nn.Module): self.all_parameters = dict(module.named_parameters()) self.analyze_module("self", module) - def analyze_module(self, var_name, module:nn.Module): + def analyze_module(self, var_name, module): ''' Recursively analyzes the structure of PyTorch modules and their nested children. @@ -293,7 +300,7 @@ def analyze_defined_module(self, var_name, module_name, module, var_module_layer # [var_module_layer] is the variable-module dictionary of current layer analyzer = ModuleAstAnalyzer(var_module_layer, var_name, module_name) analyzer.visit(module_ast) - self.print_module_map(analyzer.module_map) + # self.print_module_map(analyzer.module_map) return 0 def update_module_map(self, var_name, replace_map): @@ -410,12 +417,23 @@ def all_Attribute(self, targets): #---------------------------------------- def find_full_name_array(self, nodes): - ret = [] - for node in nodes: - node_name = self.find_full_name(node) - if node_name: - ret.append(node_name) - return ret + if not nodes == None: + ret = [] + for node in nodes: + if isinstance(node, ast.Tuple): + for elt in node.elts: + elt_name = self.find_full_name(elt) + if elt_name: + ret.append(elt_name) + elif isinstance(node, String): + if 'self' in node: + ret.append(node) + else: + node_name = self.find_full_name(node) + if node_name: + ret.append(node_name) + return ret + return None def find_full_name(self, node): if isinstance(node, ast.Name): @@ -426,22 +444,12 @@ def find_full_name(self, node): return None def remove_starting_self(self, node_name): - arr = node_name.split('.') - if arr[0] == 'self': - return '.'.join(arr[1:]) - else: - return '.'.join(arr) - - # def find_all_names(self, node_list): - # ret = [] - # for node in node_list: - # if isinstance(node, ast.Name) or isinstance(node, ast.Attribute): - # ret.append(self.find_full_name(node)) - # elif isinstance(node, ast.Constant): - # ret.append(str(node.value)) - # elif isinstance(node, ast.BinOp): - # ret.append(astor.to_source(node)) - # return ret + if node_name: + arr = node_name.split('.') + if arr[0] == 'self': + return '.'.join(arr[1:]) + else: + return '.'.join(arr) #-------------------------------------------------- #--Analyze special functions (such as forward())--- @@ -507,21 +515,152 @@ def analyze_net_ast(self, parents, this:Name): parents = parents[::-1] # === self.print_parents_and_code(this) - for parent in parents: - if isinstance(parent, ast.Assign): - pass - elif isinstance(parent, ast.Call): - pass - elif isinstance(parent, ast.Subscript): - pass - elif isinstance(parent, ast.For): - pass - elif isinstance(parent, ast.BinOp): - pass + if len(parents) == 1: + self.special_case_length_one(parents[0], this) + else: + current = None + intermediate = None + for parent in parents: + if isinstance(parent, ast.Assign): + intermediate = self.analyze_net_ast_assign(parent, this, current, intermediate) + elif isinstance(parent, ast.Call): + intermediate = self.analyze_net_ast_call(parent, this, current, intermediate) + elif isinstance(parent, ast.Attribute): + intermediate = self.analyze_net_ast_attribute(parent, this, current, intermediate) + elif isinstance(parent, ast.Subscript): + intermediate = self.analyze_net_ast_subscript(parent, this, current, intermediate) + elif isinstance(parent, ast.For): + intermediate = self.analyze_net_ast_for(parent, this, current, intermediate) + elif isinstance(parent, ast.BinOp): + intermediate = self.analyze_net_ast_binop(parent, this, current, intermediate) + elif isinstance(parent, ast.Tuple): + intermediate = self.analyze_net_ast_tuple(parent, this, current, intermediate) + current = parent # === return 0 + + def special_case_length_one(self, parent, this:Name): + if isinstance(parent, ast.Attribute): + # [] nn.Module + # [] self.revised + # self.print_parents_and_code(this) + pass + elif isinstance(parent, ast.Assign) and parent.targets[0] == this: + # [] x = self.embedding_layer(x) + # [] x = self.transformer(x) + # [] x = self.post_transformer_ln(x) + # [] x = self.cls_layer(x) + # What we find is the variable [x] on the left side of assign, so ignore + # self.print_parents_and_code(this) + pass + elif isinstance(parent, ast.Assign) and not parent.targets[0] == this: + # seems is a case not gonna happen + # self.print_parents_and_code(this) + pass + elif isinstance(parent, ast.Call): + # self.print_parents_and_code(this) + op_name = self.from_node_to_operation(parent.func) + args = self.find_full_name_array(parent.args) + self.update_module_name(args, args, op_name) + pass + else: + # self.print_parents_and_code(this) + pass + + def analyze_net_ast_assign(self, node:ast.Assign, this:Name, current, intermediate): + print("Start analyzation on a node type of ast.Assign") + op_name = "=" + if current: + if node.value == current or node.targets[0] == current: # B, N, C = x.shape + op_out = self.find_full_name_array(node.targets) + if intermediate: + ret_hash = self.update_module_name_only_out([intermediate], op_out, op_name)[0] + return ret_hash + else: + pass + else: + pass + else: + pass + return intermediate + + def analyze_net_ast_call(self, node:ast.Call, this:Name, current, intermediate): + print("Start analyzation on a node type of ast.Call") + if current: + if current in node.args: + op_name = self.from_node_to_operation(node.func) + ret_hash = self.update_module_name_hash_in([intermediate], op_name)[0] + return ret_hash + else: + return intermediate + else: + op_name = self.from_node_to_operation(node.func) + args = self.find_full_name_array(node.args) + ret_hash = self.update_module_name_only_in(args, op_name)[0] + return ret_hash + return intermediate + + def analyze_net_ast_attribute(self, node:ast.Attribute, this:Name, current, intermediate): + print("Start analyzation on a node type of ast.Attribute") + op_name = node.attr + if current: + if node.value == current: + ret_hash = self.update_module_name_hash_in([intermediate], op_name)[0] + return ret_hash + else: + pass + else: + args = self.find_full_name_array([node.value]) + ret_hash = self.update_module_name_only_in(args, op_name)[0] + return ret_hash + return intermediate + + def analyze_net_ast_subscript(self, node:ast.Subscript, this:Name, current, intermediate): + print("Start analyzation on a node type of ast.Subscript") + if current: + pass + else: + op_name = node.slice + op_name = astor.to_source(op_name) + args = [this.id + f'[{op_name}]'] + ret_hash = self.update_module_name_only_in(args, op_name)[0] + return ret_hash + return intermediate + + def analyze_net_ast_for(self, node:ast.For, this:Name, current, intermediate): + print("Start analyzation on a node type of ast.For") + if current: + pass + else: + pass + return intermediate + + def analyze_net_ast_binop(self, node:ast.BinOp, this:Name, current, intermediate): + print("Start analyzation on a node type of ast.BinOp") + op_name = node.op + if current: + pass + else: + args_left = self.find_full_name_array([node.left]) + args_right = self.find_full_name_array([node.right]) + print(args_left) + print(args_right) + ret_hash = self.update_module_name_only_in(args_left+args_right, op_name)[0] + return ret_hash + return intermediate + + def analyze_net_ast_tuple(self, node:ast.Tuple, this:Name, current, intermediate): + print("Start analyzation on a node type of ast.BinOp") + op_name = "=" + if current: + pass + else: + args = self.find_full_name_array(node.elts) + ret_hash = self.update_module_name_only_in(args, op_name)[0] + return ret_hash + return intermediate - def is_parameter_or_tensor(self, op_in_array:Array, op_out_array:Array, op_name): + def is_parameter_or_tensor(self, op_name): return 0 def from_node_to_operation(self, node): @@ -535,13 +674,24 @@ def from_node_to_operation(self, node): def update_module_name_only_in(self, op_in_array:Array, op_name): op_in_hash = [] for op_in in op_in_array: - old_hash = self.out_dict[op_in] + if op_in in self.out_dict: + old_hash = self.out_dict[op_in] + else: + old_hash = hash_code(op_in) op_in_hash.append(old_hash) middle_hash = hash_code() op_out_hash = [middle_hash] operation = (op_in_hash, op_out_hash, op_name) - if not operation in self.module_map: - self.module_map.append(operation) + self.module_map.append(operation) + self.print_operation(operation) + return [middle_hash] + + def update_module_name_hash_in(self, op_in_hash:Array, op_name): + middle_hash = hash_code() + op_out_hash: list[str] = [middle_hash] + operation = (op_in_hash, op_out_hash, op_name) + self.module_map.append(operation) + self.print_operation(operation) return [middle_hash] def update_module_name_only_out(self, middle_hash, op_out_array:Array, op_name): @@ -552,11 +702,15 @@ def update_module_name_only_out(self, middle_hash, op_out_array:Array, op_name): self.hash_var_dict[new_hash] = op_out op_out_hash.append(new_hash) operation = (middle_hash, op_out_hash, op_name) - if not operation in self.module_map: - self.module_map.append(operation) + self.module_map.append(operation) + self.print_operation(operation) return [op_out_hash] def update_module_name(self, op_in_array:Array, op_out_array:Array, op_name): + if len(op_in_array) == 0: + return None + if len(op_out_array) == 0: + return None op_in_hash = [] for op_in in op_in_array: old_hash = self.out_dict[op_in] @@ -568,10 +722,28 @@ def update_module_name(self, op_in_array:Array, op_out_array:Array, op_name): self.hash_var_dict[new_hash] = op_out op_out_hash.append(new_hash) operation = (op_in_hash, op_out_hash, op_name) - if not operation in self.module_map: - self.module_map.append(operation) + self.module_map.append(operation) + self.print_operation(operation) + return [op_out_hash] + + def update_module_name_hash(self, op_in_hash:Array, op_out_array:Array, op_name): + if len(op_out_array) == 0: + return None + op_out_hash = [] + for op_out in op_out_array: + new_hash = hash_code(op_out) + self.out_dict[op_out] = new_hash + self.hash_var_dict[new_hash] = op_out + op_out_hash.append(new_hash) + operation = (op_in_hash, op_out_hash, op_name) + self.module_map.append(operation) + self.print_operation(operation) return [op_out_hash] def print_parents_and_code(self, this:Name): if len(self.parent_stack) >= 1: print(f'{Color.BOLD_BLUE}{self.parent_stack}{Color.END} {this.id} {Color.LIME}{astor.to_source(self.parent_stack[0])}{Color.END}', end='') + + def print_operation(self, mm, length=8): + if mm: + print(f'Here is a layer of module map:', [a[:length] for a in mm[0]], [a[:length] for a in mm[1]], mm[2])