From 5a7ff36e58722d8109c827221eceaaa34238fb72 Mon Sep 17 00:00:00 2001 From: Salvatore Ingala <6681844+bigspider@users.noreply.github.com> Date: Tue, 12 Dec 2023 11:21:24 +0100 Subject: [PATCH] Refactor to allow StandardContracts to specify the exact taptree structure --- examples/rps/rps_contracts.py | 6 ++-- examples/vault/vault_contracts.py | 6 ++-- matt/__init__.py | 54 ++++++++++++++++++++++++------- 3 files changed, 48 insertions(+), 18 deletions(-) diff --git a/examples/rps/rps_contracts.py b/examples/rps/rps_contracts.py index 7a28297..1fa9ca7 100644 --- a/examples/rps/rps_contracts.py +++ b/examples/rps/rps_contracts.py @@ -72,7 +72,7 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes, stake: int = DEFA OP_SHA256, # data = sha256(m_b) 0, # index 0, # NUMS pk - S1.get_taptree(), + S1.get_taptree_merkle_root(), 0, # flags OP_CHECKCONTRACTVERIFY, ]), @@ -83,7 +83,7 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes, stake: int = DEFA next_output_fn=lambda args: [ClauseOutput(n=0, next_contract=S1, next_data=sha256(bn2vch(args['m_b'])))] ) - super().__init__(NUMS_KEY, [bob_move]) + super().__init__(NUMS_KEY, bob_move) # params: @@ -169,4 +169,4 @@ def make_script(diff: int, ctv_hash: bytes): bob_wins = StandardClause("bob_wins", make_script(1, tmpl_bob_wins.get_standard_template_hash(0)), arg_specs, lambda _: tmpl_bob_wins) tie = StandardClause("alice_wins", make_script(2, tmpl_tie.get_standard_template_hash(0)), arg_specs, lambda _: tmpl_tie) - super().__init__(NUMS_KEY, [alice_wins, bob_wins, tie]) + super().__init__(NUMS_KEY, [alice_wins, [bob_wins, tie]]) diff --git a/examples/vault/vault_contracts.py b/examples/vault/vault_contracts.py index adbcb7a..dd7bf20 100644 --- a/examples/vault/vault_contracts.py +++ b/examples/vault/vault_contracts.py @@ -20,7 +20,7 @@ def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: script=CScript([ # data and index already on the stack 0 if alternate_pk is None else alternate_pk, # pk - unvaulting.get_taptree(), # taptree + unvaulting.get_taptree_merkle_root(), # taptree 0, # standard flags OP_CHECKCONTRACTVERIFY, @@ -48,7 +48,7 @@ def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: # data and index already on the stack 0 if alternate_pk is None else alternate_pk, # pk - unvaulting.get_taptree(), # taptree + unvaulting.get_taptree_merkle_root(), # taptree 0, # standard flags OP_CHECKCONTRACTVERIFY, @@ -85,7 +85,7 @@ def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: next_output_fn=lambda args: [ClauseOutput(n=args['out_i'], next_contract=OpaqueP2TR(recover_pk))] ) - super().__init__(NUMS_KEY if alternate_pk is None else alternate_pk, [trigger, trigger_and_recover, recover]) + super().__init__(NUMS_KEY if alternate_pk is None else alternate_pk, [trigger, [trigger_and_recover, recover]]) class Unvaulting(StandardAugmentedP2TR): diff --git a/matt/__init__.py b/matt/__init__.py index 156fd99..0b84f20 100644 --- a/matt/__init__.py +++ b/matt/__init__.py @@ -132,13 +132,16 @@ def get_address(self) -> str: def __repr__(self) -> str: return f"{self.__class__.__name__}(pubkey={self.pubkey.hex()})" +Tapleaf = Tuple[str, Union[CScript, bytes]] +TaptreeDescription = List['TaptreeDescription'] + class P2TR(AbstractContract): """ A class representing a Pay-to-Taproot script. """ - def __init__(self, internal_pubkey: bytes, scripts: List[Tuple[str, CScript]]): + def __init__(self, internal_pubkey: bytes, scripts: TaptreeDescription): assert len(internal_pubkey) == 32 self.internal_pubkey = internal_pubkey @@ -167,10 +170,10 @@ def __init__(self, naked_internal_pubkey: bytes): self.naked_internal_pubkey = naked_internal_pubkey - def get_scripts(self) -> List[Tuple[str, CScript]]: + def get_scripts(self) -> TaptreeDescription: raise NotImplementedError("This must be implemented in subclasses") - def get_taptree(self) -> bytes: + def get_taptree_merkle_root(self) -> bytes: # use dummy data, since it doesn't affect the merkle root return self.get_tr_info(b'\0'*32).merkle_root @@ -185,15 +188,41 @@ def __repr__(self): return f"{self.__class__.__name__}(naked_internal_pubkey={self.naked_internal_pubkey.hex()}. Contracts's data: {self.data})" +StandardTaptreeDescription = Union[StandardClause, List['StandardTaptreeDescription']] + + +# converts a StandardTaptreeDescription to a TaptreeDescription, preserving the structure +def _normalize_standard_taptree_description(std_tree: StandardTaptreeDescription) -> TaptreeDescription: + if isinstance(std_tree, list): + if len(std_tree) != 2: + raise ValueError("A TapBranch must have exactly two children") + return [_normalize_standard_taptree_description(el) for el in std_tree] + else: + # std_tree is actually a single StandardClause + return [(std_tree.name, std_tree.script)] + + +# returns a flattenet list of StandardClause +def _flatten_standard_taptree_description(std_tree: StandardTaptreeDescription) -> list[StandardClause]: + if isinstance(std_tree, list): + if len(std_tree) != 2: + raise ValueError("A TapBranch must have exactly two children") + return [item for subtree in std_tree for item in _flatten_standard_taptree_description(subtree)] + else: + # std_tree is a single clause + return [std_tree] + + class StandardP2TR(P2TR): """ A StandardP2TR where all the transitions are given by a StandardClause. """ - def __init__(self, internal_pubkey: bytes, clauses: List[StandardClause]): - super().__init__(internal_pubkey, list(map(lambda x: (x.name, x.script), clauses))) - self.clauses = clauses - self._clauses_dict = {clause.name: clause for clause in clauses} + def __init__(self, internal_pubkey: bytes, standard_taptree: StandardTaptreeDescription): + super().__init__(internal_pubkey, _normalize_standard_taptree_description(standard_taptree)) + self.standard_taptree = standard_taptree + self.clauses = _flatten_standard_taptree_description(standard_taptree) + self._clauses_dict = {clause.name: clause for clause in self.clauses} def get_scripts(self) -> List[Tuple[str, CScript]]: return list(map(lambda clause: (clause.name, clause.script), self.clauses)) @@ -220,13 +249,14 @@ class StandardAugmentedP2TR(AugmentedP2TR): An AugmentedP2TR where all the transitions are given by a StandardClause. """ - def __init__(self, naked_internal_pubkey: bytes, clauses: List[StandardClause]): + def __init__(self, naked_internal_pubkey: bytes, standard_taptree: StandardTaptreeDescription): super().__init__(naked_internal_pubkey) - self.clauses = clauses - self._clauses_dict = {clause.name: clause for clause in clauses} + self.standard_taptree = standard_taptree + self.clauses = _flatten_standard_taptree_description(standard_taptree) + self._clauses_dict = {clause.name: clause for clause in self.clauses} - def get_scripts(self) -> List[Tuple[str, CScript]]: - return list(map(lambda clause: (clause.name, clause.script), self.clauses)) + def get_scripts(self) -> TaptreeDescription: + return _normalize_standard_taptree_description(self.standard_taptree) def decode_wit_stack(self, data: bytes, stack_elems: List[bytes]) -> Tuple[str, dict]: leaf_hash = stack_elems[-2]