-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from mhchia/feat/user-interface
Let user define computation in functions
- Loading branch information
Showing
13 changed files
with
854 additions
and
205 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
{ | ||
"input_data": [ | ||
[ | ||
23.2, 92.8, 91.0, 37.2, 82.0, 15.5, 79.3, 46.6, 98.1, 75.5, 78.9, 77.6, | ||
33.8, 75.7, 96.8, 12.3, 18.4, 13.4, 6.0, 8.2, 25.8, 41.3, 68.5, 15.2, | ||
74.7, 72.7, 18.0, 42.2, 36.1, 76.7, 1.2, 96.4, 4.9, 92.0, 12.8, 28.2, | ||
61.8, 56.9, 44.3, 50.4, 81.6, 72.5, 12.9, 40.3, 12.8, 28.8, 36.3, 16.1, | ||
68.4, 35.3, 79.2, 48.4, 97.1, 93.7, 77.0, 48.7, 93.7, 54.1, 65.4, 30.8, | ||
34.4, 31.4, 78.7, 12.7, 90.7, 39.4, 86.0, 55.9, 6.8, 22.2, 65.3, 18.8, | ||
7.1, 55.9, 38.6, 15.6, 59.2, 77.3, 76.9, 11.9, 19.9, 19.4, 54.3, 39.4, | ||
4.0, 61.1, 16.8, 81.9, 49.3, 76.9, 19.2, 68.2, 54.4, 70.2, 89.8, 23.4, | ||
67.5, 18.7, 10.8, 80.7, 80.3, 96.2, 62.3, 17.2, 23.0, 98.0, 19.1, 8.1, | ||
36.2, 7.5, 55.9, 1.2, 56.8, 85.1, 18.9, 23.0, 13.5, 64.3, 9.1, 14.1, 14.1, | ||
23.1, 73.2, 86.6, 39.1, 45.5, 85.0, 79.0, 15.8, 5.2, 81.5, 34.3, 24.3, | ||
14.2, 84.6, 33.7, 86.3, 83.3, 62.8, 72.7, 14.7, 36.8, 92.5, 4.7, 30.0, | ||
59.4, 57.6, 37.4, 22.0, 20.9, 61.6, 26.8, 47.1, 63.6, 6.0, 96.6, 61.2, | ||
80.2, 59.3, 23.1, 29.3, 46.3, 89.2, 77.6, 83.2, 87.2, 63.2, 81.8, 55.0, | ||
59.7, 57.8, 43.4, 92.4, 66.9, 82.1, 51.0, 22.1, 29.9, 41.0, 85.2, 61.5, | ||
14.6, 48.0, 52.7, 31.4, 83.9, 35.5, 77.3, 35.8, 32.6, 22.2, 19.3, 49.1, | ||
70.9, 43.9, 88.8, 56.3, 41.8, 90.3, 20.4, 80.4, 36.4, 91.5, 69.6, 75.3, | ||
92.4, 84.8, 17.7, 2.3, 41.3, 91.3, 68.6, 73.3, 62.5, 60.5, 73.5, 70.7, | ||
77.5, 76.8, 98.1, 40.9, 66.3, 8.6, 48.9, 75.4, 14.7, 35.9, 89.6, 15.1, | ||
45.0, 77.6, 30.5, 76.1, 46.9, 34.3, 65.1, 43.9, 91.6, 88.8, 8.9, 42.9, | ||
11.8, 32.1, 20.1, 48.9, 79.7, 15.3, 45.4, 80.1, 73.1, 76.5, 52.4, 9.6, | ||
41.9, 52.7, 55.1, 30.9, 83.7, 46.7, 39.3, 40.5, 52.4, 19.2, 25.8, 52.7, | ||
81.0, 38.0, 54.5, 15.3, 64.3, 88.3, 49.8, 90.5, 90.4, 79.7, 87.3, 32.3, | ||
11.9, 5.7, 33.6, 75.1, 65.9, 29.1, 39.4, 87.5, 3.3, 66.3, 79.0, 97.9, | ||
69.6, 22.0, 62.8, 97.1, 90.4, 39.5, 11.7, 30.3, 18.9, 34.6, 6.6 | ||
] | ||
] | ||
} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import pytest | ||
import torch | ||
|
||
|
||
@pytest.fixture | ||
def error() -> float: | ||
return 0.01 | ||
|
||
|
||
@pytest.fixture | ||
def column_0(): | ||
return torch.tensor([3.0, 4.5, 1.0, 2.0, 7.5, 6.4, 5.5]) | ||
|
||
|
||
@pytest.fixture | ||
def column_1(): | ||
return torch.tensor([2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4]) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import json | ||
from typing import Type | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
from zkstats.core import prover_gen_settings, verifier_setup, prover_gen_proof, verifier_verify | ||
from zkstats.computation import IModel, IsResultPrecise | ||
|
||
|
||
def compute(basepath: Path, data: list[torch.Tensor], model: Type[IModel]) -> IsResultPrecise: | ||
comb_data_path = basepath / "comb_data.json" | ||
model_path = basepath / "model.onnx" | ||
settings_path = basepath / "settings.json" | ||
witness_path = basepath / "witness.json" | ||
compiled_model_path = basepath / "model.compiled" | ||
proof_path = basepath / "model.proof" | ||
pk_path = basepath / "model.pk" | ||
vk_path = basepath / "model.vk" | ||
data_paths = [basepath / f"data_{i}.json" for i in range(len(data))] | ||
|
||
for i, d in enumerate(data): | ||
filename = data_paths[i] | ||
data_json = {"input_data": [d.tolist()]} | ||
with open(filename, "w") as f: | ||
f.write(json.dumps(data_json)) | ||
|
||
prover_gen_settings( | ||
data_path_array=[str(i) for i in data_paths], | ||
comb_data_path=str(comb_data_path), | ||
prover_model=model, | ||
prover_model_path=str(model_path), | ||
scale="default", | ||
mode="resources", | ||
settings_path=str(settings_path), | ||
) | ||
verifier_setup( | ||
str(model_path), | ||
str(compiled_model_path), | ||
str(settings_path), | ||
str(vk_path), | ||
str(pk_path), | ||
) | ||
prover_gen_proof( | ||
str(model_path), | ||
str(comb_data_path), | ||
str(witness_path), | ||
str(compiled_model_path), | ||
str(settings_path), | ||
str(proof_path), | ||
str(pk_path), | ||
) | ||
verifier_verify( | ||
str(proof_path), | ||
str(settings_path), | ||
str(vk_path), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import statistics | ||
import torch | ||
import torch | ||
|
||
from zkstats.computation import State, create_model | ||
from zkstats.ops import Mean, Median | ||
|
||
from .helpers import compute | ||
|
||
|
||
def computation(state: State, x: list[torch.Tensor]): | ||
out_0 = state.median(x[0]) | ||
out_1 = state.median(x[1]) | ||
return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1)) | ||
|
||
|
||
def test_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float): | ||
state, model = create_model(computation, error) | ||
compute(tmp_path, [column_0, column_1], model) | ||
assert state.current_op_index == 3 | ||
|
||
ops = state.ops | ||
op0 = ops[0] | ||
assert isinstance(op0, Median) | ||
assert op0.result == statistics.median(column_0) | ||
op1 = ops[1] | ||
assert isinstance(op1, Median) | ||
assert op1.result == statistics.median(column_1) | ||
op2 = ops[2] | ||
assert isinstance(op2, Mean) | ||
assert op2.result == statistics.mean([op0.result.tolist(), op1.result.tolist()]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import json | ||
from typing import Type, Callable | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
import statistics | ||
|
||
import pytest | ||
|
||
import torch | ||
from zkstats.computation import Operation, Mean, Median, IModel, IsResultPrecise | ||
|
||
from .helpers import compute | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"op_type, expected_func", | ||
[ | ||
(Mean, statistics.mean), | ||
(Median, statistics.median), | ||
] | ||
) | ||
def test_1d(tmp_path, column_0: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float]): | ||
op = op_type.create(column_0, error) | ||
expected_res = expected_func(column_0.tolist()) | ||
assert expected_res == op.result | ||
model = op_to_model(op) | ||
compute(tmp_path, [column_0], model) | ||
|
||
|
||
def op_to_model(op: Operation) -> Type[IModel]: | ||
class Model(IModel): | ||
def forward(self, x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]: | ||
return op.ezkl(x), op.result | ||
return Model | ||
|
||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from abc import abstractmethod | ||
from typing import Callable, Type, Optional, Union | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from .ops import Operation, Mean, Median, IsResultPrecise | ||
|
||
|
||
DEFAULT_ERROR = 0.01 | ||
|
||
|
||
class State: | ||
""" | ||
State is a container for intermediate results of computation. | ||
Stage 1 (current_op_index is None): for every call to State (mean, median, etc.), result | ||
is calculated and temporarily stored in the state. Call `set_ready_for_exporting_onnx` to indicate | ||
Stage 2: all operations are calculated and results are ready to be used. Call `set_ready_for_exporting_onnx` | ||
to indicate it's ready to generate settings. | ||
Stage 3 (current_op_index is not None): when exporting to onnx, when the operations are called, the results and | ||
the conditions are popped from the state and filled in the onnx graph. | ||
""" | ||
def __init__(self, error: float) -> None: | ||
self.ops: list[Operation] = [] | ||
self.bools: list[Callable[[], torch.Tensor]] = [] | ||
self.error: float = error | ||
# Pointer to the current operation index. If None, it's in stage 1. If not None, it's in stage 3. | ||
self.current_op_index: Optional[int] = None | ||
|
||
def set_ready_for_exporting_onnx(self) -> None: | ||
self.current_op_index = 0 | ||
|
||
def mean(self, X: torch.Tensor) -> tuple[IsResultPrecise, torch.Tensor]: | ||
return self._call_op(X, Mean) | ||
|
||
def median(self, X: torch.Tensor) -> tuple[IsResultPrecise, torch.Tensor]: | ||
return self._call_op(X, Median) | ||
|
||
# TODO: add the rest of the operations | ||
|
||
def _call_op(self, x: torch.Tensor, op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]: | ||
if self.current_op_index is None: | ||
op = op_type.create(x, self.error) | ||
self.ops.append(op) | ||
return op.result | ||
else: | ||
# Copy the current op index to a local variable since self.current_op_index will be incremented. | ||
current_op_index = self.current_op_index | ||
# Sanity check that current op index is not out of bound | ||
len_ops = len(self.ops) | ||
if current_op_index >= len(self.ops): | ||
raise Exception(f"current_op_index out of bound: {current_op_index=} >= {len_ops=}") | ||
|
||
op = self.ops[current_op_index] | ||
# Sanity check that the operation type matches the current op type | ||
if not isinstance(op, op_type): | ||
raise Exception(f"operation type mismatch: {op_type=} != {type(op)=}") | ||
|
||
# Increment the current op index | ||
self.current_op_index += 1 | ||
|
||
# Push the ezkl condition, which is checked only in the last operation | ||
def is_precise() -> IsResultPrecise: | ||
return op.ezkl(x) | ||
self.bools.append(is_precise) | ||
|
||
# If this is the last operation, aggregate all `is_precise` in `self.bools`, and return (is_precise_aggregated, result) | ||
# else, return only result | ||
if current_op_index == len_ops - 1: | ||
# Sanity check for length of self.ops and self.bools | ||
len_bools = len(self.bools) | ||
if len_ops != len_bools: | ||
raise Exception(f"length mismatch: {len_ops=} != {len_bools=}") | ||
is_precise_aggregated = torch.tensor(1.0) | ||
for i in range(len_bools): | ||
res = self.bools[i]() | ||
is_precise_aggregated = torch.logical_and(is_precise_aggregated, res) | ||
return is_precise_aggregated, op.result | ||
elif current_op_index > len_ops - 1: | ||
# Sanity check that current op index does not exceed the length of ops | ||
raise Exception(f"current_op_index out of bound: {current_op_index=} > {len_ops=}") | ||
else: | ||
# It's not the last operation, just return the result | ||
return op.result | ||
|
||
|
||
class IModel(nn.Module): | ||
@abstractmethod | ||
def preprocess(self, x: list[torch.Tensor]) -> None: | ||
... | ||
|
||
@abstractmethod | ||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]: | ||
... | ||
|
||
|
||
|
||
# An computation function. Example: | ||
# def computation(state: State, x: list[torch.Tensor]): | ||
# out_0 = state.median(x[0]) | ||
# out_1 = state.median(x[1]) | ||
# return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1)) | ||
TComputation = Callable[[State, list[torch.Tensor]], tuple[IsResultPrecise, torch.Tensor]] | ||
|
||
|
||
def create_model(computation: TComputation, error: float = DEFAULT_ERROR) -> tuple[State, Type[IModel]]: | ||
""" | ||
Create a torch model from a `computation` function defined by user | ||
""" | ||
state = State(error) | ||
|
||
class Model(IModel): | ||
def preprocess(self, x: list[torch.Tensor]) -> None: | ||
computation(state, x) | ||
state.set_ready_for_exporting_onnx() | ||
|
||
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]: | ||
return computation(state, x) | ||
|
||
return state, Model |
Oops, something went wrong.