Skip to content

Commit

Permalink
Merge pull request #5 from mhchia/feat/user-interface
Browse files Browse the repository at this point in the history
Let user define computation in functions
  • Loading branch information
JernKunpittaya committed Jan 24, 2024
2 parents 11cccb8 + 9cc9c62 commit adffa15
Show file tree
Hide file tree
Showing 13 changed files with 854 additions and 205 deletions.
326 changes: 326 additions & 0 deletions examples/computation/computation.ipynb

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions examples/computation/data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"input_data": [
[
23.2, 92.8, 91.0, 37.2, 82.0, 15.5, 79.3, 46.6, 98.1, 75.5, 78.9, 77.6,
33.8, 75.7, 96.8, 12.3, 18.4, 13.4, 6.0, 8.2, 25.8, 41.3, 68.5, 15.2,
74.7, 72.7, 18.0, 42.2, 36.1, 76.7, 1.2, 96.4, 4.9, 92.0, 12.8, 28.2,
61.8, 56.9, 44.3, 50.4, 81.6, 72.5, 12.9, 40.3, 12.8, 28.8, 36.3, 16.1,
68.4, 35.3, 79.2, 48.4, 97.1, 93.7, 77.0, 48.7, 93.7, 54.1, 65.4, 30.8,
34.4, 31.4, 78.7, 12.7, 90.7, 39.4, 86.0, 55.9, 6.8, 22.2, 65.3, 18.8,
7.1, 55.9, 38.6, 15.6, 59.2, 77.3, 76.9, 11.9, 19.9, 19.4, 54.3, 39.4,
4.0, 61.1, 16.8, 81.9, 49.3, 76.9, 19.2, 68.2, 54.4, 70.2, 89.8, 23.4,
67.5, 18.7, 10.8, 80.7, 80.3, 96.2, 62.3, 17.2, 23.0, 98.0, 19.1, 8.1,
36.2, 7.5, 55.9, 1.2, 56.8, 85.1, 18.9, 23.0, 13.5, 64.3, 9.1, 14.1, 14.1,
23.1, 73.2, 86.6, 39.1, 45.5, 85.0, 79.0, 15.8, 5.2, 81.5, 34.3, 24.3,
14.2, 84.6, 33.7, 86.3, 83.3, 62.8, 72.7, 14.7, 36.8, 92.5, 4.7, 30.0,
59.4, 57.6, 37.4, 22.0, 20.9, 61.6, 26.8, 47.1, 63.6, 6.0, 96.6, 61.2,
80.2, 59.3, 23.1, 29.3, 46.3, 89.2, 77.6, 83.2, 87.2, 63.2, 81.8, 55.0,
59.7, 57.8, 43.4, 92.4, 66.9, 82.1, 51.0, 22.1, 29.9, 41.0, 85.2, 61.5,
14.6, 48.0, 52.7, 31.4, 83.9, 35.5, 77.3, 35.8, 32.6, 22.2, 19.3, 49.1,
70.9, 43.9, 88.8, 56.3, 41.8, 90.3, 20.4, 80.4, 36.4, 91.5, 69.6, 75.3,
92.4, 84.8, 17.7, 2.3, 41.3, 91.3, 68.6, 73.3, 62.5, 60.5, 73.5, 70.7,
77.5, 76.8, 98.1, 40.9, 66.3, 8.6, 48.9, 75.4, 14.7, 35.9, 89.6, 15.1,
45.0, 77.6, 30.5, 76.1, 46.9, 34.3, 65.1, 43.9, 91.6, 88.8, 8.9, 42.9,
11.8, 32.1, 20.1, 48.9, 79.7, 15.3, 45.4, 80.1, 73.1, 76.5, 52.4, 9.6,
41.9, 52.7, 55.1, 30.9, 83.7, 46.7, 39.3, 40.5, 52.4, 19.2, 25.8, 52.7,
81.0, 38.0, 54.5, 15.3, 64.3, 88.3, 49.8, 90.5, 90.4, 79.7, 87.3, 32.3,
11.9, 5.7, 33.6, 75.1, 65.9, 29.1, 39.4, 87.5, 3.3, 66.3, 79.0, 97.9,
69.6, 22.0, 62.8, 97.1, 90.4, 39.5, 11.7, 30.3, 18.9, 34.6, 6.6
]
]
}
Empty file added tests/__init__.py
Empty file.
18 changes: 18 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
import torch


@pytest.fixture
def error() -> float:
return 0.01


