diff --git a/analyzer.py b/analyzer.py index 3cdf90a..ffa2802 100644 --- a/analyzer.py +++ b/analyzer.py @@ -295,62 +295,6 @@ def analyze_defined_module(self, var_name, module_name, module, var_module_layer analyzer.visit(module_ast) self.print_module_map(analyzer.module_map) return 0 - - # def analyze_inbuild_module(self, var_name, var_whole_name, module_name, module): - # self.update_module_flag(module) - # self.update_var_flag(var_name) - # lf_array = self.layer_flag.split('.') - # vf_array = self.var_flag.split('.') - # lf_len = len(lf_array) - # vf_len = len(vf_array) - # # [We should make this another separated function] - # if self.nn_module_flag and self.nn_module_flag in var_whole_name: - # op_in = self.moudle_map[self.nn_module_index][0] - # op_out = self.moudle_map[self.nn_module_index][1] - # self.nn_module_pack.append((op_in, op_out, f"{self.nn_module_flag}.{module_name}")) - # elif self.nn_module_flag and not self.nn_module_flag in var_whole_name: - # self.moudle_map = self.update_module_map(self.nn_module_flag, self.nn_module_pack) - # self.nn_module_flag = None - # self.nn_module_index = None - # map_modules = [item[-1] for item in self.moudle_map] - # if var_name in map_modules: - # self.nn_module_flag = var_name - # self.nn_module_index = map_modules.index(var_name) - # # [This function should ends here] - # for name, layer in module.named_children(): - # self.analyze_module(name, layer) - # if lf_len > 1: - # self.layer_flag = '.'.join(lf_array[:-1]) - # else: - # self.layer_flag = None - # if vf_len > 1: - # self.var_flag = '.'.join(vf_array[:-1]) - # else: - # self.var_flag = None - # return 0 - - # def analyze_defined_module(self, var_whole_name, var_name, module_name, module, var_mod_list): - # # [This should be notice] - # if self.nn_module_flag and not self.nn_module_flag in var_whole_name: - # self.moudle_map = self.update_module_map(self.nn_module_flag, self.nn_module_pack) - # self.nn_module_flag = None - # self.nn_module_index = None - # # [This should be notice] - # if module not in list(self.var_module_dict.keys()): - # module_code = inspect.getsource(type(module)) - # module_ast = ast.parse(module_code) - # analyzer = ModuleAstAnalyzer(var_mod_list) - # analyzer.visit(module_ast) - # result = analyzer.module_map - # self.moudle_map = self.update_module_map(var_name, result) - # else: - # result = self.var_module_dict[module_name] - # self.moudle_map = self.update_module_map(var_name, result) - # # print(f"{Color.GREEN}{self.moudle_map}{Color.END}") - # for name, layer in module.named_children(): - # self.analyze_module(name, layer) - # self.var_module_dict[module_name] = analyzer.module_map - # return 0 def update_module_map(self, var_name, replace_map): # Determine if our current module is mentioned in previous analysis @@ -464,21 +408,6 @@ def all_Attribute(self, targets): #---------------------------------------- #---------Generically find names--------- #---------------------------------------- - - # def find_full_name(self, node):# -> _Identifier | Any | tuple[str, _Identifier] | None: - # if isinstance(node, ast.Name): - # return node.id - # elif isinstance(node, ast.Attribute): - # if isinstance(node.value, ast.Attribute): - # return node.attr + '.' + self.find_full_name(node.value) - # elif isinstance(node.value, ast.Name) and not node.value.id == 'self': - # return node.attr + '.', node.value.id - # elif isinstance(node.value, ast.Name) and node.value.id == 'self': - # return node.attr - # elif isinstance(node.value, ast.Call): - # return node.attr - # else: - # pass # This won't happen def find_full_name_array(self, nodes): ret = [] @@ -565,10 +494,10 @@ def visit_Subscript(self, node: Subscript) -> Any: def visit_Name(self, node: Name) -> Any: if node.id in self.forward_tensor_list or node.id in self.forward_param_list: - self.analyze_net_name(self.parent_stack, node) + self.analyze_net_ast(self.parent_stack, node) self.generic_visit_with_parent_stack(node) - def analyze_net_name(self, parents, this:Name): + def analyze_net_ast(self, parents, this:Name): if len(parents) == 0: return 0 # self.forward_tensor_list = [] @@ -576,70 +505,24 @@ def analyze_net_name(self, parents, this:Name): # self.out_dict = {} # self.hash_var_dict = {} parents = parents[::-1] - if len(parents) == 1: - self.special_case_length_one(parents[0], this) - elif len(parents) == 2: - self.special_case_length_two(parents[0], parents[1], this) - else: - # self.print_parents_and_code(this) - pass - return 0 - - def special_case_length_one(self, parent, this:Name): - if isinstance(parent, ast.Attribute): - # [] nn.Module - # [] self.revised - # self.print_parents_and_code(this) - pass - elif isinstance(parent, ast.Assign) and parent.targets[0] == this: - # [] x = self.embedding_layer(x) - # [] x = self.transformer(x) - # [] x = self.post_transformer_ln(x) - # [] x = self.cls_layer(x) - # What we find is the variable [x] on the left side of assign, so ignore - # self.print_parents_and_code(this) - pass - elif isinstance(parent, ast.Assign) and not parent.targets[0] == this: - # seems is a case not gonna happen - # self.print_parents_and_code(this) - pass - elif isinstance(parent, ast.Call): - # self.print_parents_and_code(this) - op_name = self.from_node_to_operation(parent.func) - args = self.find_full_name_array(parent.args) - self.update_module_name(args, args, op_name) - pass - else: - # self.print_parents_and_code(this) - pass - - def special_case_length_two(self, parent, grandparent, this:Name): - # self.forward_tensor_list = [] - # self.forward_param_list = [] - # self.out_dict = {} - # self.hash_var_dict = {} - if isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Call): - if grandparent.value == parent: - op_name = self.from_node_to_operation(parent.func) - # print(self.var_module_dict) - args = self.find_full_name_array(parent.args) - targets = self.find_full_name_array(grandparent.targets) - self.update_module_name(args, targets, op_name) - else: - # self.print_parents_and_code(this) + # === + self.print_parents_and_code(this) + for parent in parents: + if isinstance(parent, ast.Assign): pass - elif isinstance(grandparent, ast.Call) and isinstance(parent, ast.Call): - self.print_parents_and_code(this) - pass - elif isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Attribute): - print(parent.attr) - # self.print_parents_and_code(this) - elif isinstance(grandparent, ast.For) and isinstance(parent, ast.Assign): - # self.print_parents_and_code(this) - pass - else: - # self.print_parents_and_code(this) - pass + elif isinstance(parent, ast.Call): + pass + elif isinstance(parent, ast.Subscript): + pass + elif isinstance(parent, ast.For): + pass + elif isinstance(parent, ast.BinOp): + pass + # === + return 0 + + def is_parameter_or_tensor(self, op_in_array:Array, op_out_array:Array, op_name): + return 0 def from_node_to_operation(self, node): var_op = self.find_full_name(node) @@ -656,8 +539,10 @@ def update_module_name_only_in(self, op_in_array:Array, op_name): op_in_hash.append(old_hash) middle_hash = hash_code() op_out_hash = [middle_hash] - self.module_map.append((op_in_hash, op_out_hash, op_name)) - return middle_hash + operation = (op_in_hash, op_out_hash, op_name) + if not operation in self.module_map: + self.module_map.append(operation) + return [middle_hash] def update_module_name_only_out(self, middle_hash, op_out_array:Array, op_name): op_out_hash = [] @@ -666,7 +551,10 @@ def update_module_name_only_out(self, middle_hash, op_out_array:Array, op_name): self.out_dict[op_out] = new_hash self.hash_var_dict[new_hash] = op_out op_out_hash.append(new_hash) - self.module_map.append((middle_hash, op_out_hash, op_name)) + operation = (middle_hash, op_out_hash, op_name) + if not operation in self.module_map: + self.module_map.append(operation) + return [op_out_hash] def update_module_name(self, op_in_array:Array, op_out_array:Array, op_name): op_in_hash = [] @@ -679,7 +567,10 @@ def update_module_name(self, op_in_array:Array, op_out_array:Array, op_name): self.out_dict[op_out] = new_hash self.hash_var_dict[new_hash] = op_out op_out_hash.append(new_hash) - self.module_map.append((op_in_hash, op_out_hash, op_name)) + operation = (op_in_hash, op_out_hash, op_name) + if not operation in self.module_map: + self.module_map.append(operation) + return [op_out_hash] def print_parents_and_code(self, this:Name): if len(self.parent_stack) >= 1: