From 1ed1202c0dff7c64c5b3639e5701bcbad3392f39 Mon Sep 17 00:00:00 2001 From: SamDanielThangarajan <12202554+SamDanielThangarajan@users.noreply.github.com> Date: Mon, 30 Sep 2024 12:16:23 +0200 Subject: [PATCH] added itch cli tools and codegen --- pyproject.toml | 1 + .../common/message/structures.py | 17 +- .../message/templates/itch_tail.mustache | 31 ++++ src/nasdaq_protocols/common/utils.py | 12 +- src/nasdaq_protocols/itch/codegen.py | 24 ++- src/nasdaq_protocols/itch/core.py | 3 + src/nasdaq_protocols/itch/tools.py | 39 +++++ src/nasdaq_protocols/ouch/core.py | 3 + src/nasdaq_protocols/soup/_reader.py | 13 +- tests/conftest.py | 33 +++- tests/test_common_message_structures.py | 16 +- tests/test_itch_codegen.py | 14 ++ tests/test_itch_tools.py | 161 ++++++++++++++++++ tests/test_soup_session.py | 7 + 14 files changed, 361 insertions(+), 13 deletions(-) create mode 100644 src/nasdaq_protocols/common/message/templates/itch_tail.mustache create mode 100644 src/nasdaq_protocols/itch/tools.py create mode 100644 tests/test_itch_tools.py diff --git a/pyproject.toml b/pyproject.toml index 6d153ab..1cfa33d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dynamic = ["version"] [project.scripts] nasdaq-ouch-codegen="nasdaq_protocols.ouch.codegen:generate" nasdaq-itch-codegen="nasdaq_protocols.itch.codegen:generate" +nasdaq-itch-tools-codegen="nasdaq_protocols.itch.codegen:generate_itch_tools" nasdaq-protocols-create-new-project="nasdaq_protocols.tools.new_project:create" diff --git a/src/nasdaq_protocols/common/message/structures.py b/src/nasdaq_protocols/common/message/structures.py index 2b5d5b9..7fb3062 100644 --- a/src/nasdaq_protocols/common/message/structures.py +++ b/src/nasdaq_protocols/common/message/structures.py @@ -2,6 +2,7 @@ This module contains the structures used to represent the messages in the protocol. """ import inspect +import json from enum import Enum from itertools import chain from collections import OrderedDict, defaultdict @@ -10,7 +11,8 @@ from nasdaq_protocols.common.utils import logable from nasdaq_protocols.common.types import Serializable -from .types import TypeDefinition, Short, Boolean +from nasdaq_protocols.common.types import TypeDefinition +from .types import Short, Boolean __all__ = [ @@ -89,6 +91,9 @@ def to_str(cls, record: Type['_Record']): def __str__(self): return f"{{'{self.__class__.__name__}':{{{self.values}}}}}" + def as_collection(self): + return self.values + @classmethod def from_str(cls, _str): raise NotImplementedError('this method is not available') @@ -232,7 +237,8 @@ def __init_subclass__(cls, **kwargs): cls.MsgId = kwargs.get('msg_id', cls.MsgId) if all(_ in kwargs for _ in ['app_name', 'msg_id_cls', 'msg_id']): - if cls.MsgId in CommonMessage.MsgIdToClsMap[cls.AppName]: + if (cls.MsgId in CommonMessage.MsgIdToClsMap[cls.AppName] and + cls.MsgIdToClsMap[cls.AppName][cls.MsgId] != cls): raise DuplicateMessageException( existing_msg=CommonMessage.MsgIdToClsMap[cls.AppName][cls.MsgId], new_msg=cls @@ -274,6 +280,13 @@ def __setattr__(self, key, value): except KeyError: self.__dict__[key] = value + def __str__(self): + data = { + 'message': f'{self.__class__.__name__}[{self.MsgId}]', + 'body': self.record.as_collection() + } + return json.dumps(data, indent=2) + @classmethod def get_msg_classes(cls) -> list[Type['CommonMessage']]: return CommonMessage.MsgIdToClsMap[cls.AppName].values() diff --git a/src/nasdaq_protocols/common/message/templates/itch_tail.mustache b/src/nasdaq_protocols/common/message/templates/itch_tail.mustache new file mode 100644 index 0000000..addaabe --- /dev/null +++ b/src/nasdaq_protocols/common/message/templates/itch_tail.mustache @@ -0,0 +1,31 @@ +import asyncio +import click + +from nasdaq_protocols.common import utils +from nasdaq_protocols.itch import tools + +import {{package}} as app + + +@click.command() +@click.option('-h', '--host', required=True) +@click.option('-p', '--port', required=True) +@click.option('-U', '--user', required=True) +@click.option('-P', '--password', required=True) +@click.option('-S', '--session', default='', show_default=True) +@click.option('-s', '--sequence', default=1, show_default=True) +@click.option('-b', '--client-heartbeat-interval', default=1, show_default=True) +@click.option('-B', '--server-heartbeat-interval', default=1, show_default=True) +@click.option('-v', '--verbose', count=True) +def {{app}}_tail(host, port, user, password, session, sequence, + client_heartbeat_interval, server_heartbeat_interval, + verbose): + """ Simple command that tails itch messages""" + utils.enable_logging_tools(verbose) + asyncio.run( + tools.tail_itch( + (host, port), user, password, session, sequence, + app.connect_async, + client_heartbeat_interval, server_heartbeat_interval + ) + ) diff --git a/src/nasdaq_protocols/common/utils.py b/src/nasdaq_protocols/common/utils.py index f63c613..f006ae2 100644 --- a/src/nasdaq_protocols/common/utils.py +++ b/src/nasdaq_protocols/common/utils.py @@ -10,7 +10,8 @@ 'logable', 'stop_task', 'start_server', - 'Validators' + 'Validators', + 'enable_logging_tools' ] _StopTaskTypes = asyncio.Task | Stoppable _logger = logging.getLogger(__name__) @@ -73,6 +74,15 @@ async def start_server(remote, session_factory, spin_timeout=0.001, *, name='ser return server, task +def enable_logging_tools(verbose: int, format_: str = '%(asctime)s> %(levelname)s %(message)s'): + level = logging.WARN + if verbose == 1: + level = logging.INFO + elif verbose >= 2: + level = logging.DEBUG + logging.basicConfig(format=format_, level=level, datefmt='%Y-%m-%d %H:%M:%S') + + class Validators: @staticmethod def not_none(): diff --git a/src/nasdaq_protocols/itch/codegen.py b/src/nasdaq_protocols/itch/codegen.py index 257be1b..78dc57a 100644 --- a/src/nasdaq_protocols/itch/codegen.py +++ b/src/nasdaq_protocols/itch/codegen.py @@ -1,10 +1,15 @@ +from importlib import resources +import os + +import chevron import click -from nasdaq_protocols.common.message import Parser, Generator +from nasdaq_protocols.common.message import Parser, Generator, templates __all__ = [ 'generate' ] +TEMPLATES_PATH = resources.files(templates) @click.command() @@ -28,3 +33,20 @@ def generate(spec_file, app_name, prefix, op_dir, override_messages, init_file): generate_init_file=init_file ) generator.generate(extra_context=context) + + +@click.command() +@click.option('--op-dir', type=click.Path(exists=True, writable=True)) +@click.option('--app-name', type=click.STRING) +@click.option('--package', type=click.STRING) +def generate_itch_tools(op_dir, app_name, package): + op_file = os.path.join(op_dir, f'itch_{app_name}_tools.py') + template = os.path.join(str(TEMPLATES_PATH), 'itch_tail.mustache') + with open(op_file, 'w', encoding='utf-8') as op, open(template, 'r', encoding='utf-8') as inp: + context = { + 'package': package, + 'app': app_name, + } + code_as_string = chevron.render(inp.read(), context, partials_path=str(TEMPLATES_PATH)) + op.write(code_as_string) + print(f'Generated: {op_file}') diff --git a/src/nasdaq_protocols/itch/core.py b/src/nasdaq_protocols/itch/core.py index 43352ad..ea3f0c2 100644 --- a/src/nasdaq_protocols/itch/core.py +++ b/src/nasdaq_protocols/itch/core.py @@ -20,6 +20,9 @@ def from_bytes(cls, bytes_: bytes) -> tuple[int, 'ItchMessageId']: def to_bytes(self) -> tuple[int, bytes]: return Byte.to_bytes(self.indicator) + def __str__(self): + return f'indicator={self.indicator}' + @attrs.define @logable diff --git a/src/nasdaq_protocols/itch/tools.py b/src/nasdaq_protocols/itch/tools.py new file mode 100644 index 0000000..2b74736 --- /dev/null +++ b/src/nasdaq_protocols/itch/tools.py @@ -0,0 +1,39 @@ +import asyncio +from itertools import count + + +async def tail_itch(remote, + user, + passwd, + session, + sequence, + itch_connector, + client_heartbeat_interval, + server_heartbeat_interval): + closed = asyncio.Event() + sequence_counter = count(1) + + async def on_close(): + print('connection closed.') + closed.set() + + async def on_msg(msg): + print(next(sequence_counter), msg) + print() + + itch_session = None + try: + itch_session = await itch_connector( + remote, user, passwd, session, sequence, + on_msg_coro=on_msg, + on_close_coro=on_close, + client_heartbeat_interval=client_heartbeat_interval, + server_heartbeat_interval=server_heartbeat_interval + ) + print('connected') + await closed.wait() + except asyncio.CancelledError: + if itch_session is not None: + await itch_session.close() + except ConnectionRefusedError as exc: + print(exc) diff --git a/src/nasdaq_protocols/ouch/core.py b/src/nasdaq_protocols/ouch/core.py index f97acb7..5ba96f4 100644 --- a/src/nasdaq_protocols/ouch/core.py +++ b/src/nasdaq_protocols/ouch/core.py @@ -22,6 +22,9 @@ def from_bytes(cls, bytes_: bytes) -> tuple[int, 'OuchMessageId']: def to_bytes(self) -> tuple[int, bytes]: return Byte.to_bytes(self.indicator) + def __str__(self): + return f'indicator={self.indicator}' + @attrs.define @logable diff --git a/src/nasdaq_protocols/soup/_reader.py b/src/nasdaq_protocols/soup/_reader.py index 16b106b..2d927d2 100644 --- a/src/nasdaq_protocols/soup/_reader.py +++ b/src/nasdaq_protocols/soup/_reader.py @@ -40,11 +40,14 @@ async def _process(self): await self.on_close_coro() return - self.log.debug('%s> dispatching message %s', self.session_id, str(msg)) - try: - await self.on_msg_coro(msg) - except Exception: # pylint: disable=broad-except - await self.on_close_coro() + if not msg.is_heartbeat(): + self.log.debug('%s> dispatching message %s', self.session_id, str(msg)) + try: + await self.on_msg_coro(msg) + except Exception: # pylint: disable=broad-except + await self.on_close_coro() + else: + self.log.debug('%s> received heartbeat', self.session_id) self._buffer = self._buffer[siz + 2:] buff_len -= (siz+2) diff --git a/tests/conftest.py b/tests/conftest.py index 3cff289..3f3111e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import importlib.util from pathlib import Path import os +import sys import pytest from click.testing import CliRunner @@ -66,11 +67,37 @@ def generator(codegen, xml_content, app_name, generate_init_file, prefix): return generator +@pytest.fixture(scope='function') +def tools_codegen_invoker(tmp_path): + def generator(codegen, app_name, package): + runner = CliRunner() + with runner.isolated_filesystem(temp_dir=tmp_path): + Path('output').mkdir(parents=True, exist_ok=True) + result = runner.invoke( + codegen, + [ + '--op-dir', 'output', + '--app-name', app_name, + '--package', package + ] + ) + assert result.exit_code == 0 + + # Read the generated files + generated_file_contents = {} + for file in os.listdir('output'): + with open(os.path.join('output', file)) as f: + generated_file_contents[file] = f.read() + return generated_file_contents + return generator + + @pytest.fixture(scope='session') def code_loader(): def loader_(module_name, code_as_string): spec = importlib.util.spec_from_loader(module_name, loader=None) - module = importlib.util.module_from_spec(spec) - exec(code_as_string, module.__dict__) - return module + module_ = importlib.util.module_from_spec(spec) + exec(code_as_string, module_.__dict__) + sys.modules[module_name] = module_ + return module_ return loader_ \ No newline at end of file diff --git a/tests/test_common_message_structures.py b/tests/test_common_message_structures.py index 5f118e9..0c504bb 100644 --- a/tests/test_common_message_structures.py +++ b/tests/test_common_message_structures.py @@ -1,3 +1,4 @@ +import json from enum import Enum import pytest @@ -47,6 +48,10 @@ class TestArrayOfRecords(structures.RecordWithPresentBit): records: list[TestRecordWithPresentBit] +class TestMessage(structures.CommonMessage): + BodyRecord = TestRecord + + def test__record__to_bytes(): expected_bytes = b'\x02\x05\x00\x04\x00test' record = TestRecord() @@ -337,4 +342,13 @@ def test__record_with_present_bit__empty_record__to_bytes(): def test__record_with_present_bit__empty_record2__from_bytes(): - assert TestRecordWithPresentBit.from_bytes(b'\x00') == (1, None) \ No newline at end of file + assert TestRecordWithPresentBit.from_bytes(b'\x00') == (1, None) + + +def test__common_message__as_collection(): + message = TestMessage() + data = { + 'message': 'TestMessage[None]', + 'body': {} + } + assert str(message) == json.dumps(data, indent=2) diff --git a/tests/test_itch_codegen.py b/tests/test_itch_codegen.py index 3f06456..8af5226 100644 --- a/tests/test_itch_codegen.py +++ b/tests/test_itch_codegen.py @@ -157,6 +157,20 @@ def test__load_generated_code__code_loads_without_issue(load_generated_itch_code assert load_generated_itch_code is not None +def test__tools_codegen__code_generated(tools_codegen_invoker): + package = 'test' + app_name = 'test' + expected_file_name = f'itch_{app_name}_tools.py' + generated_files = tools_codegen_invoker( + codegen.generate_itch_tools, + app_name, + package + ) + + assert len(generated_files) == 1 + assert expected_file_name in generated_files + + async def test__connect__using_generated_code(load_generated_itch_code, soup_clientapp_common_tests): module = load_generated_itch_code diff --git a/tests/test_itch_tools.py b/tests/test_itch_tools.py new file mode 100644 index 0000000..8691793 --- /dev/null +++ b/tests/test_itch_tools.py @@ -0,0 +1,161 @@ +import asyncio +import pytest + +from nasdaq_protocols import soup +from nasdaq_protocols.common import stop_task +from nasdaq_protocols.itch import codegen +from nasdaq_protocols.itch.tools import tail_itch +from nasdaq_protocols.soup import LoginRejectReason +from .testdata import * +from .mocks import * + + +APP_NAME = 'test_itch_tools' +LOGIN_REQUEST = soup.LoginRequest('test-u', 'test-p', '', '1') +LOGIN_REJECTED = soup.LoginRejected(LoginRejectReason.NOT_AUTHORIZED) +LOGIN_ACCEPTED = soup.LoginAccepted('test', 1) +LOGOUT_REQUEST = soup.EndOfSession() + + +@pytest.fixture(scope='function') +def load_itch_definitions(code_loader, codegen_invoker): + def generator(app_name): + generated_file_name = f'itch_{app_name}.py' + generated_files = codegen_invoker( + codegen.generate, + TEST_XML_ITCH_MESSAGE, + app_name, + generate_init_file=False, + prefix='' + ) + + module_ = code_loader(app_name, generated_files[generated_file_name]) + assert module_ is not None + return module_ + yield generator + + +@pytest.fixture(scope='function') +def load_itch_tools(code_loader, load_itch_definitions, tools_codegen_invoker): + def generator(app_name): + definitions = load_itch_definitions(app_name) + generated_file_name = f'itch_{app_name}_tools.py' + generated_files = tools_codegen_invoker( + codegen.generate_itch_tools, + app_name, + app_name + ) + + tools = code_loader(f'{app_name}_tools', generated_files[generated_file_name]) + assert tools is not None + return definitions, tools, + yield generator + + +def get_message_feed(module): + message1 = module.TestMessage1() + message1.field1 = 1 + message1.field2 = '2' + message1.field3 = '3' + + message2 = module.TestMessage2() + message2.field1_1 = 4 + message2.field2_1 = '5' + message2.field3_1 = '6' + return [message1, message2] + + +async def test__itch_tools__tail_itch(mock_server_session, load_itch_tools): + definitions, tools = load_itch_tools('test__itch_tools__tail_itch') + message_feed = get_message_feed(definitions) + port, server_session = mock_server_session + + # setup login + session = server_session.when( + matches(LOGIN_REQUEST), 'match-login-request' + ).do( + send(LOGIN_ACCEPTED), 'send-login-accept' + ) + # send itch feeds + for msg in message_feed: + session.do( + send(soup.SequencedData(msg.to_bytes()[1])), 'send-test-message' + ) + + # start tailing + tailer = asyncio.create_task(tail_itch( + ('127.0.0.1', port), 'test-u', 'test-p', '', 1, + definitions.connect_async, 10, 10 + )) + assert not tailer.done() + + # give some time for tail + await asyncio.sleep(1) + assert not tailer.done() + + # logout + server_session.send(LOGOUT_REQUEST) + + # wait for tailer to finish + await asyncio.wait_for(tailer, timeout=5) + assert tailer.done() + + +async def test__itch_tools__tail_itch__login_failed(mock_server_session, load_itch_tools): + definitions, tools = load_itch_tools('test__itch_tools__tail_itch__login_failed') + port, server_session = mock_server_session + + # setup login + server_session.when( + matches(LOGIN_REQUEST), 'match-login-request' + ).do( + send(LOGIN_REJECTED), 'send-login-reject' + ) + + # start tailing + tailer = asyncio.create_task(tail_itch( + ('127.0.0.1', port), 'test-u', 'test-p', '', 1, + definitions.connect_async, 10, 10 + )) + # give some time for tail + await asyncio.sleep(1) + assert tailer.done() + + +async def test__itch_tools__tail_itch__wrong_server(mock_server_session, load_itch_tools): + definitions, tools = load_itch_tools('test__itch_tools__tail_itch__wrong_server') + port, server_session = mock_server_session + + # start tailing + tailer = asyncio.create_task(tail_itch( + ('no.such.host', port), 'test-u', 'test-p', '', 1, + definitions.connect_async, 10, 10 + )) + # give some time for tail + await asyncio.sleep(1) + assert tailer.done() + + +async def test__itch_tools__tail_itch__ctrl_c(mock_server_session, load_itch_tools): + definitions, tools = load_itch_tools('test__itch_tools__tail_itch__ctrl_c') + port, server_session = mock_server_session + + # setup login + session = server_session.when( + matches(LOGIN_REQUEST), 'match-login-request' + ).do( + send(LOGIN_ACCEPTED), 'send-login-accept' + ) + + # start tailing + tailer = asyncio.create_task(tail_itch( + ('127.0.0.1', port), 'test-u', 'test-p', '', 1, + definitions.connect_async, 10, 10 + )) + # give some time for tail + await asyncio.sleep(1) + assert not tailer.done() + + await stop_task(tailer) + assert tailer.done() + diff --git a/tests/test_soup_session.py b/tests/test_soup_session.py index 0c7e7cf..55b30c9 100644 --- a/tests/test_soup_session.py +++ b/tests/test_soup_session.py @@ -140,3 +140,10 @@ async def on_close(): server_session.generate_load(100) await asyncio.wait_for(closed.wait(), 1) + + +def test__soup_session__session_with_no_session_type(): + class BaseSessionType(soup.SoupSession): + pass + + assert BaseSessionType.SessionType is 'base'