Skip to content

Commit

Permalink
Complete update module map functions
Browse files Browse the repository at this point in the history
  • Loading branch information
yilin-bao authored Dec 9, 2023
1 parent e001d96 commit 2004f5d
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def generate_random_string(length=8):
characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
return ''.join(random.choice(characters) for _ in range(length))

def hash_code(code):
def hash_code(code=" abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"):
code = code + generate_random_string()
sha256_hash = hashlib.sha256()
sha256_hash.update(code.encode('utf-8'))
Expand Down Expand Up @@ -589,17 +589,26 @@ def special_case_length_one(self, parent, this:Name):
if isinstance(parent, ast.Attribute):
# [<ast.Attribute object at 0x12aa0f790>] nn.Module
# [<ast.Attribute object at 0x12aa0d4b0>] self.revised
# self.print_parents_and_code(this)
pass
elif isinstance(parent, ast.Assign) and parent.targets[0] == this:
# [<ast.Assign object at 0x130b28760>] x = self.embedding_layer(x)
# [<ast.Assign object at 0x130b28940>] x = self.transformer(x)
# [<ast.Assign object at 0x130b29a50>] x = self.post_transformer_ln(x)
# [<ast.Assign object at 0x130b29f30>] x = self.cls_layer(x)
# What we find is the variable [x] on the left side of assign, so ignore
# self.print_parents_and_code(this)
pass
elif isinstance(parent, ast.Assign) and not parent.targets[0] == this:
# seems is a case not gonna happen
# self.print_parents_and_code(this)
pass
elif isinstance(parent, ast.Call):
# self.print_parents_and_code(this)
op_name = self.from_node_to_operation(parent.func)
args = self.find_full_name_array(parent.args)
self.update_module_name(args, args, op_name)
pass
else:
# self.print_parents_and_code(this)
pass
Expand All @@ -620,7 +629,7 @@ def special_case_length_two(self, parent, grandparent, this:Name):
# self.print_parents_and_code(this)
pass
elif isinstance(grandparent, ast.Call) and isinstance(parent, ast.Call):
# self.print_parents_and_code(this)
self.print_parents_and_code(this)
pass
elif isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Attribute):
print(parent.attr)
Expand All @@ -640,6 +649,25 @@ def from_node_to_operation(self, node):
else:
return f'{self.var_name}.{var_op_noself}'

def update_module_name_only_in(self, op_in_array:Array, op_name):
op_in_hash = []
for op_in in op_in_array:
old_hash = self.out_dict[op_in]
op_in_hash.append(old_hash)
middle_hash = hash_code()
op_out_hash = [middle_hash]
self.module_map.append((op_in_hash, op_out_hash, op_name))
return middle_hash

def update_module_name_only_out(self, middle_hash, op_out_array:Array, op_name):
op_out_hash = []
for op_out in op_out_array:
new_hash = hash_code(op_out)
self.out_dict[op_out] = new_hash
self.hash_var_dict[new_hash] = op_out
op_out_hash.append(new_hash)
self.module_map.append((middle_hash, op_out_hash, op_name))

def update_module_name(self, op_in_array:Array, op_out_array:Array, op_name):
op_in_hash = []
for op_in in op_in_array:
Expand Down

0 comments on commit 2004f5d

Please sign in to comment.