Skip to content

Commit

Permalink
Compatibility of type annotations with older python versions
Browse files Browse the repository at this point in the history
  • Loading branch information
bigspider committed Dec 10, 2023
1 parent 0154a53 commit 15cc354
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 71 deletions.
6 changes: 4 additions & 2 deletions examples/ram/ram.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import argparse
from ast import Tuple
import json

import os

import logging
import shlex
import traceback
from typing import Dict, List

from dotenv import load_dotenv

Expand Down Expand Up @@ -63,7 +65,7 @@ def get_completions(self, document, complete_event):
rpc_port = os.getenv("RPC_PORT", 18443)


def parse_outputs(output_strings: list[str]) -> list[tuple[str, int]]:
def parse_outputs(output_strings: List[str]) -> List[Tuple[str, int]]:
"""Parses a list of strings in the form "address:amount" into a list of (address, amount) tuples.
Args:
Expand Down Expand Up @@ -250,7 +252,7 @@ def script_main(script_filename: str):
environment = Environment(rpc, manager, None, None, False)

# map from known ctv hashes to the corresponding template (used for withdrawals)
ctv_templates: dict[bytes, CTransaction] = {}
ctv_templates: Dict[bytes, CTransaction] = {}


if args.script:
Expand Down
3 changes: 2 additions & 1 deletion examples/ram/ram_contracts.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import List
from matt.argtypes import BytesType, MerkleProofType
from matt.btctools.script import OP_CAT, OP_CHECKCONTRACTVERIFY, OP_DUP, OP_ELSE, OP_ENDIF, OP_EQUAL, OP_EQUALVERIFY, OP_FROMALTSTACK, OP_IF, OP_NOTIF, OP_PICK, OP_ROLL, OP_ROT, OP_SHA256, OP_SWAP, OP_TOALTSTACK, OP_TRUE, CScript
from matt import CCV_FLAG_CHECK_INPUT, NUMS_KEY, ClauseOutput, StandardClause, StandardAugmentedP2TR

from matt.merkle import is_power_of_2, floor_lg

class RAM(StandardAugmentedP2TR):
def __init__(self, size: list[bytes]):
def __init__(self, size: List[bytes]):
assert is_power_of_2(size)

self.size = size
Expand Down
7 changes: 4 additions & 3 deletions examples/vault/vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import shlex
import traceback
from typing import Dict, List, Tuple

from dotenv import load_dotenv

Expand Down Expand Up @@ -81,7 +82,7 @@ def segwit_addr_to_scriptpubkey(addr: str) -> bytes:
])


def parse_outputs(output_strings: list[str]) -> list[tuple[str, int]]:
def parse_outputs(output_strings: List[str]) -> List[Tuple[str, int]]:
"""Parses a list of strings in the form "address:amount" into a list of (address, amount) tuples.
Args:
Expand Down Expand Up @@ -160,7 +161,7 @@ def execute_command(input_line: str):
if not isinstance(items_idx, list) or len(set(items_idx)) != len(items_idx):
raise ValueError("Invalid items")

spending_vaults: list[ContractInstance] = []
spending_vaults: List[ContractInstance] = []
for idx in items_idx:
if idx >= len(manager.instances):
raise ValueError(f"No such instance: {idx}")
Expand Down Expand Up @@ -343,7 +344,7 @@ def script_main(script_filename: str):
print(f"Vault address: {V.get_address()}\n")

# map from known ctv hashes to the corresponding template (used for withdrawals)
ctv_templates: dict[bytes, CTransaction] = {}
ctv_templates: Dict[bytes, CTransaction] = {}


if args.script:
Expand Down
5 changes: 3 additions & 2 deletions examples/vault/vault_contracts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Optional
from matt.argtypes import BytesType, IntType, SignerType
from matt.btctools.script import OP_CHECKCONTRACTVERIFY, OP_CHECKSEQUENCEVERIFY, OP_CHECKSIG, OP_CHECKTEMPLATEVERIFY, OP_DROP, OP_DUP, OP_SWAP, OP_TRUE, CScript
from matt import CCV_FLAG_CHECK_INPUT, CCV_FLAG_DEDUCT_OUTPUT_AMOUNT, NUMS_KEY, ClauseOutput, ClauseOutputAmountBehaviour, OpaqueP2TR, StandardClause, StandardP2TR, StandardAugmentedP2TR


class Vault(StandardP2TR):
def __init__(self, alternate_pk: bytes | None, spend_delay: int, recover_pk: bytes, unvault_pk: bytes):
def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: bytes, unvault_pk: bytes):
assert (alternate_pk is None or len(alternate_pk) == 32) and len(recover_pk) == 32 and len(unvault_pk)

self.alternate_pk = alternate_pk
Expand Down Expand Up @@ -88,7 +89,7 @@ def __init__(self, alternate_pk: bytes | None, spend_delay: int, recover_pk: byt


class Unvaulting(StandardAugmentedP2TR):
def __init__(self, alternate_pk: bytes | None, spend_delay: int, recover_pk: bytes):
def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: bytes):
assert (alternate_pk is None or len(alternate_pk) == 32) and len(recover_pk) == 32

self.alternate_pk = alternate_pk
Expand Down
82 changes: 41 additions & 41 deletions matt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from enum import Enum
from io import BytesIO
from typing import Callable, Optional
from typing import Callable, Dict, List, Optional, Tuple, Union

from .argtypes import ArgType, SignerType
from .btctools import script, key
Expand Down Expand Up @@ -32,9 +32,9 @@ class ClauseOutputAmountBehaviour(Enum):

@dataclass
class ClauseOutput:
n: None | int
n: Optional[int]
next_contract: AbstractContract # only StandardP2TR and StandardAugmentedP2TR are supported so far
next_data: None | bytes = None # only meaningful if c is augmented
next_data: Optional[bytes] = None # only meaningful if c is augmented
next_amount: ClauseOutputAmountBehaviour = ClauseOutputAmountBehaviour.PRESERVE_OUTPUT

def __repr__(self):
Expand All @@ -46,13 +46,13 @@ def __init__(self, name: str, script: CScript):
self.name = name
self.script = script

def stack_elements_from_args(self, args: dict) -> list[bytes]:
def stack_elements_from_args(self, args: dict) -> List[bytes]:
raise NotImplementedError

def next_outputs(self, args: dict) -> list[ClauseOutput]:
def next_outputs(self, args: dict) -> List[ClauseOutput]:
raise NotImplementedError

def args_from_stack_elements(self, elements: list[bytes]) -> dict:
def args_from_stack_elements(self, elements: List[bytes]) -> dict:
raise NotImplementedError

def __repr__(self):
Expand All @@ -65,22 +65,22 @@ def __repr__(self):
# Other types of generic treatable clauses could be defined (for example, a MiniscriptClause).
# Moreover, it specifies a function that converts the arguments of the clause, to the data of the next output.
class StandardClause(Clause):
def __init__(self, name: str, script: CScript, arg_specs: list[tuple[str, ArgType]], next_output_fn: Callable[[dict], list[ClauseOutput] | CTransaction] | None = None):
def __init__(self, name: str, script: CScript, arg_specs: List[Tuple[str, ArgType]], next_output_fn: Optional[Callable[[dict], Union[List[ClauseOutput], CTransaction]]] = None):
super().__init__(name, script)
self.arg_specs = arg_specs

self.next_outputs_fn = next_output_fn



def next_outputs(self, args: dict) -> list[ClauseOutput] | CTransaction:
def next_outputs(self, args: dict) -> Union[List[ClauseOutput], CTransaction]:
if self.next_outputs_fn is not None:
return self.next_outputs_fn(args)
else:
return []

def stack_elements_from_args(self, args: dict) -> list[bytes]:
result: list[bytes] = []
def stack_elements_from_args(self, args: dict) -> List[bytes]:
result: List[bytes] = []
for arg_name, arg_cls in self.arg_specs:
if arg_name not in args:
raise ValueError(f"Missing argument: {arg_name}")
Expand All @@ -90,7 +90,7 @@ def stack_elements_from_args(self, args: dict) -> list[bytes]:

return result

def args_from_stack_elements(self, elements: list[bytes]) -> dict:
def args_from_stack_elements(self, elements: List[bytes]) -> dict:
result = {}
cur = 0
for arg_name, arg_cls in self.arg_specs:
Expand Down Expand Up @@ -138,7 +138,7 @@ class P2TR(AbstractContract):
A class representing a Pay-to-Taproot script.
"""

def __init__(self, internal_pubkey: bytes, scripts: list[tuple[str, CScript]]):
def __init__(self, internal_pubkey: bytes, scripts: List[Tuple[str, CScript]]):
assert len(internal_pubkey) == 32

self.internal_pubkey = internal_pubkey
Expand Down Expand Up @@ -167,7 +167,7 @@ def __init__(self, naked_internal_pubkey: bytes):

self.naked_internal_pubkey = naked_internal_pubkey

def get_scripts(self) -> list[tuple[str, CScript]]:
def get_scripts(self) -> List[Tuple[str, CScript]]:
raise NotImplementedError("This must be implemented in subclasses")

def get_taptree(self) -> bytes:
Expand All @@ -190,15 +190,15 @@ class StandardP2TR(P2TR):
A StandardP2TR where all the transitions are given by a StandardClause.
"""

def __init__(self, internal_pubkey: bytes, clauses: list[StandardClause]):
def __init__(self, internal_pubkey: bytes, clauses: List[StandardClause]):
super().__init__(internal_pubkey, list(map(lambda x: (x.name, x.script), clauses)))
self.clauses = clauses
self._clauses_dict = {clause.name: clause for clause in clauses}

def get_scripts(self) -> list[tuple[str, CScript]]:
def get_scripts(self) -> List[Tuple[str, CScript]]:
return list(map(lambda clause: (clause.name, clause.script), self.clauses))

def decode_wit_stack(self, stack_elems: list[bytes]) -> tuple[str, dict]:
def decode_wit_stack(self, stack_elems: List[bytes]) -> Tuple[str, dict]:
leaf_hash = stack_elems[-2]

