From 2004f5dd542a6d4bc4c55e09388f5e8baabe67c3 Mon Sep 17 00:00:00 2001 From: yilin-bao <62961093+yilin-bao@users.noreply.github.com> Date: Sat, 9 Dec 2023 01:20:04 -0600 Subject: [PATCH] Complete update module map functions --- analyzer.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/analyzer.py b/analyzer.py index b9415f5..3cdf90a 100644 --- a/analyzer.py +++ b/analyzer.py @@ -77,7 +77,7 @@ def generate_random_string(length=8): characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" return ''.join(random.choice(characters) for _ in range(length)) -def hash_code(code): +def hash_code(code=" abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"): code = code + generate_random_string() sha256_hash = hashlib.sha256() sha256_hash.update(code.encode('utf-8')) @@ -589,6 +589,7 @@ def special_case_length_one(self, parent, this:Name): if isinstance(parent, ast.Attribute): # [] nn.Module # [] self.revised + # self.print_parents_and_code(this) pass elif isinstance(parent, ast.Assign) and parent.targets[0] == this: # [] x = self.embedding_layer(x) @@ -596,10 +597,18 @@ def special_case_length_one(self, parent, this:Name): # [] x = self.post_transformer_ln(x) # [] x = self.cls_layer(x) # What we find is the variable [x] on the left side of assign, so ignore + # self.print_parents_and_code(this) pass elif isinstance(parent, ast.Assign) and not parent.targets[0] == this: + # seems is a case not gonna happen # self.print_parents_and_code(this) pass + elif isinstance(parent, ast.Call): + # self.print_parents_and_code(this) + op_name = self.from_node_to_operation(parent.func) + args = self.find_full_name_array(parent.args) + self.update_module_name(args, args, op_name) + pass else: # self.print_parents_and_code(this) pass @@ -620,7 +629,7 @@ def special_case_length_two(self, parent, grandparent, this:Name): # self.print_parents_and_code(this) pass elif isinstance(grandparent, ast.Call) and isinstance(parent, ast.Call): - # self.print_parents_and_code(this) + self.print_parents_and_code(this) pass elif isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Attribute): print(parent.attr) @@ -640,6 +649,25 @@ def from_node_to_operation(self, node): else: return f'{self.var_name}.{var_op_noself}' + def update_module_name_only_in(self, op_in_array:Array, op_name): + op_in_hash = [] + for op_in in op_in_array: + old_hash = self.out_dict[op_in] + op_in_hash.append(old_hash) + middle_hash = hash_code() + op_out_hash = [middle_hash] + self.module_map.append((op_in_hash, op_out_hash, op_name)) + return middle_hash + + def update_module_name_only_out(self, middle_hash, op_out_array:Array, op_name): + op_out_hash = [] + for op_out in op_out_array: + new_hash = hash_code(op_out) + self.out_dict[op_out] = new_hash + self.hash_var_dict[new_hash] = op_out + op_out_hash.append(new_hash) + self.module_map.append((middle_hash, op_out_hash, op_name)) + def update_module_name(self, op_in_array:Array, op_out_array:Array, op_name): op_in_hash = [] for op_in in op_in_array: