Skip to content

Commit

Permalink
Delete those no longer used functions
Browse files Browse the repository at this point in the history
  • Loading branch information
yilin-bao authored Dec 9, 2023
1 parent 9952866 commit df4ddba
Showing 1 changed file with 31 additions and 140 deletions.
171 changes: 31 additions & 140 deletions analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,62 +295,6 @@ def analyze_defined_module(self, var_name, module_name, module, var_module_layer
analyzer.visit(module_ast)
self.print_module_map(analyzer.module_map)
return 0

# def analyze_inbuild_module(self, var_name, var_whole_name, module_name, module):
# self.update_module_flag(module)
# self.update_var_flag(var_name)
# lf_array = self.layer_flag.split('.')
# vf_array = self.var_flag.split('.')
# lf_len = len(lf_array)
# vf_len = len(vf_array)
# # [We should make this another separated function]
# if self.nn_module_flag and self.nn_module_flag in var_whole_name:
# op_in = self.moudle_map[self.nn_module_index][0]
# op_out = self.moudle_map[self.nn_module_index][1]
# self.nn_module_pack.append((op_in, op_out, f"{self.nn_module_flag}.{module_name}"))
# elif self.nn_module_flag and not self.nn_module_flag in var_whole_name:
# self.moudle_map = self.update_module_map(self.nn_module_flag, self.nn_module_pack)
# self.nn_module_flag = None
# self.nn_module_index = None
# map_modules = [item[-1] for item in self.moudle_map]
# if var_name in map_modules:
# self.nn_module_flag = var_name
# self.nn_module_index = map_modules.index(var_name)
# # [This function should ends here]
# for name, layer in module.named_children():
# self.analyze_module(name, layer)
# if lf_len > 1:
# self.layer_flag = '.'.join(lf_array[:-1])
# else:
# self.layer_flag = None
# if vf_len > 1:
# self.var_flag = '.'.join(vf_array[:-1])
# else:
# self.var_flag = None
# return 0

# def analyze_defined_module(self, var_whole_name, var_name, module_name, module, var_mod_list):
# # [This should be notice]
# if self.nn_module_flag and not self.nn_module_flag in var_whole_name:
# self.moudle_map = self.update_module_map(self.nn_module_flag, self.nn_module_pack)
# self.nn_module_flag = None
# self.nn_module_index = None
# # [This should be notice]
# if module not in list(self.var_module_dict.keys()):
# module_code = inspect.getsource(type(module))
# module_ast = ast.parse(module_code)
# analyzer = ModuleAstAnalyzer(var_mod_list)
# analyzer.visit(module_ast)
# result = analyzer.module_map
# self.moudle_map = self.update_module_map(var_name, result)
# else:
# result = self.var_module_dict[module_name]
# self.moudle_map = self.update_module_map(var_name, result)
# # print(f"{Color.GREEN}{self.moudle_map}{Color.END}")
# for name, layer in module.named_children():
# self.analyze_module(name, layer)
# self.var_module_dict[module_name] = analyzer.module_map
# return 0

def update_module_map(self, var_name, replace_map):
# Determine if our current module is mentioned in previous analysis
Expand Down Expand Up @@ -464,21 +408,6 @@ def all_Attribute(self, targets):
#----------------------------------------
#---------Generically find names---------
#----------------------------------------

# def find_full_name(self, node):# -> _Identifier | Any | tuple[str, _Identifier] | None:
# if isinstance(node, ast.Name):
# return node.id
# elif isinstance(node, ast.Attribute):
# if isinstance(node.value, ast.Attribute):
# return node.attr + '.' + self.find_full_name(node.value)
# elif isinstance(node.value, ast.Name) and not node.value.id == 'self':
# return node.attr + '.', node.value.id
# elif isinstance(node.value, ast.Name) and node.value.id == 'self':
# return node.attr
# elif isinstance(node.value, ast.Call):
# return node.attr
# else:
# pass # This won't happen

def find_full_name_array(self, nodes):
ret = []
Expand Down Expand Up @@ -565,81 +494,35 @@ def visit_Subscript(self, node: Subscript) -> Any:

def visit_Name(self, node: Name) -> Any:
if node.id in self.forward_tensor_list or node.id in self.forward_param_list:
self.analyze_net_name(self.parent_stack, node)
self.analyze_net_ast(self.parent_stack, node)
self.generic_visit_with_parent_stack(node)

def analyze_net_name(self, parents, this:Name):
def analyze_net_ast(self, parents, this:Name):
if len(parents) == 0:
return 0
# self.forward_tensor_list = []
# self.forward_param_list = []
# self.out_dict = {}
# self.hash_var_dict = {}
parents = parents[::-1]
if len(parents) == 1:
self.special_case_length_one(parents[0], this)
elif len(parents) == 2:
self.special_case_length_two(parents[0], parents[1], this)
else:
# self.print_parents_and_code(this)
pass
return 0

def special_case_length_one(self, parent, this:Name):
if isinstance(parent, ast.Attribute):
# [<ast.Attribute object at 0x12aa0f790>] nn.Module
# [<ast.Attribute object at 0x12aa0d4b0>] self.revised
# self.print_parents_and_code(this)
pass
elif isinstance(parent, ast.Assign) and parent.targets[0] == this:
# [<ast.Assign object at 0x130b28760>] x = self.embedding_layer(x)
# [<ast.Assign object at 0x130b28940>] x = self.transformer(x)
# [<ast.Assign object at 0x130b29a50>] x = self.post_transformer_ln(x)
# [<ast.Assign object at 0x130b29f30>] x = self.cls_layer(x)
# What we find is the variable [x] on the left side of assign, so ignore
# self.print_parents_and_code(this)
pass
elif isinstance(parent, ast.Assign) and not parent.targets[0] == this:
# seems is a case not gonna happen
# self.print_parents_and_code(this)
pass
elif isinstance(parent, ast.Call):
# self.print_parents_and_code(this)
op_name = self.from_node_to_operation(parent.func)
args = self.find_full_name_array(parent.args)
self.update_module_name(args, args, op_name)
pass
else:
# self.print_parents_and_code(this)
pass

def special_case_length_two(self, parent, grandparent, this:Name):
# self.forward_tensor_list = []
# self.forward_param_list = []
# self.out_dict = {}
# self.hash_var_dict = {}
if isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Call):
if grandparent.value == parent:
op_name = self.from_node_to_operation(parent.func)
# print(self.var_module_dict)
args = self.find_full_name_array(parent.args)
targets = self.find_full_name_array(grandparent.targets)
self.update_module_name(args, targets, op_name)
else:
# self.print_parents_and_code(this)
# ===
self.print_parents_and_code(this)
for parent in parents:
if isinstance(parent, ast.Assign):
pass
elif isinstance(grandparent, ast.Call) and isinstance(parent, ast.Call):
self.print_parents_and_code(this)
pass
elif isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Attribute):
print(parent.attr)
# self.print_parents_and_code(this)
elif isinstance(grandparent, ast.For) and isinstance(parent, ast.Assign):
# self.print_parents_and_code(this)
pass
else:
# self.print_parents_and_code(this)
pass
elif isinstance(parent, ast.Call):
pass
elif isinstance(parent, ast.Subscript):
pass
elif isinstance(parent, ast.For):
pass
elif isinstance(parent, ast.BinOp):
pass
# ===
return 0