@pytest.fixture
def column_0():
return torch.tensor([3.0, 4.5, 1.0, 2.0, 7.5, 6.4, 5.5])


@pytest.fixture
def column_1():
return torch.tensor([2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4])

57 changes: 57 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import json
from typing import Type
from pathlib import Path

import torch

from zkstats.core import prover_gen_settings, verifier_setup, prover_gen_proof, verifier_verify
from zkstats.computation import IModel, IsResultPrecise


def compute(basepath: Path, data: list[torch.Tensor], model: Type[IModel]) -> IsResultPrecise:
comb_data_path = basepath / "comb_data.json"
model_path = basepath / "model.onnx"
settings_path = basepath / "settings.json"
witness_path = basepath / "witness.json"
compiled_model_path = basepath / "model.compiled"
proof_path = basepath / "model.proof"
pk_path = basepath / "model.pk"
vk_path = basepath / "model.vk"
data_paths = [basepath / f"data_{i}.json" for i in range(len(data))]

for i, d in enumerate(data):
filename = data_paths[i]
data_json = {"input_data": [d.tolist()]}
with open(filename, "w") as f:
f.write(json.dumps(data_json))

prover_gen_settings(
data_path_array=[str(i) for i in data_paths],
comb_data_path=str(comb_data_path),
prover_model=model,
prover_model_path=str(model_path),
scale="default",
mode="resources",
settings_path=str(settings_path),
)
verifier_setup(
str(model_path),
str(compiled_model_path),
str(settings_path),
str(vk_path),
str(pk_path),
)
prover_gen_proof(
str(model_path),
str(comb_data_path),
str(witness_path),
str(compiled_model_path),
str(settings_path),
str(proof_path),
str(pk_path),
)
verifier_verify(
str(proof_path),
str(settings_path),
str(vk_path),
)
31 changes: 31 additions & 0 deletions tests/test_computation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import statistics
import torch
import torch

from zkstats.computation import State, create_model
from zkstats.ops import Mean, Median

from .helpers import compute


def computation(state: State, x: list[torch.Tensor]):
out_0 = state.median(x[0])
out_1 = state.median(x[1])
return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))


def test_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float):
state, model = create_model(computation, error)
compute(tmp_path, [column_0, column_1], model)
assert state.current_op_index == 3

ops = state.ops
op0 = ops[0]
assert isinstance(op0, Median)
assert op0.result == statistics.median(column_0)
op1 = ops[1]
assert isinstance(op1, Median)
assert op1.result == statistics.median(column_1)
op2 = ops[2]
assert isinstance(op2, Mean)
assert op2.result == statistics.mean([op0.result.tolist(), op1.result.tolist()])
36 changes: 36 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import json
from typing import Type, Callable
from dataclasses import dataclass
from pathlib import Path
import statistics

import pytest

import torch
from zkstats.computation import Operation, Mean, Median, IModel, IsResultPrecise

from .helpers import compute


@pytest.mark.parametrize(
"op_type, expected_func",
[
(Mean, statistics.mean),
(Median, statistics.median),
]
)
def test_1d(tmp_path, column_0: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float]):
op = op_type.create(column_0, error)
expected_res = expected_func(column_0.tolist())
assert expected_res == op.result
model = op_to_model(op)
compute(tmp_path, [column_0], model)


