Skip to content

Commit

Permalink
Exit signature update cache (#329)
Browse files Browse the repository at this point in the history
* Add ExitSignatureUpdateCache

* Move OraclesCache to AppState

* Move OraclesCache to app_state.py
  • Loading branch information
evgeny-stakewise authored May 3, 2024
1 parent 939ab7d commit f4a9013
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 36 deletions.
26 changes: 26 additions & 0 deletions src/common/app_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import dataclass

from eth_typing import BlockNumber

from src.common.typings import Singleton


@dataclass
class OraclesCache:
checkpoint_block: BlockNumber
config: dict
validators_threshold: int
rewards_threshold: int


@dataclass
class ExitSignatureUpdateCache:
checkpoint_block: BlockNumber | None = None
last_event_block: BlockNumber | None = None


# pylint: disable-next=too-few-public-methods
class AppState(metaclass=Singleton):
def __init__(self):
self.exit_signature_update_cache = ExitSignatureUpdateCache()
self.oracles_cache: OraclesCache | None = None
11 changes: 8 additions & 3 deletions src/common/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,14 @@ async def get_last_rewards_update(self) -> RewardVoteInfo | None:
)
return voting_info

async def get_exit_signatures_updated_event(self, vault: ChecksumAddress) -> EventData | None:
from_block = settings.network_config.KEEPER_GENESIS_BLOCK
to_block = await execution_client.eth.get_block_number()
async def get_exit_signatures_updated_event(
self,
vault: ChecksumAddress,
from_block: BlockNumber | None = None,
to_block: BlockNumber | None = None,
) -> EventData | None:
from_block = from_block or settings.network_config.KEEPER_GENESIS_BLOCK
to_block = to_block or await execution_client.eth.get_block_number()

last_event = await self._get_last_event(
self.events.ExitSignaturesUpdated,
Expand Down
21 changes: 10 additions & 11 deletions src/common/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from web3.exceptions import BadFunctionCallOutput
from web3.types import TxParams, Wei

from src.common.app_state import AppState, OraclesCache
from src.common.clients import execution_client, ipfs_fetch_client
from src.common.contracts import keeper_contract, multicall_contract, vault_contract
from src.common.metrics import metrics
from src.common.tasks import BaseTask
from src.common.typings import OraclesCache
from src.common.wallet import hot_wallet
from src.config.settings import settings

Expand Down Expand Up @@ -53,18 +53,16 @@ async def check_hot_wallet_balance() -> None:
)


_oracles_cache: OraclesCache | None = None


async def update_oracles_cache() -> None:
"""
Fetches latest oracle config from IPFS. Uses cache if possible.
"""
global _oracles_cache # pylint: disable=global-statement
app_state = AppState()
oracles_cache = app_state.oracles_cache

# Find the latest block for which oracle config is cached
if _oracles_cache:
from_block = BlockNumber(_oracles_cache.checkpoint_block + 1)
if oracles_cache:
from_block = BlockNumber(oracles_cache.checkpoint_block + 1)
else:
from_block = settings.network_config.KEEPER_GENESIS_BLOCK

Expand All @@ -73,13 +71,13 @@ async def update_oracles_cache() -> None:
if from_block > to_block:
return

logger.debug('update_oracles_cache: get logs from_block %s, to_block %s', from_block, to_block)
logger.debug('update_oracles_cache: get logs from block %s to block %s', from_block, to_block)
event = await keeper_contract.get_config_updated_event(from_block=from_block, to_block=to_block)
if event:
ipfs_hash = event['args']['configIpfsHash']
config = cast(dict, await ipfs_fetch_client.fetch_json(ipfs_hash))
else:
config = _oracles_cache.config # type: ignore
config = oracles_cache.config # type: ignore

rewards_threshold_call = keeper_contract.encode_abi(fn_name='rewardsMinOracles', args=[])
validators_threshold_call = keeper_contract.encode_abi(fn_name='validatorsMinOracles', args=[])
Expand All @@ -93,7 +91,7 @@ async def update_oracles_cache() -> None:
rewards_threshold = Web3.to_int(multicall_response[0][1])
validators_threshold = Web3.to_int(multicall_response[1][1])

