Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: outputs order is wrong when more than 1 outputs #54

Merged
merged 2 commits into from
Jun 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 22 additions & 32 deletions zkstats/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[tor
if op_class_str not in self.op_dict:
self.precal_witness[op_class_str+"_0"] = [op.result.data.item()]
self.op_dict[op_class_str] = 1
else:
else:
self.precal_witness[op_class_str+"_"+str(self.op_dict[op_class_str])] = [op.result.data.item()]
self.op_dict[op_class_str]+=1
elif isinstance(op, Median):
Expand All @@ -177,7 +177,7 @@ def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[tor
if op_class_str not in self.op_dict:
self.precal_witness[op_class_str+"_0"] = [op.result.data.item(), op.data_mean.data.item()]
self.op_dict[op_class_str] = 1
else:
else:
self.precal_witness[op_class_str+"_"+str(self.op_dict[op_class_str])] = [op.result.data.item(), op.data_mean.data.item()]
self.op_dict[op_class_str]+=1
elif isinstance(op, Covariance):
Expand Down Expand Up @@ -229,7 +229,7 @@ def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[tor
current_op_index = self.current_op_index
# Sanity check that current op index is not out of bound
len_ops = len(self.ops)
if current_op_index >= len(self.ops):
if current_op_index >= len_ops:
raise Exception(f"current_op_index out of bound: {current_op_index=} >= {len_ops=}")

op = self.ops[current_op_index]
Expand All @@ -245,28 +245,12 @@ def is_precise() -> IsResultPrecise:
return op.ezkl(x)
self.bools.append(is_precise)

# If this is the last operation, aggregate all `is_precise` in `self.bools`, and return (is_precise_aggregated, result)
# else, return only result

if current_op_index == len_ops - 1:
# print('final op: ', op)
# Sanity check for length of self.ops and self.bools
len_bools = len(self.bools)
if len_ops != len_bools:
raise Exception(f"length mismatch: {len_ops=} != {len_bools=}")
is_precise_aggregated = torch.tensor(1.0)
for i in range(len_bools):
res = self.bools[i]()
is_precise_aggregated = torch.logical_and(is_precise_aggregated, res)
if self.isProver:
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0]

elif current_op_index > len_ops - 1:
if current_op_index > len_ops - 1:
# Sanity check that current op index does not exceed the length of ops
raise Exception(f"current_op_index out of bound: {current_op_index=} > {len_ops=}")
else:
return op.result+(x[0]-x[0])[0][0]
if self.isProver:
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
return op.result+(x[0]-x[0])[0][0]


class IModel(nn.Module):
Expand Down Expand Up @@ -296,23 +280,29 @@ def computation_to_model(computation: TComputation, precal_witness_path:str, isP
State is a container for intermediate results of computation, which can be useful when debugging.
"""
state = State(error)
# if it's verifier


state.precal_witness_path= precal_witness_path
state.isProver = isProver

class Model(IModel):
def preprocess(self, x: list[torch.Tensor]) -> None:
"""
Calculate the witnesses of the computation and store them in the state.
"""
# In the preprocess step, the operations are calculated and the results are stored in the state.
# So we don't need to get the returned result
computation(state, x)
state.set_ready_for_exporting_onnx()

def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
# print('x sy: ')
result = computation(state, x)
if len(result) ==1:
return (x[0]-x[0])[0][0]+torch.tensor(1.0), result
else:
return result
# print('state:: ', state.aggregate_witness_path)
"""
Called by torch.onnx.export.
"""
result = computation(state, x)
is_computation_result_accurate = state.bools[0]()
for op_precise_check in state.bools[1:]:
is_op_result_accurate = op_precise_check()
is_computation_result_accurate = torch.logical_and(is_computation_result_accurate, is_op_result_accurate)
return is_computation_result_accurate, result
return state, Model

Loading