diff --git a/examples/game256/game256_contracts.py b/examples/game256/game256_contracts.py index c068e6f..1439820 100644 --- a/examples/game256/game256_contracts.py +++ b/examples/game256/game256_contracts.py @@ -1,10 +1,11 @@ +from dataclasses import dataclass from matt import NUMS_KEY from matt.argtypes import BytesType, IntType, SignerType from matt.btctools.common import sha256 -from matt.btctools.script import OP_ADD, OP_CHECKSIG, OP_DUP, OP_FROMALTSTACK, OP_ROT, OP_SHA256, OP_SWAP, OP_TOALTSTACK, CScript -from matt.contracts import ClauseOutput, StandardClause, StandardAugmentedP2TR, StandardP2TR +from matt.btctools.script import OP_ADD, OP_CHECKSIG, OP_DUP, OP_EQUAL, OP_FROMALTSTACK, OP_NOT, OP_PICK, OP_ROT, OP_SHA256, OP_SWAP, OP_TOALTSTACK, OP_VERIFY, CScript +from matt.contracts import ClauseOutput, StandardClause, StandardAugmentedP2TR, StandardP2TR, ContractState from matt.hub.fraud import Bisect_1, Computer, Leaf from matt.merkle import MerkleTree from matt.script_helpers import check_input_contract, check_output_contract, drop, dup, merkle_root, older @@ -16,23 +17,19 @@ # TODO: how to generalize what the contract does after the leaf? We should be able to compose clauses with some external code. # Do we need "clause" algebra? -# TODO: Augmented contracts should also specify the "encoder" for its data, so that callers don't have to worry -# about handling Merkle trees by hand. -# Might also be needed to define "higher order contracts" that can be used as a gadget, then provide a result -# to some other contract provided by the caller. - class G256_S0(StandardP2TR): def __init__(self, alice_pk: bytes, bob_pk: bytes, forfait_timeout: int = 10): self.alice_pk = alice_pk self.bob_pk = bob_pk self.forfait_timeout = forfait_timeout + g256_s1 = G256_S1(alice_pk, bob_pk, forfait_timeout) # witness: choose = StandardClause( name="choose", script=CScript([ - OP_SHA256, # sha256(x) - *check_output_contract(G256_S1(alice_pk, bob_pk, forfait_timeout)), + *g256_s1.State.encoder_script(), + *check_output_contract(g256_s1), bob_pk, OP_CHECKSIG @@ -41,10 +38,10 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, forfait_timeout: int = 10): ('bob_sig', SignerType(bob_pk)), ('x', IntType()), ], - next_output_fn=lambda args: [ClauseOutput( + next_outputs_fn=lambda args, _: [ClauseOutput( n=-1, - next_contract=G256_S1(alice_pk, bob_pk, forfait_timeout), - next_data=sha256(encode_wit_element(args['x'])) + next_contract=g256_s1, + next_state=g256_s1.State(x=args['x']) )] ) @@ -52,6 +49,16 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, forfait_timeout: int = 10): class G256_S1(StandardAugmentedP2TR): + @dataclass + class State(ContractState): + x: int + + def encode(self): + return sha256(encode_wit_element(self.x)) + + def encoder_script(): + return CScript([OP_SHA256]) + def __init__(self, alice_pk: bytes, bob_pk: bytes, forfait_timeout): self.alice_pk = alice_pk self.bob_pk = bob_pk @@ -59,24 +66,18 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, forfait_timeout): g256_s2 = G256_S2(alice_pk, bob_pk, forfait_timeout) - # reveal: + # reveal: reveal = StandardClause( name="reveal", script=CScript([ OP_DUP, # check that the top of the stack is the embedded data + *self.State.encoder_script(), *check_input_contract(), - OP_TOALTSTACK, - OP_SHA256, - OP_FROMALTSTACK, - - # - *merkle_root(3), - - # - + # + *g256_s2.State.encoder_script(), *check_output_contract(g256_s2), alice_pk, @@ -86,12 +87,12 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, forfait_timeout): ('alice_sig', SignerType(alice_pk)), ('t_a', BytesType()), ('y', IntType()), - ('sha256_x', BytesType()), + ('x', IntType()), ], - next_output_fn=lambda args: [ClauseOutput( + next_outputs_fn=lambda args, _: [ClauseOutput( n=-1, next_contract=g256_s2, - next_data=MerkleTree([args['t_a'], sha256(encode_wit_element(args['y'])), args['sha256_x']]).root + next_state=g256_s2.State(t_a=args['t_a'], y=args['y'], x=args['x']) )] ) @@ -113,6 +114,21 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, forfait_timeout): class G256_S2(StandardAugmentedP2TR): + @dataclass + class State(ContractState): + t_a: bytes + y: int + x: bytes + + def encode(self): + return MerkleTree([self.t_a, sha256(encode_wit_element(self.y)), sha256(encode_wit_element(self.x))]).root + + def encoder_script(): + return CScript([ + OP_TOALTSTACK, OP_SHA256, OP_FROMALTSTACK, OP_SHA256, + *merkle_root(3) + ]) + def __init__(self, alice_pk: bytes, bob_pk: bytes, forfait_timeout: int = 10): self.alice_pk = alice_pk self.bob_pk = bob_pk @@ -133,27 +149,32 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, forfait_timeout: int = 10): def leaf_factory(i: int): return Leaf(alice_pk, bob_pk, Compute2x) bisectg256_0 = Bisect_1(alice_pk, bob_pk, 0, 7, leaf_factory, forfait_timeout) - # start_challenge: + # start_challenge: start_challenge = StandardClause( name="start_challenge", script=CScript([ OP_TOALTSTACK, - OP_SHA256, OP_TOALTSTACK, - # --- + # check that y != z + OP_DUP, 3, OP_PICK, OP_EQUAL, OP_NOT, OP_VERIFY, + + OP_TOALTSTACK, + + # --- - # verify the embedded data *dup(3), - *merkle_root(3), + + # verify the embedded data + *self.State.encoder_script(), *check_input_contract(), - # --- - OP_SWAP, - # --- + # --- + OP_SHA256, OP_SWAP, OP_SHA256, + # --- OP_ROT, # --- - OP_FROMALTSTACK, + OP_FROMALTSTACK, OP_SHA256, # --- OP_SWAP, # --- @@ -162,7 +183,7 @@ def leaf_factory(i: int): return Leaf(alice_pk, bob_pk, Compute2x) # - *merkle_root(5), + *bisectg256_0.State.encoder_script(), *check_output_contract(bisectg256_0), bob_pk, @@ -171,21 +192,21 @@ def leaf_factory(i: int): return Leaf(alice_pk, bob_pk, Compute2x) arg_specs=[ ('bob_sig', SignerType(bob_pk)), ('t_a', BytesType()), - ('sha256_y', BytesType()), - ('sha256_x', BytesType()), + ('y', IntType()), + ('x', IntType()), ('z', IntType()), ('t_b', BytesType()), ], - next_output_fn=lambda args: [ClauseOutput( + next_outputs_fn=lambda args, _: [ClauseOutput( n=-1, next_contract=bisectg256_0, - next_data=MerkleTree([ - args['sha256_x'], - args['sha256_y'], - sha256(encode_wit_element(args['z'])), - args['t_a'], - args['t_b'], - ]).root + next_state=bisectg256_0.State( + h_i=sha256(encode_wit_element(args['x'])), + h_j_plus_1_a=sha256(encode_wit_element(args['y'])), + h_j_plus_1_b=sha256(encode_wit_element(args['z'])), + t_i_j_a=args['t_a'], + t_i_j_b=args['t_b'], + ) )] ) diff --git a/examples/ram/ram_contracts.py b/examples/ram/ram_contracts.py index 9120a8b..b6adffa 100644 --- a/examples/ram/ram_contracts.py +++ b/examples/ram/ram_contracts.py @@ -1,14 +1,26 @@ +from dataclasses import dataclass from typing import List from matt import CCV_FLAG_CHECK_INPUT, NUMS_KEY from matt.argtypes import BytesType, MerkleProofType from matt.btctools.script import OP_CAT, OP_CHECKCONTRACTVERIFY, OP_DUP, OP_ELSE, OP_ENDIF, OP_EQUAL, OP_EQUALVERIFY, OP_FROMALTSTACK, OP_IF, OP_NOTIF, OP_PICK, OP_ROLL, OP_ROT, OP_SHA256, OP_SWAP, OP_TOALTSTACK, OP_TRUE, CScript -from matt.contracts import ClauseOutput, StandardClause, StandardAugmentedP2TR -from matt.merkle import is_power_of_2, floor_lg +from matt.contracts import ClauseOutput, StandardClause, StandardAugmentedP2TR, ContractState +from matt.merkle import MerkleTree, is_power_of_2, floor_lg +from matt.script_helpers import merkle_root class RAM(StandardAugmentedP2TR): - def __init__(self, size: List[bytes]): + @dataclass + class State(ContractState): + leaves: List[bytes] + + def encode(self): + return MerkleTree(self.leaves).root + + def encoder_script(size: int): + return merkle_root(size) + + def __init__(self, size: int): assert is_power_of_2(size) self.size = size @@ -57,6 +69,19 @@ def __init__(self, size: List[bytes]): ] ) + def next_outputs_fn(args: dict, state: RAM.State): + i: int = args["merkle_proof"].get_leaf_index() + + return [ + ClauseOutput( + n=-1, + next_contract=self, + next_state=self.State( + leaves=state.leaves[:i] + [args["new_value"]] + state.leaves[i+1:] + ) + ) + ] + # witness: ... write = StandardClause( name="write", @@ -86,44 +111,44 @@ def __init__(self, size: List[bytes]): # TODO: seems too verbose, there should be a way of optimizing it # top of stack is now: OP_IF, - # top of stack is now: - # right child: we want h || x - 2, OP_PICK, - # top of stack is now: - OP_SWAP, - OP_CAT, - OP_SHA256, - # top of stack is now: - - OP_SWAP, - # top of stack is now: - OP_ROT, - # top of stack is now: - OP_SWAP, - # OP_CAT, - # OP_SHA256, - # # top of stack is now: - - # OP_SWAP, - # # top of stack is now: + # top of stack is now: + # right child: we want h || x + 2, OP_PICK, + # top of stack is now: + OP_SWAP, + OP_CAT, + OP_SHA256, + # top of stack is now: + + OP_SWAP, + # top of stack is now: + OP_ROT, + # top of stack is now: + OP_SWAP, + # OP_CAT, + # OP_SHA256, + # # top of stack is now: + + # OP_SWAP, + # # top of stack is now: OP_ELSE, - # top of stack is now: - 2, OP_PICK, - # top of stack is now: - OP_CAT, - OP_SHA256, - # top of stack is now: - - OP_SWAP, - OP_ROT, - # top of stack is now: - - # OP_CAT, - # OP_SHA256, - # # top of stack is now: - - # OP_SWAP, - # # top of stack is now: + # top of stack is now: + 2, OP_PICK, + # top of stack is now: + OP_CAT, + OP_SHA256, + # top of stack is now: + + OP_SWAP, + OP_ROT, + # top of stack is now: + + # OP_CAT, + # OP_SHA256, + # # top of stack is now: + + # OP_SWAP, + # # top of stack is now: OP_ENDIF, # this is in common between the two branches, so we can put it here @@ -144,7 +169,7 @@ def __init__(self, size: List[bytes]): # stack: # Check that new_root is committed in the next output, - 0, # index + -1, # index 0, # NUMS -1, # keep current taptree 0, # default, preserve amount @@ -158,13 +183,7 @@ def __init__(self, size: List[bytes]): ('new_value', BytesType()), ('merkle_root', BytesType()), ], - next_output_fn=lambda args: [ - ClauseOutput( - n=0, - next_contract=self, - next_data=args["merkle_proof"].get_new_root_after_update(args["new_value"]) - ) - ] + next_outputs_fn=next_outputs_fn ) super().__init__(NUMS_KEY, [withdraw, write]) diff --git a/examples/rps/rps_contracts.py b/examples/rps/rps_contracts.py index 9f2d839..2eeac5e 100644 --- a/examples/rps/rps_contracts.py +++ b/examples/rps/rps_contracts.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import hashlib from matt import CCV_FLAG_CHECK_INPUT, NUMS_KEY @@ -5,9 +6,9 @@ from matt.btctools.messages import sha256 from matt.btctools import script from matt.btctools.script import OP_ADD, OP_CAT, OP_CHECKCONTRACTVERIFY, OP_CHECKSIG, OP_CHECKTEMPLATEVERIFY, OP_DUP, OP_ENDIF, OP_EQUALVERIFY, OP_FROMALTSTACK, OP_IF, OP_LESSTHAN, OP_OVER, OP_SHA256, OP_SUB, OP_SWAP, OP_TOALTSTACK, OP_VERIFY, OP_WITHIN, CScript, bn2vch -from matt.contracts import P2TR, ClauseOutput, StandardClause, StandardP2TR, StandardAugmentedP2TR -from matt.script_helpers import check_output_contract -from matt.utils import make_ctv_template +from matt.contracts import P2TR, ClauseOutput, StandardClause, StandardP2TR, StandardAugmentedP2TR, ContractState +from matt.script_helpers import check_input_contract, check_output_contract +from matt.utils import encode_wit_element, make_ctv_template DEFAULT_STAKE: int = 1000 # amount of sats that the players bet @@ -71,14 +72,19 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes, stake: int = DEFA OP_DUP, 0, 3, OP_WITHIN, OP_VERIFY, # check that m_b is 0, 1 or 2 - OP_SHA256, # data = sha256(m_b) + *S1.State.encoder_script(), *check_output_contract(S1, index=0), ]), arg_specs=[ ('m_b', IntType()), ('bob_sig', SignerType(bob_pk)), ], - next_output_fn=lambda args: [ClauseOutput(n=0, next_contract=S1, next_data=sha256(bn2vch(args['m_b'])))] + next_outputs_fn=lambda args, _: [ + ClauseOutput( + n=0, + next_contract=S1, + next_state=S1.State(m_b=args["m_b"]) + )] ) super().__init__(NUMS_KEY, bob_move) @@ -95,6 +101,16 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes, stake: int = DEFA # - alice_pk, reveal losing move => ctv(bob wins) # - alice_pk, reveal tie move => ctv(tie) class RPSGameS1(StandardAugmentedP2TR): + @dataclass + class State(ContractState): + m_b: int + + def encode(self): + return sha256(encode_wit_element(self.m_b)) + + def encoder_script(): + return CScript([OP_SHA256]) + def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes, stake: int): self.alice_pk = alice_pk self.bob_pk = bob_pk @@ -120,12 +136,8 @@ def make_script(diff: int, ctv_hash: bytes): OP_DUP, # stack: altstack: - OP_SHA256, # data: sha256(m_b) - -1, # index: current input's index - 0, # NUMS pubkey - -1, # taptree: current input's taptree - CCV_FLAG_CHECK_INPUT, # flags - OP_CHECKCONTRACTVERIFY, + *self.State.encoder_script(), + *check_input_contract(), # stack: altstack: @@ -164,10 +176,10 @@ def make_script(diff: int, ctv_hash: bytes): ('r_a', BytesType()), ] alice_wins = StandardClause("tie", make_script( - 0, tmpl_alice_wins.get_standard_template_hash(0)), arg_specs, lambda _: tmpl_alice_wins) + 0, tmpl_alice_wins.get_standard_template_hash(0)), arg_specs, lambda _, __: tmpl_alice_wins) bob_wins = StandardClause("bob_wins", make_script( - 1, tmpl_bob_wins.get_standard_template_hash(0)), arg_specs, lambda _: tmpl_bob_wins) + 1, tmpl_bob_wins.get_standard_template_hash(0)), arg_specs, lambda _, __: tmpl_bob_wins) tie = StandardClause("alice_wins", make_script( - 2, tmpl_tie.get_standard_template_hash(0)), arg_specs, lambda _: tmpl_tie) + 2, tmpl_tie.get_standard_template_hash(0)), arg_specs, lambda _, __: tmpl_tie) super().__init__(NUMS_KEY, [alice_wins, [bob_wins, tie]]) diff --git a/examples/vault/vault_contracts.py b/examples/vault/vault_contracts.py index 78325e0..cbed457 100644 --- a/examples/vault/vault_contracts.py +++ b/examples/vault/vault_contracts.py @@ -1,9 +1,10 @@ +from dataclasses import dataclass from typing import Optional from matt import CCV_FLAG_DEDUCT_OUTPUT_AMOUNT, NUMS_KEY from matt.argtypes import BytesType, IntType, SignerType from matt.btctools.script import OP_CHECKCONTRACTVERIFY, OP_CHECKSIG, OP_CHECKTEMPLATEVERIFY, OP_DUP, OP_SWAP, OP_TRUE, CScript -from matt.contracts import ClauseOutput, ClauseOutputAmountBehaviour, OpaqueP2TR, StandardClause, StandardP2TR, StandardAugmentedP2TR +from matt.contracts import ClauseOutput, ClauseOutputAmountBehaviour, OpaqueP2TR, StandardClause, StandardP2TR, StandardAugmentedP2TR, ContractState from matt.script_helpers import check_input_contract, older @@ -38,8 +39,11 @@ def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: ('ctv_hash', BytesType()), ('out_i', IntType()), ], - next_output_fn=lambda args: [ClauseOutput( - n=args['out_i'], next_contract=unvaulting, next_data=args['ctv_hash'])] + next_outputs_fn=lambda args, _: [ClauseOutput( + n=args['out_i'], + next_contract=unvaulting, + next_state=unvaulting.State(ctv_hash=args["ctv_hash"]) + )] ) # witness: @@ -68,10 +72,13 @@ def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: ('out_i', IntType()), ('revault_out_i', IntType()), ], - next_output_fn=lambda args: [ + next_outputs_fn=lambda args, _: [ ClauseOutput(n=args['revault_out_i'], next_contract=self, next_amount=ClauseOutputAmountBehaviour.DEDUCT_OUTPUT), - ClauseOutput(n=args['out_i'], next_contract=unvaulting, next_data=args['ctv_hash']), + ClauseOutput( + n=args['out_i'], + next_contract=unvaulting, + next_state=unvaulting.State(ctv_hash=args["ctv_hash"])), ] ) @@ -90,7 +97,7 @@ def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: arg_specs=[ ('out_i', IntType()), ], - next_output_fn=lambda args: [ClauseOutput(n=args['out_i'], next_contract=OpaqueP2TR(recover_pk))] + next_outputs_fn=lambda args, _: [ClauseOutput(n=args['out_i'], next_contract=OpaqueP2TR(recover_pk))] ) if self.has_partial_revault: @@ -108,6 +115,16 @@ def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: class Unvaulting(StandardAugmentedP2TR): + @dataclass + class State(ContractState): + ctv_hash: bytes + + def encode(self): + return self.ctv_hash + + def encoder_script(): + return CScript([]) + def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: bytes): assert (alternate_pk is None or len(alternate_pk) == 32) and len(recover_pk) == 32 @@ -149,7 +166,7 @@ def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: arg_specs=[ ('out_i', IntType()), ], - next_output_fn=lambda args: [ClauseOutput(n=args['out_i'], next_contract=OpaqueP2TR(recover_pk))] + next_outputs_fn=lambda args, _: [ClauseOutput(n=args['out_i'], next_contract=OpaqueP2TR(recover_pk))] ) super().__init__(NUMS_KEY if alternate_pk is None else alternate_pk, [withdrawal, recover]) diff --git a/matt/contracts.py b/matt/contracts.py index c16b8ec..224a636 100644 --- a/matt/contracts.py +++ b/matt/contracts.py @@ -1,6 +1,7 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Type, Union from .argtypes import ArgType from .btctools import script @@ -14,6 +15,32 @@ class AbstractContract: pass +class ContractState(ABC): + """ + This class describes the "state" of a StandardAugmented contract, that is, the full data committed to + inside the data tweak. + """ + + @abstractmethod + def encode(self) -> bytes: + """ + Computes the 32-byte data tweak that represents the commitment to the state of the contract. + """ + pass + + @staticmethod + @abstractmethod + def encoder_script(*args, **kwargs) -> CScript: + """ + Returns a CScript that computes the commitment to the state, assuming that the top of the stack contains the + values of the individual stack items that allow to compute the state commitment, as output by the encode() function. + Contracts might decide not to implement this (and raise an error if this is called), but they must document how the + state commitment should be computed if not. Contracts implementing it should document what the expected stack + elements are when the encoder_script is used. + """ + pass + + class ClauseOutputAmountBehaviour(Enum): PRESERVE_OUTPUT = 0 # The output should be at least as large as the input IGNORE_OUTPUT = 1 # The output amount is not checked @@ -24,11 +51,11 @@ class ClauseOutputAmountBehaviour(Enum): class ClauseOutput: n: Optional[int] next_contract: AbstractContract # only StandardP2TR and StandardAugmentedP2TR are supported so far - next_data: Optional[bytes] = None # only meaningful if c is augmented + next_state: Optional[ContractState] = None # only meaningful if c is augmented next_amount: ClauseOutputAmountBehaviour = ClauseOutputAmountBehaviour.PRESERVE_OUTPUT def __repr__(self): - return f"ClauseOutput(n={self.n}, next_contract={self.next_contract}, next_data={self.next_data}, next_amount={self.next_amount})" + return f"ClauseOutput(n={self.n}, next_contract={self.next_contract}, next_state={self.next_state}, next_amount={self.next_amount})" class Clause: @@ -55,15 +82,15 @@ def __repr__(self): # Other types of generic treatable clauses could be defined (for example, a MiniscriptClause). # Moreover, it specifies a function that converts the arguments of the clause, to the data of the next output. class StandardClause(Clause): - def __init__(self, name: str, script: CScript, arg_specs: List[Tuple[str, ArgType]], next_output_fn: Optional[Callable[[dict], Union[List[ClauseOutput], CTransaction]]] = None): + def __init__(self, name: str, script: CScript, arg_specs: List[Tuple[str, ArgType]], next_outputs_fn: Optional[Callable[[dict, ContractState], Union[List[ClauseOutput], CTransaction]]] = None): super().__init__(name, script) self.arg_specs = arg_specs - self.next_outputs_fn = next_output_fn + self.next_outputs_fn = next_outputs_fn - def next_outputs(self, args: dict) -> Union[List[ClauseOutput], CTransaction]: + def next_outputs(self, args: dict, state: Optional[ContractState]) -> Union[List[ClauseOutput], CTransaction]: if self.next_outputs_fn is not None: - return self.next_outputs_fn(args) + return self.next_outputs_fn(args, state) else: return [] @@ -233,7 +260,7 @@ def __repr__(self): return f"{self.__class__.__name__}(internal_pubkey={self.internal_pubkey.hex()})" -class StandardAugmentedP2TR(AugmentedP2TR): +class StandardAugmentedP2TR(AugmentedP2TR, ABC): """ An AugmentedP2TR where all the transitions are given by a StandardClause. """ @@ -262,3 +289,8 @@ def decode_wit_stack(self, data: bytes, stack_elems: List[bytes]) -> Tuple[str, def __repr__(self): return f"{self.__class__.__name__}(naked_internal_pubkey={self.naked_internal_pubkey.hex()})" + + @property + @abstractmethod + def State() -> Type[ContractState]: + pass diff --git a/matt/hub/fraud.py b/matt/hub/fraud.py index 7b45371..971ca29 100644 --- a/matt/hub/fraud.py +++ b/matt/hub/fraud.py @@ -99,7 +99,7 @@ from .. import NUMS_KEY from ..argtypes import ArgType, BytesType, SignerType from ..btctools.script import OP_CAT, OP_CHECKSIG, OP_EQUAL, OP_EQUALVERIFY, OP_FROMALTSTACK, OP_NOT, OP_PICK, OP_SHA256, OP_SWAP, OP_TOALTSTACK, OP_VERIFY, CScript -from ..contracts import ClauseOutput, StandardAugmentedP2TR, StandardClause +from ..contracts import ClauseOutput, StandardAugmentedP2TR, StandardClause, ContractState from ..script_helpers import check_input_contract, check_output_contract, drop, dup, merkle_root, older @@ -111,6 +111,18 @@ class Computer: class Leaf(StandardAugmentedP2TR): + @dataclass + class State(ContractState): + h_start: bytes + h_end_alice: bytes + h_end_bob: bytes + + def encode(self): + return MerkleTree([self.h_start, self.h_end_alice, self.h_end_bob]).root + + def encoder_script(): + return CScript([*merkle_root(3)]) + def __init__(self, alice_pk: bytes, bob_pk: bytes, computer: Computer): self.alice_pk = alice_pk self.bob_pk = bob_pk @@ -212,6 +224,26 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, computer: Computer): class Bisect_1(StandardAugmentedP2TR): + @dataclass + class State(ContractState): + h_i: bytes + h_j_plus_1_a: bytes + h_j_plus_1_b: bytes + t_i_j_a: bytes + t_i_j_b: bytes + + def encode(self): + return MerkleTree([ + self.h_i, + self.h_j_plus_1_a, + self.h_j_plus_1_b, + self.t_i_j_a, + self.t_i_j_b + ]).root + + def encoder_script(): + return CScript([*merkle_root(5)]) + def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: Callable[[int], Leaf], forfait_timeout: int = 10): self.alice_pk = alice_pk self.bob_pk = bob_pk @@ -240,7 +272,7 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: *dup(5), # verify the embedded data - *merkle_root(5), + *self.State.encoder_script(), *check_input_contract(), OP_FROMALTSTACK, @@ -262,7 +294,7 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: OP_EQUALVERIFY, # verify that computed and committed values for match # check output - *merkle_root(8), + *bisect_2.State.encoder_script(), *check_output_contract(bisect_2), alice_pk, @@ -279,19 +311,19 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: ('t_left_a', BytesType()), ('t_right_a', BytesType()), ], - next_output_fn=lambda args: [ClauseOutput( + next_outputs_fn=lambda args, _: [ClauseOutput( n=-1, next_contract=bisect_2, - next_data=MerkleTree([ - args['h_i'], - args['h_j_plus_1_a'], - args['h_j_plus_1_b'], - args['t_i_j_a'], - args['t_i_j_b'], - args['h_i_plus_m_a'], - args['t_left_a'], - args['t_right_a'], - ]).root + next_state=bisect_2.State( + h_i=args['h_i'], + h_j_plus_1_a=args['h_j_plus_1_a'], + h_j_plus_1_b=args['h_j_plus_1_b'], + t_i_j_a=args['t_i_j_a'], + t_i_j_b=args['t_i_j_b'], + h_i_plus_m_a=args['h_i_plus_m_a'], + t_left_a=args['t_left_a'], + t_right_a=args['t_right_a'], + ) )] ) @@ -312,6 +344,32 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: # TODO: probably more efficient to combine the _left and _right clauses class Bisect_2(StandardAugmentedP2TR): + @dataclass + class State(ContractState): + h_i: bytes + h_j_plus_1_a: bytes + h_j_plus_1_b: bytes + t_i_j_a: bytes + t_i_j_b: bytes + h_i_plus_m_a: bytes + t_left_a: bytes + t_right_a: bytes + + def encode(self): + return MerkleTree([ + self.h_i, + self.h_j_plus_1_a, + self.h_j_plus_1_b, + self.t_i_j_a, + self.t_i_j_b, + self.h_i_plus_m_a, + self.t_left_a, + self.t_right_a + ]).root + + def encoder_script(): + return CScript([*merkle_root(8)]) + def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: Callable[[int], Leaf], forfait_timeout: int = 10): self.alice_pk = alice_pk self.bob_pk = bob_pk @@ -334,8 +392,8 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: leaf_left = leaf_factory(i) leaf_right = leaf_factory(i + 1) else: - bisectg_1_left = Bisect_1(alice_pk, bob_pk, i, i + m - 1, leaf_factory, forfait_timeout) - bisectg_1_right = Bisect_1(alice_pk, bob_pk, i + m, j, leaf_factory, forfait_timeout) + bisect_1_left = Bisect_1(alice_pk, bob_pk, i, i + m - 1, leaf_factory, forfait_timeout) + bisect_1_right = Bisect_1(alice_pk, bob_pk, i + m, j, leaf_factory, forfait_timeout) # bob reveals a midstate that doesn't match with Alice's # (iterate on the left child) @@ -350,7 +408,7 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: *dup(8), # verify the embedded data - *merkle_root(8), + *self.State.encoder_script(), *check_input_contract(), OP_FROMALTSTACK, @@ -383,7 +441,7 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: 10, OP_PICK, # h_i 1 + 5, OP_PICK, 2 + 2, OP_PICK, - *merkle_root(3), + *leaf_left.State.encoder_script(), *check_output_contract(leaf_left), ] if are_children_leaves else [ # put on top of the stack: [h_i, h_{i+m; a}, h_{i+m; b}, t_{i, i+m-1; a}, t_{i, i+m-1; b}] @@ -392,8 +450,8 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: 2 + 2, OP_PICK, 3 + 4, OP_PICK, 4 + 1, OP_PICK, - *merkle_root(5), - *check_output_contract(bisectg_1_left), + *bisect_1_left.State.encoder_script(), + *check_output_contract(bisect_1_left), ]), # only leave on the stack @@ -416,20 +474,20 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: ('t_left_b', BytesType()), ('t_right_b', BytesType()), ], - next_output_fn=lambda args: [ClauseOutput( + next_outputs_fn=lambda args, _: [ClauseOutput( n=-1, - next_contract=leaf_left if are_children_leaves else bisectg_1_left, - next_data=MerkleTree([ - args['h_i'], - args['h_i_plus_m_a'], - args['h_i_plus_m_b'], - ]).root if are_children_leaves else MerkleTree([ - args['h_i'], - args['h_i_plus_m_a'], - args['h_i_plus_m_b'], - args['t_left_a'], - args['t_left_b'], - ]).root + next_contract=leaf_left if are_children_leaves else bisect_1_left, + next_state=leaf_left.State( + h_start=args['h_i'], + h_end_alice=args['h_i_plus_m_a'], + h_end_bob=args['h_i_plus_m_b'], + ) if are_children_leaves else bisect_1_left.State( + h_i=args['h_i'], + h_j_plus_1_a=args['h_i_plus_m_a'], + h_j_plus_1_b=args['h_i_plus_m_b'], + t_i_j_a=args['t_left_a'], + t_i_j_b=args['t_left_b'], + ) )] ) @@ -446,7 +504,7 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: *dup(8), # verify the embedded data - *merkle_root(8), + *self.State.encoder_script(), *check_input_contract(), OP_FROMALTSTACK, @@ -479,7 +537,7 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: 5, OP_PICK, 1 + 9, OP_PICK, 2 + 8, OP_PICK, - *merkle_root(3), + *leaf_right.State.encoder_script(), *check_output_contract(leaf_right), ] if are_children_leaves else [ # put on top of the stack: [h_{i+m}, h_{j+1; a}, h_{j+1; b}, t_{i+m, j; a}, t_{i+m, j; b}] @@ -488,8 +546,8 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: 2 + 8, OP_PICK, 3 + 3, OP_PICK, 4 + 0, OP_PICK, - *merkle_root(5), - *check_output_contract(bisectg_1_right), + *bisect_1_right.State.encoder_script(), + *check_output_contract(bisect_1_right), ]), # only leave on the stack @@ -512,20 +570,20 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, i: int, j: int, leaf_factory: ('t_left_b', BytesType()), ('t_right_b', BytesType()), ], - next_output_fn=lambda args: [ClauseOutput( + next_outputs_fn=lambda args, _: [ClauseOutput( n=-1, - next_contract=leaf_right if are_children_leaves else bisectg_1_right, - next_data=MerkleTree([ - args['h_i_plus_m_a'], - args['h_j_plus_1_a'], - args['h_j_plus_1_b'], - ]).root if are_children_leaves else MerkleTree([ - args['h_i_plus_m_a'], # this is equal to h_i_plus_m_b, as it's checked in the script! - args['h_j_plus_1_a'], - args['h_j_plus_1_b'], - args['t_right_a'], - args['t_right_b'], - ]).root + next_contract=leaf_right if are_children_leaves else bisect_1_right, + next_state=leaf_right.State( + h_start=args['h_i_plus_m_a'], + h_end_alice=args['h_j_plus_1_a'], + h_end_bob=args['h_j_plus_1_b'], + ) if are_children_leaves else bisect_1_right.State( + h_i=args['h_i_plus_m_a'], # this is equal to h_i_plus_m_b, as it's checked in the script! + h_j_plus_1_a=args['h_j_plus_1_a'], + h_j_plus_1_b=args['h_j_plus_1_b'], + t_i_j_a=args['t_right_a'], + t_i_j_b=args['t_right_b'], + ) )] ) diff --git a/matt/manager.py b/matt/manager.py index e6ae78d..e7b460b 100644 --- a/matt/manager.py +++ b/matt/manager.py @@ -15,7 +15,7 @@ from .btctools.messages import COutPoint, CTransaction, CTxIn, CTxInWitness, CTxOut from .btctools.script import TaprootInfo from .btctools.segwit_addr import encode_segwit_address -from .contracts import P2TR, AugmentedP2TR, ClauseOutputAmountBehaviour, OpaqueP2TR, StandardAugmentedP2TR, StandardP2TR +from .contracts import P2TR, AugmentedP2TR, ClauseOutputAmountBehaviour, OpaqueP2TR, StandardAugmentedP2TR, StandardP2TR, ContractState from .utils import wait_for_output, wait_for_spending_tx @@ -52,9 +52,8 @@ class ContractInstanceStatus(Enum): class ContractInstance: def __init__(self, contract: Union[StandardP2TR, StandardAugmentedP2TR]): self.contract = contract - self.data = None if not self.is_augm() else b'\0'*32 - - self.data_expanded = None # TODO: figure out a good API for this + self.data: Optional[bytes] = None + self.data_expanded: Optional[ContractState] = None # TODO: figure out a good API for this self.manager: ContractManager = None @@ -195,7 +194,7 @@ def get_spend_tx( raise ValueError(f"Clause {clause_name} not found") clause = instance.contract.clauses[clause_idx] - next_outputs = clause.next_outputs(args) + next_outputs = clause.next_outputs(args, instance.data_expanded) if isinstance(next_outputs, CTransaction): if len(tx.vin) != 1 or len(next_outputs.vin) != 1: raise ValueError("CTV clauses are only supported for single-input spends") # TODO: generalize @@ -211,9 +210,9 @@ def get_spend_tx( if isinstance(out_contract, (P2TR, OpaqueP2TR)): out_scriptPubKey = out_contract.get_tr_info().scriptPubKey elif isinstance(out_contract, AugmentedP2TR): - if clause_output.next_data is None: + if clause_output.next_state is None: raise ValueError("Missing data for augmented output") - out_scriptPubKey = out_contract.get_tr_info(clause_output.next_data).scriptPubKey + out_scriptPubKey = out_contract.get_tr_info(clause_output.next_state.encode()).scriptPubKey else: raise ValueError("Unsupported contract type") @@ -336,7 +335,7 @@ def wait_for_spend(self, instances: Union[ContractInstance, List[ContractInstanc raise ValueError(f"Clause {instance.spending_clause} not found") clause = instance.contract.clauses[clause_idx] - next_outputs = clause.next_outputs(instance.spending_args) + next_outputs = clause.next_outputs(instance.spending_args, instance.data_expanded) # We go through all the outputs produced by spending this transaction, # and add them to the manager if they are standard @@ -357,9 +356,10 @@ def wait_for_spend(self, instances: Union[ContractInstance, List[ContractInstanc if isinstance(out_contract, (P2TR, OpaqueP2TR, StandardP2TR)): continue # nothing to do, will not track this output elif isinstance(out_contract, StandardAugmentedP2TR): - if clause_output.next_data is None: + if clause_output.next_state is None: raise ValueError("Missing data for augmented output") - new_instance.data = clause_output.next_data + new_instance.data = clause_output.next_state.encode() + new_instance.data_expanded = clause_output.next_state else: raise ValueError("Unsupported contract type") @@ -376,7 +376,7 @@ def wait_for_spend(self, instances: Union[ContractInstance, List[ContractInstanc self.add_instance(instance) return result - def fund_instance(self, contract: Union[StandardP2TR, StandardAugmentedP2TR], amount: int, data: Optional[bytes] = None) -> ContractInstance: + def fund_instance(self, contract: Union[StandardP2TR, StandardAugmentedP2TR], amount: int, data: Optional[ContractState] = None) -> ContractInstance: """ Convenience method to create an instance of a contract, add it to the ContractManager, and send a transaction to fund it with a certain amount. @@ -389,7 +389,8 @@ def fund_instance(self, contract: Union[StandardP2TR, StandardAugmentedP2TR], am if isinstance(contract, StandardAugmentedP2TR): if data is None: raise ValueError("The data must be provided for an augmented P2TR contract instance") - instance.data = data + instance.data_expanded = data + instance.data = data.encode() self.add_instance(instance) txid = self.rpc.sendtoaddress(instance.get_address(), amount/100_000_000) self.wait_for_outpoint(instance, txid) diff --git a/tests/test_fraud.py b/tests/test_fraud.py index f9697de..ac40c46 100644 --- a/tests/test_fraud.py +++ b/tests/test_fraud.py @@ -15,7 +15,7 @@ bob_key = key.ExtendedKey.deserialize( "tprv8ZgxMBicQKsPeDvaW4xxmiMXxqakLgvukT8A5GR6mRwBwjsDJV1jcZab8mxSerNcj22YPrusm2Pz5oR8LTw9GqpWT51VexTNBzxxm49jCZZ") -# TODO: make outputs that sends to Alice/Bob, instead of using burn addresses :P +# TODO: make outputs that sends to Alice/Bob, instead of using burn addresses def test_leaf_reveal_alice(manager: ContractManager): @@ -29,11 +29,8 @@ def test_leaf_reveal_alice(manager: ContractManager): h_end_alice = sha256(encode_wit_element(x_end_alice)) h_end_bob = sha256(encode_wit_element(x_end_bob)) - data = [h_start, h_end_alice, h_end_bob] - mt = MerkleTree(data) - - L_inst = manager.fund_instance(L, AMOUNT, data=mt.root) - L_inst.data_expanded = data + L_inst = manager.fund_instance(L, AMOUNT, data=L.State( + h_start=h_start, h_end_alice=h_end_alice, h_end_bob=h_end_bob)) outputs = [ CTxOut( @@ -62,11 +59,8 @@ def test_leaf_reveal_bob(manager: ContractManager): h_end_alice = sha256(encode_wit_element(x_end_alice)) h_end_bob = sha256(encode_wit_element(x_end_bob)) - data = [h_start, h_end_alice, h_end_bob] - mt = MerkleTree(data) - - L_inst = manager.fund_instance(L, AMOUNT, data=mt.root) - L_inst.data_expanded = data + L_inst = manager.fund_instance(L, AMOUNT, data=L.State( + h_start=h_start, h_end_alice=h_end_alice, h_end_bob=h_end_bob)) outputs = [ CTxOut( @@ -88,8 +82,6 @@ def test_fraud_proof_full(manager: ContractManager): alice_trace = [2, 4, 8, 16, 32, 64, 127, 254, 508] bob_trace = [2, 4, 8, 16, 32, 64, 128, 256, 512] - # TODO: the contract instance should be able to keep track of the data contained - assert alice_trace[0] == bob_trace[0] and len(alice_trace) == len(bob_trace) n = len(alice_trace) - 1 # the trace has n + 1 entries @@ -137,28 +129,27 @@ def t_node_b(i, j) -> bytes: [inst] = inst('choose', signer=bob_signer, x=x) assert isinstance(inst.contract, G256_S1) + assert isinstance(inst.data_expanded, G256_S1.State) and inst.data_expanded.x == x t_a = t_node_a(0, n - 1) # trace root according to Alice t_b = t_node_b(0, n - 1) # trace root according to Bob # Alice reveals her answer [inst] = inst('reveal', signer=alice_signer, + x=x, y=y, - t_a=t_a, - sha256_x=sha256(encode_wit_element(x)) - ) - inst.data = MerkleTree([t_a, sha256(encode_wit_element(y)), sha256(encode_wit_element(x))]).root + t_a=t_a) assert isinstance(inst.contract, G256_S2) + assert inst.data_expanded == G256_S2.State(t_a=t_a, x=x, y=y) # Bob disagrees and starts the challenge [inst] = inst('start_challenge', signer=bob_signer, t_a=t_a, - sha256_y=sha256(encode_wit_element(y)), - sha256_x=sha256(encode_wit_element(x)), + x=x, + y=y, z=z, - t_b=t_b - ) + t_b=t_b) # inst now represents a step in the bisection protocol corresponding to the root of the computation diff --git a/tests/test_ram.py b/tests/test_ram.py index 95320d2..116bd7e 100644 --- a/tests/test_ram.py +++ b/tests/test_ram.py @@ -17,7 +17,8 @@ def test_withdraw(rpc, manager: ContractManager): data = [sha256(i.to_bytes(1, byteorder='little')) for i in range(size)] mt = MerkleTree(data) - R_inst = manager.fund_instance(RAM(len(data)), AMOUNT, data=mt.root) + R = RAM(len(data)) + R_inst = manager.fund_instance(R, AMOUNT, data=R.State(data)) outputs = [ CTxOut( @@ -44,7 +45,8 @@ def test_write(manager: ContractManager): data = [sha256(i.to_bytes(1, byteorder='little')) for i in range(size)] mt = MerkleTree(data) - R_inst = manager.fund_instance(RAM(len(data)), AMOUNT, data=mt.root) + R = RAM(len(data)) + R_inst = manager.fund_instance(R, AMOUNT, data=R.State(data)) out_instances = R_inst("write", merkle_root=mt.root, @@ -68,7 +70,8 @@ def test_write_loop(manager: ContractManager): data = [sha256(i.to_bytes(1, byteorder='little')) for i in range(size)] - R_inst = manager.fund_instance(RAM(len(data)), AMOUNT, data=MerkleTree(data).root) + R = RAM(len(data)) + R_inst = manager.fund_instance(R, AMOUNT, data=R.State(data)) for i in range(16): leaf_index = i % size