diff --git a/examples/vault/vault_contracts.py b/examples/vault/vault_contracts.py index 8a6afec..3995ce5 100644 --- a/examples/vault/vault_contracts.py +++ b/examples/vault/vault_contracts.py @@ -7,7 +7,7 @@ class Vault(StandardP2TR): - def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: bytes, unvault_pk: bytes): + def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: bytes, unvault_pk: bytes, *, has_partial_revault=True, has_early_recover=True): assert (alternate_pk is None or len(alternate_pk) == 32) and len(recover_pk) == 32 and len(unvault_pk) self.alternate_pk = alternate_pk @@ -16,6 +16,9 @@ def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: unvaulting = Unvaulting(alternate_pk, spend_delay, recover_pk) + self.has_partial_revault = has_partial_revault + self.has_early_recover = has_early_recover + # witness: trigger = StandardClause( name="trigger", @@ -89,7 +92,18 @@ def __init__(self, alternate_pk: Optional[bytes], spend_delay: int, recover_pk: next_output_fn=lambda args: [ClauseOutput(n=args['out_i'], next_contract=OpaqueP2TR(recover_pk))] ) - super().__init__(NUMS_KEY if alternate_pk is None else alternate_pk, [trigger, [trigger_and_recover, recover]]) + if self.has_partial_revault: + if self.has_early_recover: + clauses = [trigger, [trigger_and_recover, recover]] + else: + clauses = [trigger, trigger_and_recover] + else: + if self.has_early_recover: + clauses = [trigger, recover] + else: + clauses = trigger + + super().__init__(NUMS_KEY if alternate_pk is None else alternate_pk, clauses) class Unvaulting(StandardAugmentedP2TR): diff --git a/tests/test_vault.py b/tests/test_vault.py index e5154cd..96940c0 100644 --- a/tests/test_vault.py +++ b/tests/test_vault.py @@ -1,4 +1,5 @@ from io import TextIOWrapper +from typing import Tuple import pytest from examples.vault.vault_contracts import Vault @@ -19,12 +20,39 @@ "tprv8ZgxMBicQKsPeDvaW4xxmiMXxqakLgvukT8A5GR6mRwBwjsDJV1jcZab8mxSerNcj22YPrusm2Pz5oR8LTw9GqpWT51VexTNBzxxm49jCZZ") -def test_vault_recover(manager: ContractManager, report_file: TextIOWrapper): - V = Vault(None, 10, recover_priv_key.pubkey[1:], unvault_priv_key.pubkey[1:]) +locktime = 10 + +VaultSpecs = Tuple[str, Vault] + + +V_full: VaultSpecs = ( + "Vault", + Vault(None, locktime, recover_priv_key.pubkey[1:], unvault_priv_key.pubkey[1:]) +) +V_no_partial_revault: VaultSpecs = ( + "Vault[no partial revault]", + Vault(None, locktime, recover_priv_key.pubkey[1:], unvault_priv_key.pubkey[1:], has_partial_revault=False) +) + +V_no_early_recover: VaultSpecs = ( + "Vault[no early recover]", + Vault(None, locktime, recover_priv_key.pubkey[1:], unvault_priv_key.pubkey[1:], has_early_recover=False) +) + +V_light: VaultSpecs = ( + "Vault[light]", + Vault(None, locktime, recover_priv_key.pubkey[1:], unvault_priv_key.pubkey[1:], + has_partial_revault=False, has_early_recover=False) +) + + +@pytest.mark.parametrize("vault_specs", [V_full, V_no_partial_revault]) +def test_vault_recover(vault_specs: VaultSpecs, manager: ContractManager, report_file: TextIOWrapper): + vault_description, vault_contract = vault_specs amount = 20_000 - V_inst = manager.fund_instance(V, amount) + V_inst = manager.fund_instance(vault_contract, amount) out_instances = V_inst("recover", out_i=0) @@ -33,20 +61,21 @@ def test_vault_recover(manager: ContractManager, report_file: TextIOWrapper): assert out.nValue == amount assert out.scriptPubKey == OpaqueP2TR(recover_priv_key.pubkey[1:]).get_tr_info().scriptPubKey - report_file.write(format_tx_markdown(V_inst.spending_tx, "Recovery from vault, 1 input [NoRecoveryAuth]")) + report_file.write(format_tx_markdown(V_inst.spending_tx, + f"{vault_description}: Recovery from vault, 1 input [NoRecoveryAuth]")) assert len(out_instances) == 0 -def test_vault_trigger_and_recover(manager: ContractManager, report_file: TextIOWrapper): - locktime = 10 - V = Vault(None, locktime, recover_priv_key.pubkey[1:], unvault_priv_key.pubkey[1:]) +@pytest.mark.parametrize("vault_specs", [V_full, V_no_partial_revault, V_no_early_recover, V_light]) +def test_vault_trigger_and_recover(vault_specs: VaultSpecs, manager: ContractManager, report_file: TextIOWrapper): + vault_description, vault_contract = vault_specs signer = SchnorrSigner(unvault_priv_key) amount = 4999990000 - V_inst = manager.fund_instance(V, amount) + V_inst = manager.fund_instance(vault_contract, amount) ctv_tmpl = make_ctv_template([ ("bcrt1qqy0kdmv0ckna90ap6efd6z39wcdtpfa3a27437", 4999990000), @@ -57,24 +86,24 @@ def test_vault_trigger_and_recover(manager: ContractManager, report_file: TextIO [U_inst] = V_inst("trigger", signer=signer, out_i=0, ctv_hash=ctv_tmpl.get_standard_template_hash(0)) - report_file.write(format_tx_markdown(V_inst.spending_tx, "Trigger [3 vault inputs]")) + report_file.write(format_tx_markdown(V_inst.spending_tx, f"{vault_description}: Trigger")) out_instances = U_inst("recover", out_i=0) assert len(out_instances) == 0 - report_file.write(format_tx_markdown(U_inst.spending_tx, "Recovery from trigger")) + report_file.write(format_tx_markdown(U_inst.spending_tx, f"{vault_description}: Recovery from trigger")) -def test_vault_trigger_and_withdraw(rpc: AuthServiceProxy, manager: ContractManager, report_file: TextIOWrapper): - locktime = 10 - V = Vault(None, locktime, recover_priv_key.pubkey[1:], unvault_priv_key.pubkey[1:]) +@pytest.mark.parametrize("vault_specs", [V_full, V_no_partial_revault, V_no_early_recover, V_light]) +def test_vault_trigger_and_withdraw(vault_specs: VaultSpecs, rpc: AuthServiceProxy, manager: ContractManager, report_file: TextIOWrapper): + vault_description, vault_contract = vault_specs signer = SchnorrSigner(unvault_priv_key) amount = 4999990000 - V_inst = manager.fund_instance(V, amount) + V_inst = manager.fund_instance(vault_contract, amount) ctv_tmpl = make_ctv_template([ ("bcrt1qqy0kdmv0ckna90ap6efd6z39wcdtpfa3a27437", 1666663333), @@ -85,7 +114,7 @@ def test_vault_trigger_and_withdraw(rpc: AuthServiceProxy, manager: ContractMana [U_inst] = V_inst("trigger", signer=signer, out_i=0, ctv_hash=ctv_tmpl.get_standard_template_hash(0)) - report_file.write(format_tx_markdown(V_inst.spending_tx, "Trigger [3 vault inputs]")) + report_file.write(format_tx_markdown(V_inst.spending_tx, f"{vault_description}: Trigger")) spend_tx, _ = manager.get_spend_tx( (U_inst, "withdraw", {"ctv_hash": ctv_tmpl.get_standard_template_hash(0)}) @@ -112,21 +141,21 @@ def test_vault_trigger_and_withdraw(rpc: AuthServiceProxy, manager: ContractMana manager.spend_and_wait(U_inst, spend_tx) - report_file.write(format_tx_markdown(U_inst.spending_tx, "Withdraw [3 outputs]")) + report_file.write(format_tx_markdown(U_inst.spending_tx, f"{vault_description}: Withdraw [3 outputs]")) -def test_vault_trigger_with_revault_and_withdraw(rpc: AuthServiceProxy, manager: ContractManager, report_file: TextIOWrapper): +@pytest.mark.parametrize("vault_specs", [V_full, V_no_early_recover]) +def test_vault_trigger_with_revault_and_withdraw(vault_specs: VaultSpecs, rpc: AuthServiceProxy, manager: ContractManager, report_file: TextIOWrapper): # get coins on 3 different Vaults, then trigger with partial withdrawal # one of the vault uses "trigger_with_revault", the others us normal "trigger" - locktime = 10 - amount = 4999990000 + vault_description, vault_contract = vault_specs - V = Vault(None, locktime, recover_priv_key.pubkey[1:], unvault_priv_key.pubkey[1:]) + amount = 4999990000 - V_inst_1 = manager.fund_instance(V, amount) - V_inst_2 = manager.fund_instance(V, amount) - V_inst_3 = manager.fund_instance(V, amount) + V_inst_1 = manager.fund_instance(vault_contract, amount) + V_inst_2 = manager.fund_instance(vault_contract, amount) + V_inst_3 = manager.fund_instance(vault_contract, amount) ctv_tmpl = make_ctv_template([ ("bcrt1qqy0kdmv0ckna90ap6efd6z39wcdtpfa3a27437", 4999990000), @@ -158,7 +187,7 @@ def test_vault_trigger_with_revault_and_withdraw(rpc: AuthServiceProxy, manager: [U_inst] = manager.spend_and_wait([V_inst_1, V_inst_2, V_inst_3], spend_tx) - report_file.write(format_tx_markdown(spend_tx, "Trigger (with revault) [3 vault inputs]")) + report_file.write(format_tx_markdown(spend_tx, f"{vault_description}: Trigger (with revault) [3 vault inputs]")) spend_tx, _ = manager.get_spend_tx( (U_inst, "withdraw", {"ctv_hash": ctv_tmpl.get_standard_template_hash(0)}) @@ -182,4 +211,4 @@ def test_vault_trigger_with_revault_and_withdraw(rpc: AuthServiceProxy, manager: manager.spend_and_wait(U_inst, spend_tx) - report_file.write(format_tx_markdown(U_inst.spending_tx, "Withdraw (3 outputs)")) + report_file.write(format_tx_markdown(U_inst.spending_tx, f"{vault_description}: Withdraw (3 outputs)"))