Skip to content

Commit

Permalink
feat(lnprototest): Add stash function to the connection
Browse files Browse the repository at this point in the history
This commit introduce the ability to stash information inside
the connection.

This function allow lnprototest to build connection that can
keep state by connection.

Signed-off-by: Vincenzo Palazzo <vincenzopalazzodev@gmail.com>
  • Loading branch information
vincenzopalazzo committed Nov 11, 2024
1 parent 3de9543 commit ff7fc08
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
13 changes: 7 additions & 6 deletions lnprototest/dummyrunner.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
#! /usr/bin/python3
# #### Dummy runner which you should replace with real one. ####
import io
from .runner import Runner, Conn
from .event import Event, ExpectMsg, MustNotMsg
from typing import List, Optional
from .keyset import KeySet
from typing import Any

from pyln.proto.message import (
Message,
FieldType,
DynamicArrayType,
EllipsisArrayType,
SizedArrayType,
)
from typing import Any

from .runner import Runner, Conn
from .event import Event, ExpectMsg, MustNotMsg
from typing import List, Optional
from .keyset import KeySet


class DummyRunner(Runner):
Expand Down
23 changes: 21 additions & 2 deletions lnprototest/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ def __str__(self) -> str:


class RunnerConn(Conn):
"""Default Connection implementation for a runner that use the pyln.proto
to open a connection over a socket."""
"""
Default Connection implementation for a runner that use the pyln.proto
to open a connection over a socket.
Each connection has an internal memory to stash information
and keep connection state.
"""

def __init__(
self,
Expand All @@ -58,6 +63,20 @@ def __init__(
host,
port,
)
self.stash: Dict[str, Dict[str, Any]] = {}
self.logger = logging.getLogger(__name__)

def add_stash(self, stashname: str, vals: Any) -> None:
"""Add a dict to the stash."""
self.stash[stashname] = vals

def get_stash(self, event: Event, stashname: str, default: Any = None) -> Any:
"""Get an entry from the stash."""
if stashname not in self.stash:
if default is not None:
return default
raise SpecFileError(event, "Unknown stash name {}".format(stashname))
return self.stash[stashname]


class Runner(ABC):
Expand Down

0 comments on commit ff7fc08

Please sign in to comment.