From 8725549a5408d5d33ce5dc28f8ada7cd7cd8b6ba Mon Sep 17 00:00:00 2001 From: ixje Date: Mon, 9 Oct 2023 06:26:27 -0700 Subject: [PATCH] api: add `FindStorage` RPC call to `NeoRpcClient` (#288) --- neo3/api/noderpc.py | 53 ++++++++++++++++++++++++++++++++++++++- tests/api/test_noderpc.py | 49 ++++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/neo3/api/noderpc.py b/neo3/api/noderpc.py index 3fd393b..cfb2715 100644 --- a/neo3/api/noderpc.py +++ b/neo3/api/noderpc.py @@ -13,7 +13,17 @@ from enum import Enum, IntEnum from contextlib import suppress from dataclasses import dataclass -from typing import Optional, TypedDict, Any, Protocol, Iterator, Union, cast, Type +from typing import ( + Optional, + TypedDict, + Any, + Protocol, + Iterator, + Union, + cast, + Type, + AsyncGenerator, +) from collections.abc import Sequence from neo3.core import types, cryptography, interfaces, serialization from neo3.contracts import manifest, nef, contract, abi @@ -887,6 +897,47 @@ async def calculate_network_fee(self, tx: bytes | transaction.Transaction) -> in result = await self._do_post("calculatenetworkfee", params) return int(result["networkfee"]) + async def find_states( + self, contract_hash: types.UInt160 | str, prefix: Optional[bytes] = None + ) -> AsyncGenerator[tuple[bytes, bytes], None]: + """ + Fetch the smart contract storage state. + + Args: + contract_hash: the hash of the smart contract to call. + prefix: storage prefix to search for. If omitted will return all storage + + Returns: + a storage key/value pair + + Examples: + # prints all deployed + prefix_contract_hash = b"\x0c" + async with api.NeoRpcClient("https://testnet1.neo.coz.io:443") as client: + async for k, v in client.find_states(CONTRACT_HASHES.MANAGEMENT, prefix_contract_hash): + print(k, v) + + """ + if isinstance(contract_hash, str): + contract_hash = types.UInt160.from_string(contract_hash) + contract_hash = f"0x{contract_hash}" + + if prefix is None: + prefix = b"" + _prefix = base64.b64encode(prefix).decode() + start = 0 + while True: + response = await self._do_post( + "findstorage", [contract_hash, _prefix, start] + ) + for pair in response["results"]: + key = base64.b64decode(pair["key"]) + value = base64.b64decode(pair["value"]) + yield key, value + if not response["truncated"]: + break + start = response["next"] + async def get_application_log_transaction( self, tx_hash: types.UInt256 | str ) -> TransactionApplicationLogResponse: diff --git a/tests/api/test_noderpc.py b/tests/api/test_noderpc.py index 35d9af0..431c355 100644 --- a/tests/api/test_noderpc.py +++ b/tests/api/test_noderpc.py @@ -14,6 +14,7 @@ class TestNeoRpcClient(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self) -> None: self.client = api.NeoRpcClient("localhost") + # CAREFULL THIS PATCHES ALL aiohttp CALLS! self.helper = aioresponses() self.helper.start() @@ -43,6 +44,54 @@ async def test_calculate_network_fee(self): ) self.assertEqual(123, response) + async def test_find_states(self): + key1 = b"\x0c\x00\x00\x00\x01" + key2 = b"\x0c\x00\x00\x00\x02" + key3 = b"\x0c\x00\x00\x00\x03" + + value1 = b"\x97\"\x8dq\xd20\xaf\xde\\\xce\x8f\xf9'\x1f*\x9d(\x88u\xf0" + value2 = b"\x92,\x15\xa9\xa0\xe9\x00\x02\xed\xb4o\x1e>\xe4\xb7V\x8c\xb7%F" + value3 = b"\xe0\x98^\x9d\xf0w\xb0\x88v\x1eV\xb3m\x97\xef\x89\x08F\x12\x13" + + captured1 = { + "truncated": True, + "next": 2, + "results": [ + { + "key": base64.b64encode(key1).decode(), + "value": base64.b64encode(value1).decode(), + }, + { + "key": base64.b64encode(key2).decode(), + "value": base64.b64encode(value2).decode(), + }, + ], + } + captured2 = { + "truncated": False, + "next": 3, + "results": [ + { + "key": base64.b64encode(key3).decode(), + "value": base64.b64encode(value3).decode(), + } + ], + } + self.mock_response(captured1) + self.mock_response(captured2) + from neo3.contracts.contract import CONTRACT_HASHES + + results = [] + async for k, v in self.client.find_states(CONTRACT_HASHES.MANAGEMENT, b"\x0c"): + results.append((k, v)) + self.assertEqual(3, len(results)) + self.assertEqual(key1, results[0][0]) + self.assertEqual(value1, results[0][1]) + self.assertEqual(key2, results[1][0]) + self.assertEqual(value2, results[1][1]) + self.assertEqual(key3, results[2][0]) + self.assertEqual(value3, results[2][1]) + async def test_get_application_log_transaction(self): captured = { "txid": "0x7da6ae7ff9d0b7af3d32f3a2feb2aa96c2a27ef8b651f9a132cfaad6ef20724c",