_oracles_cache = OraclesCache(
app_state.oracles_cache = OraclesCache(
config=config,
validators_threshold=validators_threshold,
rewards_threshold=rewards_threshold,
Expand All @@ -103,8 +101,9 @@ async def update_oracles_cache() -> None:

async def get_protocol_config() -> ProtocolConfig:
await update_oracles_cache()
app_state = AppState()

oracles_cache = cast(OraclesCache, _oracles_cache)
oracles_cache = cast(OraclesCache, app_state.oracles_cache)
pc = build_protocol_config(
config_data=oracles_cache.config,
rewards_threshold=oracles_cache.rewards_threshold,
Expand Down
18 changes: 9 additions & 9 deletions src/common/typings.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
from dataclasses import dataclass

from eth_typing import BlockNumber
from web3.types import Wei


@dataclass
class OraclesCache:
checkpoint_block: BlockNumber
config: dict
validators_threshold: int
rewards_threshold: int


@dataclass
class RewardVoteInfo:
rewards_root: bytes
Expand All @@ -38,3 +29,12 @@ class OraclesApproval:
signatures: bytes
ipfs_hash: str
deadline: int


class Singleton(type):
_instances: dict = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
10 changes: 1 addition & 9 deletions src/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from web3 import Web3
from web3.types import ChecksumAddress

from src.common.typings import Singleton
from src.config.networks import HOLESKY, MAINNET, NETWORKS, NetworkConfig
from src.validators.typings import ValidatorsRegistrationMode

Expand All @@ -20,15 +21,6 @@
DEFAULT_API_PORT = 8000


class Singleton(type):
_instances: dict = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]


# pylint: disable-next=too-many-public-methods,too-many-instance-attributes
class Settings(metaclass=Singleton):
vault: ChecksumAddress
Expand Down
28 changes: 25 additions & 3 deletions src/exits/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from tenacity import RetryError
from web3.types import HexStr

from src.common.app_state import AppState
from src.common.clients import execution_client
from src.common.contracts import keeper_contract
from src.common.exceptions import NotEnoughOracleApprovalsError
from src.common.execution import get_protocol_config
Expand Down Expand Up @@ -41,6 +43,9 @@ async def process_block(self, interrupt_handler: InterruptHandler) -> None:

protocol_config = await get_protocol_config()
update_block = await _fetch_last_update_block()

logger.debug('last exit signature update block %s', update_block)

if update_block and not await is_block_finalized(update_block):
logger.info('Waiting for signatures update block %d to finalize...', update_block)
return
Expand Down Expand Up @@ -95,10 +100,27 @@ async def _fetch_last_update_block_replicas(replicas: list[str]) -> BlockNumber


async def _fetch_last_update_block() -> BlockNumber | None:
last_event = await keeper_contract.get_exit_signatures_updated_event(vault=settings.vault)
app_state = AppState()
update_cache = app_state.exit_signature_update_cache

from_block: BlockNumber | None = None
if (checkpoint_block := update_cache.checkpoint_block) is not None:
from_block = BlockNumber(checkpoint_block + 1)

to_block = await execution_client.eth.get_block_number()

if from_block is not None and from_block > to_block:
return update_cache.last_event_block

last_event = await keeper_contract.get_exit_signatures_updated_event(
vault=settings.vault, from_block=from_block, to_block=to_block
)
update_cache.checkpoint_block = to_block

if last_event:
return BlockNumber(last_event['blockNumber'])
return None
update_cache.last_event_block = BlockNumber(last_event['blockNumber'])

return update_cache.last_event_block


async def _fetch_outdated_indexes(
Expand Down
57 changes: 56 additions & 1 deletion src/exits/tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import contextlib
from pathlib import Path
from random import randint
from typing import Callable
from unittest import mock
from unittest.mock import AsyncMock

import pytest
from eth_typing import ChecksumAddress
from sw_utils.typings import ConsensusFork, ProtocolConfig

from src.common.utils import get_current_timestamp
from src.config.settings import settings
from src.exits.tasks import _get_oracles_request
from src.exits.tasks import _fetch_last_update_block, _get_oracles_request
from src.validators.keystores.local import Keys, LocalKeystore
from src.validators.keystores.remote import RemoteSignerKeystore

Expand Down Expand Up @@ -79,3 +81,56 @@ async def test_remote_signer(
== list(validators.values())[: protocol_config.validators_exit_rotation_batch_limit]
)
assert request.deadline == deadline


@contextlib.contextmanager
def patch_latest_block(block_number):
with mock.patch('src.exits.tasks.execution_client', new=AsyncMock()) as execution_client_mock:
execution_client_mock.eth.get_block_number.return_value = block_number
yield


@pytest.mark.usefixtures('fake_settings')
class TestFetchLastExitSignatureUpdateBlock:
async def test_normal(self):
get_event_func = 'src.exits.tasks.keeper_contract.get_exit_signatures_updated_event'

# no events, checkpoint moved from None to 8
with (
mock.patch(get_event_func, return_value=None) as get_event_mock,
patch_latest_block(8),
):
last_update_block = await _fetch_last_update_block()

assert last_update_block is None
get_event_mock.assert_called_once_with(vault=settings.vault, from_block=None, to_block=8)

# no events, checkpoint moved to 9
with (
mock.patch(get_event_func, return_value=None) as get_event_mock,
patch_latest_block(9),
):
last_update_block = await _fetch_last_update_block()

assert last_update_block is None
get_event_mock.assert_called_once_with(vault=settings.vault, from_block=9, to_block=9)

# event is found, checkpoint moved to 15
with (
mock.patch(get_event_func, return_value=dict(blockNumber=11)) as get_event_mock,
patch_latest_block(15),
):
last_update_block = await _fetch_last_update_block()

assert last_update_block == 11
get_event_mock.assert_called_once_with(vault=settings.vault, from_block=10, to_block=15)

# no events, checkpoint moved to 20
with (
mock.patch(get_event_func, return_value=None) as get_event_mock,
patch_latest_block(20),
):
last_update_block = await _fetch_last_update_block()

assert last_update_block == 11
get_event_mock.assert_called_once_with(vault=settings.vault, from_block=16, to_block=20)

0 comments on commit f4a9013

Please sign in to comment.