Skip to content

Commit

Permalink
Refactor to allow StandardContracts to specify the exact taptree stru…
Browse files Browse the repository at this point in the history
…cture
  • Loading branch information
bigspider committed Dec 12, 2023
1 parent a63d49f commit 5a7ff36
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 18 deletions.
6 changes: 3 additions & 3 deletions examples/rps/rps_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes, stake: int = DEFA
OP_SHA256, # data = sha256(m_b)
0, # index
0, # NUMS pk
S1.get_taptree(),
S1.get_taptree_merkle_root(),
0, # flags
OP_CHECKCONTRACTVERIFY,
]),
Expand All @@ -83,7 +83,7 @@ def __init__(self, alice_pk: bytes, bob_pk: bytes, c_a: bytes, stake: int = DEFA
next_output_fn=lambda args: [ClauseOutput(n=0, next_contract=S1, next_data=sha256(bn2vch(args['m_b'])))]
)

super().__init__(NUMS_KEY, [bob_move])
super().__init__(NUMS_KEY, bob_move)


# params:
Expand Down Expand Up @@ -169,4 +169,4 @@ def make_script(diff: int, ctv_hash: bytes):
bob_wins = StandardClause("bob_wins", make_script(1, tmpl_bob_wins.get_standard_template_hash(0)), arg_specs, lambda _: tmpl_bob_wins)
tie = StandardClause("alice_wins", make_script(2, tmpl_tie.get_standard_template_hash(0)), arg_specs, lambda _: tmpl_tie)

super().__init__(NUMS_KEY, [alice_wins, bob_wins, tie])
super().__init__(NUMS_KEY, [alice_wins, [bob_wins, tie]])
6 changes: 3 additions & 3 deletions examples/vault/vault_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk:
script=CScript([
# data and index already on the stack
0 if alternate_pk is None else alternate_pk, # pk
unvaulting.get_taptree(), # taptree
unvaulting.get_taptree_merkle_root(), # taptree
0, # standard flags
OP_CHECKCONTRACTVERIFY,

Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk:

# data and index already on the stack
0 if alternate_pk is None else alternate_pk, # pk
unvaulting.get_taptree(), # taptree
unvaulting.get_taptree_merkle_root(), # taptree
0, # standard flags
OP_CHECKCONTRACTVERIFY,

Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk:
next_output_fn=lambda args: [ClauseOutput(n=args['out_i'], next_contract=OpaqueP2TR(recover_pk))]
)

super().__init__(NUMS_KEY if alternate_pk is None else alternate_pk, [trigger, trigger_and_recover, recover])
super().__init__(NUMS_KEY if alternate_pk is None else alternate_pk, [trigger, [trigger_and_recover, recover]])


class Unvaulting(StandardAugmentedP2TR):
Expand Down
54 changes: 42 additions & 12 deletions matt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,16 @@ def get_address(self) -> str:
def __repr__(self) -> str:
return f"{self.__class__.__name__}(pubkey={self.pubkey.hex()})"

Tapleaf = Tuple[str, Union[CScript, bytes]]
TaptreeDescription = List['TaptreeDescription']


class P2TR(AbstractContract):
"""
A class representing a Pay-to-Taproot script.
"""

def __init__(self, internal_pubkey: bytes, scripts: List[Tuple[str, CScript]]):
def __init__(self, internal_pubkey: bytes, scripts: TaptreeDescription):
assert len(internal_pubkey) == 32

self.internal_pubkey = internal_pubkey
Expand Down Expand Up @@ -167,10 +170,10 @@ def __init__(self, naked_internal_pubkey: bytes):

self.naked_internal_pubkey = naked_internal_pubkey

def get_scripts(self) -> List[Tuple[str, CScript]]:
def get_scripts(self) -> TaptreeDescription:
raise NotImplementedError("This must be implemented in subclasses")

def get_taptree(self) -> bytes:
def get_taptree_merkle_root(self) -> bytes:
# use dummy data, since it doesn't affect the merkle root
return self.get_tr_info(b'\0'*32).merkle_root

Expand All @@ -185,15 +188,41 @@ def __repr__(self):
return f"{self.__class__.__name__}(naked_internal_pubkey={self.naked_internal_pubkey.hex()}. Contracts's data: {self.data})"


StandardTaptreeDescription = Union[StandardClause, List['StandardTaptreeDescription']]


# converts a StandardTaptreeDescription to a TaptreeDescription, preserving the structure
def _normalize_standard_taptree_description(std_tree: StandardTaptreeDescription) -> TaptreeDescription:
if isinstance(std_tree, list):
if len(std_tree) != 2:
raise ValueError("A TapBranch must have exactly two children")
return [_normalize_standard_taptree_description(el) for el in std_tree]
else:
# std_tree is actually a single StandardClause
return [(std_tree.name, std_tree.script)]


# returns a flattenet list of StandardClause
def _flatten_standard_taptree_description(std_tree: StandardTaptreeDescription) -> list[StandardClause]:
if isinstance(std_tree, list):
if len(std_tree) != 2:
raise ValueError("A TapBranch must have exactly two children")
return [item for subtree in std_tree for item in _flatten_standard_taptree_description(subtree)]
else:
# std_tree is a single clause
return [std_tree]


class StandardP2TR(P2TR):
"""
A StandardP2TR where all the transitions are given by a StandardClause.
"""

def __init__(self, internal_pubkey: bytes, clauses: List[StandardClause]):
super().__init__(internal_pubkey, list(map(lambda x: (x.name, x.script), clauses)))
self.clauses = clauses
self._clauses_dict = {clause.name: clause for clause in clauses}
def __init__(self, internal_pubkey: bytes, standard_taptree: StandardTaptreeDescription):
super().__init__(internal_pubkey, _normalize_standard_taptree_description(standard_taptree))
self.standard_taptree = standard_taptree
self.clauses = _flatten_standard_taptree_description(standard_taptree)
self._clauses_dict = {clause.name: clause for clause in self.clauses}

def get_scripts(self) -> List[Tuple[str, CScript]]:
return list(map(lambda clause: (clause.name, clause.script), self.clauses))
Expand All @@ -220,13 +249,14 @@ class StandardAugmentedP2TR(AugmentedP2TR):
An AugmentedP2TR where all the transitions are given by a StandardClause.
"""

def __init__(self, naked_internal_pubkey: bytes, clauses: List[StandardClause]):
def __init__(self, naked_internal_pubkey: bytes, standard_taptree: StandardTaptreeDescription):
super().__init__(naked_internal_pubkey)
self.clauses = clauses
self._clauses_dict = {clause.name: clause for clause in clauses}
self.standard_taptree = standard_taptree
self.clauses = _flatten_standard_taptree_description(standard_taptree)
self._clauses_dict = {clause.name: clause for clause in self.clauses}

def get_scripts(self) -> List[Tuple[str, CScript]]:
return list(map(lambda clause: (clause.name, clause.script), self.clauses))
def get_scripts(self) -> TaptreeDescription:
return _normalize_standard_taptree_description(self.standard_taptree)

def decode_wit_stack(self, data: bytes, stack_elems: List[bytes]) -> Tuple[str, dict]:
leaf_hash = stack_elems[-2]
Expand Down

0 comments on commit 5a7ff36

Please sign in to comment.