def is_parameter_or_tensor(self, op_in_array:Array, op_out_array:Array, op_name):
return 0

def from_node_to_operation(self, node):
var_op = self.find_full_name(node)
Expand All @@ -656,8 +539,10 @@ def update_module_name_only_in(self, op_in_array:Array, op_name):
op_in_hash.append(old_hash)
middle_hash = hash_code()
op_out_hash = [middle_hash]
self.module_map.append((op_in_hash, op_out_hash, op_name))
return middle_hash
operation = (op_in_hash, op_out_hash, op_name)
if not operation in self.module_map:
self.module_map.append(operation)
return [middle_hash]

def update_module_name_only_out(self, middle_hash, op_out_array:Array, op_name):
op_out_hash = []
Expand All @@ -666,7 +551,10 @@ def update_module_name_only_out(self, middle_hash, op_out_array:Array, op_name):
self.out_dict[op_out] = new_hash
self.hash_var_dict[new_hash] = op_out
op_out_hash.append(new_hash)
self.module_map.append((middle_hash, op_out_hash, op_name))
operation = (middle_hash, op_out_hash, op_name)
if not operation in self.module_map:
self.module_map.append(operation)
return [op_out_hash]

def update_module_name(self, op_in_array:Array, op_out_array:Array, op_name):
op_in_hash = []
Expand All @@ -679,7 +567,10 @@ def update_module_name(self, op_in_array:Array, op_out_array:Array, op_name):
self.out_dict[op_out] = new_hash
self.hash_var_dict[new_hash] = op_out
op_out_hash.append(new_hash)
self.module_map.append((op_in_hash, op_out_hash, op_name))
operation = (op_in_hash, op_out_hash, op_name)
if not operation in self.module_map:
self.module_map.append(operation)
return [op_out_hash]

def print_parents_and_code(self, this:Name):
if len(self.parent_stack) >= 1:
Expand Down

0 comments on commit df4ddba

Please sign in to comment.