Skip to content

Commit

Permalink
Write method update module map
Browse files Browse the repository at this point in the history
  • Loading branch information
yilin-bao authored Dec 9, 2023
1 parent b6bb46a commit 04a19a4
Showing 1 changed file with 37 additions and 6 deletions.
43 changes: 37 additions & 6 deletions analyzer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import ast, astor
from ctypes import Array
import time
import hashlib
from ast import Assign, Attribute, BinOp, Call, For, Name, Return, Subscript, Tuple, mod
from gettext import find
Expand Down Expand Up @@ -70,7 +72,12 @@ class Color:
END = '\033[0m'


def generate_random_string(length=8):
characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
return ''.join(random.choice(characters) for _ in range(length))

def hash_code(code):
code = code + generate_random_string()
sha256_hash = hashlib.sha256()
sha256_hash.update(code.encode('utf-8'))
hashed_code = sha256_hash.hexdigest()
Expand Down Expand Up @@ -408,7 +415,7 @@ def __init__(self, var_module_dict, var_name, module_name):
self.var_module_dict:dict = var_module_dict
self.var_name = var_name
self.module_name = module_name
self.module_map = []
self.module_map = [] # ([Inputs], [Outputs], Function/Operation)

# forward_tensor_list: tensor, matrix, vector usd in deep learning
# forward_param_list: int, float, dimension variables, other variables
Expand Down Expand Up @@ -595,10 +602,13 @@ def special_case_length_two(self, parent, grandparent, this:Name):
# self.out_dict = {}
# self.hash_var_dict = {}
if isinstance(grandparent, ast.Assign) and isinstance(parent, ast.Call):
print(self.var_name)
print(self.module_name)
print(self.remove_starting_self(self.find_full_name(parent.func)))
print(self.find_full_name_array(parent.args))
if grandparent.value == parent:
op = self.from_node_to_operation(parent.func)
# print(self.var_module_dict)
args = self.find_full_name_array(parent.args)
targets = self.find_full_name_array(grandparent.targets)
print(args, targets)

self.print_parents_and_code(this)
elif isinstance(grandparent, ast.Call) and isinstance(parent, ast.Call):
# self.print_parents_and_code(this)
Expand All @@ -612,7 +622,28 @@ def special_case_length_two(self, parent, grandparent, this:Name):
else:
# self.print_parents_and_code(this)
pass

def from_node_to_operation(self, node):
var_op = self.find_full_name(node)
var_op_noself = self.remove_starting_self(var_op)
if self.var_name == 'self':
return var_op_noself
else:
return f'{self.var_name}.{var_op_noself}'

def update_module_name(self, op_in_array:Array, op_out_array:Array, op_name):
op_in_hash = []
for op_in in op_in_array:
old_hash = self.out_dict[op_in]
op_in_hash.append(old_hash)
op_out_hash = []
for op_out in op_out_array:
new_hash = hash_code(op_out)
self.out_dict[op_out] = new_hash
self.hash_var_dict[new_hash] = op_out
op_out_hash.append(new_hash)
self.module_map.append((op_in_hash, op_out_hash, op_name))

def print_parents_and_code(self, this:Name):
if len(self.parent_stack) >= 1:
print(f'{Color.BOLD_BLUE}{self.parent_stack}{Color.END} {this.id} {Color.LIME}{astor.to_source(self.parent_stack[0])}{Color.END}', end=' ')
print(f'{Color.BOLD_BLUE}{self.parent_stack}{Color.END} {this.id} {Color.LIME}{astor.to_source(self.parent_stack[0])}{Color.END}', end='')

0 comments on commit 04a19a4

Please sign in to comment.