def op_to_model(op: Operation) -> Type[IModel]:
class Model(IModel):
def forward(self, x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
return op.ezkl(x), op.result
return Model


Empty file added zkstats/__init__.py
Empty file.
35 changes: 32 additions & 3 deletions zkstats/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os
import sys
from typing import Type
import importlib.util

import click
import torch

from .core import prover_gen_proof, prover_setup, load_model, verifier_verify, gen_data_commitment
from .core import prover_gen_proof, prover_gen_settings, verifier_setup, verifier_verify, gen_data_commitment

cwd = os.getcwd()
# TODO: Should make this configurable
Expand Down Expand Up @@ -29,15 +34,19 @@ def cli():
def prove(model_path: str, data_path: str):
model = load_model(model_path)
print("Loaded model:", model)
prover_setup(
prover_gen_settings(
[data_path],
comb_data_path,
model,
model_onnx_path,
compiled_model_path,
"default",
"resources",
settings_path,
)
verifier_setup(
model_path,
compiled_model_path,
settings_path,
vk_path,
pk_path,
)
Expand Down Expand Up @@ -80,6 +89,26 @@ def main():
cli()


def load_model(module_path: str) -> Type[torch.nn.Module]:
"""
Load a model from a Python module.
"""
# FIXME: This is unsafe since malicious code can be executed

model_name = "Model"
module_name = os.path.splitext(os.path.basename(module_path))[0]
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)

try:
cls = getattr(module, model_name)
except AttributeError:
raise ImportError(f"class {model_name} does not exist in {module_name}")
return cls


# Register commands
cli.add_command(prove)
cli.add_command(verify)
Expand Down
121 changes: 121 additions & 0 deletions zkstats/computation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from abc import abstractmethod
from typing import Callable, Type, Optional, Union

import torch
from torch import nn

from .ops import Operation, Mean, Median, IsResultPrecise


DEFAULT_ERROR = 0.01


class State:
"""
State is a container for intermediate results of computation.
Stage 1 (current_op_index is None): for every call to State (mean, median, etc.), result
is calculated and temporarily stored in the state. Call `set_ready_for_exporting_onnx` to indicate
Stage 2: all operations are calculated and results are ready to be used. Call `set_ready_for_exporting_onnx`
to indicate it's ready to generate settings.
Stage 3 (current_op_index is not None): when exporting to onnx, when the operations are called, the results and
the conditions are popped from the state and filled in the onnx graph.
"""
def __init__(self, error: float) -> None:
self.ops: list[Operation] = []
self.bools: list[Callable[[], torch.Tensor]] = []
self.error: float = error
# Pointer to the current operation index. If None, it's in stage 1. If not None, it's in stage 3.
self.current_op_index: Optional[int] = None

def set_ready_for_exporting_onnx(self) -> None:
self.current_op_index = 0

def mean(self, X: torch.Tensor) -> tuple[IsResultPrecise, torch.Tensor]:
return self._call_op(X, Mean)

def median(self, X: torch.Tensor) -> tuple[IsResultPrecise, torch.Tensor]:
return self._call_op(X, Median)

# TODO: add the rest of the operations

def _call_op(self, x: torch.Tensor, op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
if self.current_op_index is None:
op = op_type.create(x, self.error)
self.ops.append(op)
return op.result
else:
# Copy the current op index to a local variable since self.current_op_index will be incremented.
current_op_index = self.current_op_index
# Sanity check that current op index is not out of bound
len_ops = len(self.ops)
if current_op_index >= len(self.ops):
raise Exception(f"current_op_index out of bound: {current_op_index=} >= {len_ops=}")

op = self.ops[current_op_index]
# Sanity check that the operation type matches the current op type
if not isinstance(op, op_type):
raise Exception(f"operation type mismatch: {op_type=} != {type(op)=}")

# Increment the current op index
self.current_op_index += 1

# Push the ezkl condition, which is checked only in the last operation
def is_precise() -> IsResultPrecise:
return op.ezkl(x)
self.bools.append(is_precise)

# If this is the last operation, aggregate all `is_precise` in `self.bools`, and return (is_precise_aggregated, result)
# else, return only result
if current_op_index == len_ops - 1:
# Sanity check for length of self.ops and self.bools
len_bools = len(self.bools)
if len_ops != len_bools:
raise Exception(f"length mismatch: {len_ops=} != {len_bools=}")
is_precise_aggregated = torch.tensor(1.0)
for i in range(len_bools):
res = self.bools[i]()
is_precise_aggregated = torch.logical_and(is_precise_aggregated, res)
return is_precise_aggregated, op.result
elif current_op_index > len_ops - 1:
# Sanity check that current op index does not exceed the length of ops
raise Exception(f"current_op_index out of bound: {current_op_index=} > {len_ops=}")
else:
# It's not the last operation, just return the result
return op.result


class IModel(nn.Module):
@abstractmethod
def preprocess(self, x: list[torch.Tensor]) -> None:
...

@abstractmethod
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
...



# An computation function. Example:
# def computation(state: State, x: list[torch.Tensor]):
# out_0 = state.median(x[0])
# out_1 = state.median(x[1])
# return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))
TComputation = Callable[[State, list[torch.Tensor]], tuple[IsResultPrecise, torch.Tensor]]


def create_model(computation: TComputation, error: float = DEFAULT_ERROR) -> tuple[State, Type[IModel]]:
"""
Create a torch model from a `computation` function defined by user
"""
state = State(error)

class Model(IModel):
def preprocess(self, x: list[torch.Tensor]) -> None:
computation(state, x)
state.set_ready_for_exporting_onnx()

def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
return computation(state, x)

return state, Model
Loading

0 comments on commit adffa15

Please sign in to comment.