clause_name = None
Expand All @@ -220,15 +220,15 @@ class StandardAugmentedP2TR(AugmentedP2TR):
An AugmentedP2TR where all the transitions are given by a StandardClause.
"""

def __init__(self, naked_internal_pubkey: bytes, clauses: list[StandardClause]):
def __init__(self, naked_internal_pubkey: bytes, clauses: List[StandardClause]):
super().__init__(naked_internal_pubkey)
self.clauses = clauses
self._clauses_dict = {clause.name: clause for clause in clauses}

def get_scripts(self) -> list[tuple[str, CScript]]:
def get_scripts(self) -> List[Tuple[str, CScript]]:
return list(map(lambda clause: (clause.name, clause.script), self.clauses))

def decode_wit_stack(self, data: bytes, stack_elems: list[bytes]) -> tuple[str, dict]:
def decode_wit_stack(self, data: bytes, stack_elems: List[bytes]) -> Tuple[str, dict]:
leaf_hash = stack_elems[-2]

clause_name = None
Expand All @@ -252,7 +252,7 @@ def __repr__(self):
# would include other info to help the signer decide (e.g.: the transaction)
# There are no bad people here, though, so we keep it simple for now.
class SchnorrSigner:
def __init__(self, keys: key.ExtendedKey | list[key.ExtendedKey]):
def __init__(self, keys: Union[key.ExtendedKey, List[key.ExtendedKey]]):
if not isinstance(keys, list):
keys = [keys]

Expand All @@ -262,7 +262,7 @@ def __init__(self, keys: key.ExtendedKey | list[key.ExtendedKey]):

self.keys = keys

def sign(self, msg: bytes, pubkey: bytes) -> bytes | None:
def sign(self, msg: bytes, pubkey: bytes) -> Optional[bytes]:
if len(msg) != 32:
raise ValueError("msg should be 32 bytes long")
if len(pubkey) != 32:
Expand All @@ -282,7 +282,7 @@ class ContractInstanceStatus(Enum):


class ContractInstance:
def __init__(self, contract: StandardP2TR | StandardAugmentedP2TR):
def __init__(self, contract: Union[StandardP2TR, StandardAugmentedP2TR]):
self.contract = contract
self.data = None if not self.is_augm() else b'\0'*32

Expand All @@ -293,10 +293,10 @@ def __init__(self, contract: StandardP2TR | StandardAugmentedP2TR):
self.last_height = 0

self.status = ContractInstanceStatus.ABSTRACT
self.outpoint: COutPoint | None = None
self.funding_tx: CTransaction | None = None
self.outpoint: Optional[COutPoint] = None
self.funding_tx: Optional[CTransaction] = None

self.spending_tx: CTransaction | None = None
self.spending_tx: Optional[CTransaction] = None
self.spending_vin = None

self.spending_clause = None
Expand All @@ -323,7 +323,7 @@ def get_value(self) -> int:
raise ValueError("contract not funded, or funding transaction unknown")
return self.funding_tx.vout[self.outpoint.n].nValue

def decode_wit_stack(self, stack_elems: list[bytes]) -> tuple[str, dict]:
def decode_wit_stack(self, stack_elems: List[bytes]) -> Tuple[str, dict]:
if self.is_augm():
return self.contract.decode_wit_stack(self.data, stack_elems)
else:
Expand All @@ -335,7 +335,7 @@ def __repr__(self):
value = self.funding_tx.vout[self.outpoint.n].nValue
return f"{self.__class__.__name__}(contract={self.contract}, data={self.data if self.data is None else self.data.hex()}, value={value}, status={self.status}, outpoint={self.outpoint})"

def __call__(self, clause_name: str, *, signer: Optional[SchnorrSigner] = None, outputs: list[CTxOut] = [], **kwargs) -> list['ContractInstance']:
def __call__(self, clause_name: str, *, signer: Optional[SchnorrSigner] = None, outputs: List[CTxOut] = [], **kwargs) -> List['ContractInstance']:
if self.manager is None:
raise ValueError("Direct invocation is only allowed after adding the instance to a ContractManager")

Expand All @@ -346,13 +346,13 @@ def __call__(self, clause_name: str, *, signer: Optional[SchnorrSigner] = None,


class ContractManager:
def __init__(self, contract_instances: list[ContractInstance], rpc: AuthServiceProxy, *, poll_interval: float = 1, mine_automatically: bool = False):
def __init__(self, contract_instances: List[ContractInstance], rpc: AuthServiceProxy, *, poll_interval: float = 1, mine_automatically: bool = False):
self.instances = contract_instances
self.mine_automatically = mine_automatically
self.rpc = rpc
self.poll_interval = poll_interval

def _check_instance(self, instance: ContractInstance, exp_statuses: None | ContractInstanceStatus | list[ContractInstanceStatus] = None):
def _check_instance(self, instance: ContractInstance, exp_statuses: Optional[Union[ContractInstanceStatus, List[ContractInstanceStatus]]] = None):
if exp_statuses is not None:
if isinstance(exp_statuses, ContractInstanceStatus):
if instance.status != exp_statuses:
Expand All @@ -371,7 +371,7 @@ def add_instance(self, instance: ContractInstance):
instance.manager = self
self.instances.append(instance)

def wait_for_outpoint(self, instance: ContractInstance, txid: str | None = None):
def wait_for_outpoint(self, instance: ContractInstance, txid: Optional[str] = None):
self._check_instance(instance, exp_statuses=ContractInstanceStatus.ABSTRACT)
if instance.is_augm():
if instance.data is None:
Expand All @@ -394,15 +394,15 @@ def wait_for_outpoint(self, instance: ContractInstance, txid: str | None = None)

def get_spend_tx(
self,
spends: tuple[ContractInstance, str, dict] | list[tuple[ContractInstance, str, dict]],
output_amounts: dict[int, int] = {}
) -> tuple[CTransaction, list[bytes]]:
spends: Union[Tuple[ContractInstance, str, dict], List[Tuple[ContractInstance, str, dict]]],
output_amounts: Dict[int, int] = {}
) -> Tuple[CTransaction, List[bytes]]:
if not isinstance(spends, list):
spends = [spends]

tx = CTransaction()
tx.nVersion = 2
outputs_map: dict[int, CTxOut] = {}
outputs_map: Dict[int, CTxOut] = {}

tx.vin = [CTxIn(outpoint=instance.outpoint) for instance, _, _ in spends]

Expand Down Expand Up @@ -464,7 +464,7 @@ def get_spend_tx(
tx.vout = [outputs_map[i] for i in range(len(outputs_map))]

# TODO: generalize for keypath spend?
sighashes: list[bytes] = []
sighashes: List[bytes] = []
spent_utxos = []

# TODO: simplify
Expand Down Expand Up @@ -501,11 +501,11 @@ def get_spend_wit(self, instance: ContractInstance, clause_name: str, wargs: dic
]
return in_wit

def _mine_blocks(self, n_blocks: int = 1) -> list[str]:
def _mine_blocks(self, n_blocks: int = 1) -> List[str]:
address = self.rpc.getnewaddress()
return self.rpc.generatetoaddress(n_blocks, address)

def spend_and_wait(self, instances: ContractInstance | list[ContractInstance], tx: CTransaction) -> list[ContractInstance]:
def spend_and_wait(self, instances: Union[ContractInstance, List[ContractInstance]], tx: CTransaction) -> List[ContractInstance]:
if isinstance(instances, ContractInstance):
instances = [instances]

Expand All @@ -520,11 +520,11 @@ def spend_and_wait(self, instances: ContractInstance | list[ContractInstance], t
self._mine_blocks(1)
return self.wait_for_spend(instances)

def wait_for_spend(self, instances: ContractInstance | list[ContractInstance]) -> list[ContractInstance]:
def wait_for_spend(self, instances: Union[ContractInstance, List[ContractInstance]]) -> List[ContractInstance]:
if isinstance(instances, ContractInstance):
instances = [instances]

out_contracts: dict[int, ContractInstance] = {}
out_contracts: Dict[int, ContractInstance] = {}

for instance in instances:
self._check_instance(instance, exp_statuses=ContractInstanceStatus.FUNDED)
Expand Down Expand Up @@ -590,7 +590,7 @@ def wait_for_spend(self, instances: ContractInstance | list[ContractInstance]) -
self.add_instance(instance)
return result

def fund_instance(self, contract: StandardP2TR | StandardAugmentedP2TR, amount: int, data: Optional[bytes] = None) -> ContractInstance:
def fund_instance(self, contract: Union[StandardP2TR, StandardAugmentedP2TR], amount: int, data: Optional[bytes] = None) -> ContractInstance:
"""
Convenience method to create an instance of a contract, add it to the ContractManager,
and send a transaction to fund it with a certain amount.
Expand All @@ -609,7 +609,7 @@ def fund_instance(self, contract: StandardP2TR | StandardAugmentedP2TR, amount:
self.wait_for_outpoint(instance, txid)
return instance

def spend_instance(self, instance: ContractInstance, clause_name: str, args: dict, *, signer: Optional[SchnorrSigner], outputs: Optional[list[CTxOut]] = None) -> list[ContractInstance]:
def spend_instance(self, instance: ContractInstance, clause_name: str, args: dict, *, signer: Optional[SchnorrSigner], outputs: Optional[List[CTxOut]] = None) -> List[ContractInstance]:
"""
Creates and broadcasts a transaction that spends a contract instance using a specified clause and arguments.
Expand Down
Loading

0 comments on commit 15cc354

Please sign in to comment.