Skip to content

Commit

Permalink
added itch cli tools and codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
SamDanielThangarajan committed Sep 30, 2024
1 parent cb277c6 commit 1ed1202
Show file tree
Hide file tree
Showing 14 changed files with 361 additions and 13 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dynamic = ["version"]
[project.scripts]
nasdaq-ouch-codegen="nasdaq_protocols.ouch.codegen:generate"
nasdaq-itch-codegen="nasdaq_protocols.itch.codegen:generate"
nasdaq-itch-tools-codegen="nasdaq_protocols.itch.codegen:generate_itch_tools"
nasdaq-protocols-create-new-project="nasdaq_protocols.tools.new_project:create"


Expand Down
17 changes: 15 additions & 2 deletions src/nasdaq_protocols/common/message/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This module contains the structures used to represent the messages in the protocol.
"""
import inspect
import json
from enum import Enum
from itertools import chain
from collections import OrderedDict, defaultdict
Expand All @@ -10,7 +11,8 @@

from nasdaq_protocols.common.utils import logable
from nasdaq_protocols.common.types import Serializable
from .types import TypeDefinition, Short, Boolean
from nasdaq_protocols.common.types import TypeDefinition
from .types import Short, Boolean


__all__ = [
Expand Down Expand Up @@ -89,6 +91,9 @@ def to_str(cls, record: Type['_Record']):
def __str__(self):
return f"{{'{self.__class__.__name__}':{{{self.values}}}}}"

def as_collection(self):
return self.values

@classmethod
def from_str(cls, _str):
raise NotImplementedError('this method is not available')
Expand Down Expand Up @@ -232,7 +237,8 @@ def __init_subclass__(cls, **kwargs):
cls.MsgId = kwargs.get('msg_id', cls.MsgId)

if all(_ in kwargs for _ in ['app_name', 'msg_id_cls', 'msg_id']):
if cls.MsgId in CommonMessage.MsgIdToClsMap[cls.AppName]:
if (cls.MsgId in CommonMessage.MsgIdToClsMap[cls.AppName] and
cls.MsgIdToClsMap[cls.AppName][cls.MsgId] != cls):
raise DuplicateMessageException(
existing_msg=CommonMessage.MsgIdToClsMap[cls.AppName][cls.MsgId],
new_msg=cls
Expand Down Expand Up @@ -274,6 +280,13 @@ def __setattr__(self, key, value):
except KeyError:
self.__dict__[key] = value

def __str__(self):
data = {
'message': f'{self.__class__.__name__}[{self.MsgId}]',
'body': self.record.as_collection()
}
return json.dumps(data, indent=2)

@classmethod
def get_msg_classes(cls) -> list[Type['CommonMessage']]:
return CommonMessage.MsgIdToClsMap[cls.AppName].values()
Expand Down
31 changes: 31 additions & 0 deletions src/nasdaq_protocols/common/message/templates/itch_tail.mustache
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import asyncio
import click

from nasdaq_protocols.common import utils
from nasdaq_protocols.itch import tools

import {{package}} as app


@click.command()
@click.option('-h', '--host', required=True)
@click.option('-p', '--port', required=True)
@click.option('-U', '--user', required=True)
@click.option('-P', '--password', required=True)
@click.option('-S', '--session', default='', show_default=True)
@click.option('-s', '--sequence', default=1, show_default=True)
@click.option('-b', '--client-heartbeat-interval', default=1, show_default=True)
@click.option('-B', '--server-heartbeat-interval', default=1, show_default=True)
@click.option('-v', '--verbose', count=True)
def {{app}}_tail(host, port, user, password, session, sequence,
client_heartbeat_interval, server_heartbeat_interval,
verbose):
""" Simple command that tails itch messages"""
utils.enable_logging_tools(verbose)
asyncio.run(
tools.tail_itch(
(host, port), user, password, session, sequence,
app.connect_async,
client_heartbeat_interval, server_heartbeat_interval
)
)
12 changes: 11 additions & 1 deletion src/nasdaq_protocols/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
'logable',
'stop_task',
'start_server',
'Validators'
'Validators',
'enable_logging_tools'
]
_StopTaskTypes = asyncio.Task | Stoppable
_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -73,6 +74,15 @@ async def start_server(remote, session_factory, spin_timeout=0.001, *, name='ser
return server, task


def enable_logging_tools(verbose: int, format_: str = '%(asctime)s> %(levelname)s %(message)s'):
level = logging.WARN
if verbose == 1:
level = logging.INFO
elif verbose >= 2:
level = logging.DEBUG
logging.basicConfig(format=format_, level=level, datefmt='%Y-%m-%d %H:%M:%S')


class Validators:
@staticmethod
def not_none():
Expand Down
24 changes: 23 additions & 1 deletion src/nasdaq_protocols/itch/codegen.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from importlib import resources
import os

import chevron
import click
from nasdaq_protocols.common.message import Parser, Generator
from nasdaq_protocols.common.message import Parser, Generator, templates


__all__ = [
'generate'
]
TEMPLATES_PATH = resources.files(templates)


@click.command()
Expand All @@ -28,3 +33,20 @@ def generate(spec_file, app_name, prefix, op_dir, override_messages, init_file):
generate_init_file=init_file
)
generator.generate(extra_context=context)


@click.command()
@click.option('--op-dir', type=click.Path(exists=True, writable=True))
@click.option('--app-name', type=click.STRING)
@click.option('--package', type=click.STRING)
def generate_itch_tools(op_dir, app_name, package):
op_file = os.path.join(op_dir, f'itch_{app_name}_tools.py')
template = os.path.join(str(TEMPLATES_PATH), 'itch_tail.mustache')
with open(op_file, 'w', encoding='utf-8') as op, open(template, 'r', encoding='utf-8') as inp:
context = {
'package': package,
'app': app_name,
}
code_as_string = chevron.render(inp.read(), context, partials_path=str(TEMPLATES_PATH))
op.write(code_as_string)
print(f'Generated: {op_file}')
3 changes: 3 additions & 0 deletions src/nasdaq_protocols/itch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def from_bytes(cls, bytes_: bytes) -> tuple[int, 'ItchMessageId']:
def to_bytes(self) -> tuple[int, bytes]:
return Byte.to_bytes(self.indicator)

def __str__(self):
return f'indicator={self.indicator}'


@attrs.define
@logable
Expand Down
39 changes: 39 additions & 0 deletions src/nasdaq_protocols/itch/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import asyncio
from itertools import count


async def tail_itch(remote,
user,
passwd,
session,
sequence,
itch_connector,
client_heartbeat_interval,
server_heartbeat_interval):
closed = asyncio.Event()
sequence_counter = count(1)

async def on_close():
print('connection closed.')
closed.set()

async def on_msg(msg):
print(next(sequence_counter), msg)
print()

itch_session = None
try:
itch_session = await itch_connector(
remote, user, passwd, session, sequence,
on_msg_coro=on_msg,
on_close_coro=on_close,
client_heartbeat_interval=client_heartbeat_interval,
server_heartbeat_interval=server_heartbeat_interval
)
print('connected')
await closed.wait()
except asyncio.CancelledError:
if itch_session is not None:
await itch_session.close()
except ConnectionRefusedError as exc:
print(exc)
3 changes: 3 additions & 0 deletions src/nasdaq_protocols/ouch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def from_bytes(cls, bytes_: bytes) -> tuple[int, 'OuchMessageId']:
def to_bytes(self) -> tuple[int, bytes]:
return Byte.to_bytes(self.indicator)

def __str__(self):
return f'indicator={self.indicator}'


@attrs.define
@logable
Expand Down
13 changes: 8 additions & 5 deletions src/nasdaq_protocols/soup/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ async def _process(self):
await self.on_close_coro()
return

self.log.debug('%s> dispatching message %s', self.session_id, str(msg))
try:
await self.on_msg_coro(msg)
except Exception: # pylint: disable=broad-except
await self.on_close_coro()
if not msg.is_heartbeat():
self.log.debug('%s> dispatching message %s', self.session_id, str(msg))
try:
await self.on_msg_coro(msg)
except Exception: # pylint: disable=broad-except
await self.on_close_coro()
else:
self.log.debug('%s> received heartbeat', self.session_id)

self._buffer = self._buffer[siz + 2:]
buff_len -= (siz+2)
33 changes: 30 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib.util
from pathlib import Path
import os
import sys
import pytest

from click.testing import CliRunner
Expand Down Expand Up @@ -66,11 +67,37 @@ def generator(codegen, xml_content, app_name, generate_init_file, prefix):
return generator


@pytest.fixture(scope='function')
def tools_codegen_invoker(tmp_path):
def generator(codegen, app_name, package):
runner = CliRunner()
with runner.isolated_filesystem(temp_dir=tmp_path):
Path('output').mkdir(parents=True, exist_ok=True)
result = runner.invoke(
codegen,
[
'--op-dir', 'output',
'--app-name', app_name,
'--package', package
]
)
assert result.exit_code == 0

# Read the generated files
generated_file_contents = {}
for file in os.listdir('output'):
with open(os.path.join('output', file)) as f:
generated_file_contents[file] = f.read()
return generated_file_contents
return generator


@pytest.fixture(scope='session')
def code_loader():
def loader_(module_name, code_as_string):
spec = importlib.util.spec_from_loader(module_name, loader=None)
module = importlib.util.module_from_spec(spec)
exec(code_as_string, module.__dict__)
return module
module_ = importlib.util.module_from_spec(spec)
exec(code_as_string, module_.__dict__)
sys.modules[module_name] = module_
return module_
return loader_
16 changes: 15 additions & 1 deletion tests/test_common_message_structures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from enum import Enum

import pytest
Expand Down Expand Up @@ -47,6 +48,10 @@ class TestArrayOfRecords(structures.RecordWithPresentBit):
records: list[TestRecordWithPresentBit]


class TestMessage(structures.CommonMessage):
BodyRecord = TestRecord


def test__record__to_bytes():
expected_bytes = b'\x02\x05\x00\x04\x00test'
record = TestRecord()
Expand Down Expand Up @@ -337,4 +342,13 @@ def test__record_with_present_bit__empty_record__to_bytes():


def test__record_with_present_bit__empty_record2__from_bytes():
assert TestRecordWithPresentBit.from_bytes(b'\x00') == (1, None)
assert TestRecordWithPresentBit.from_bytes(b'\x00') == (1, None)


def test__common_message__as_collection():
message = TestMessage()
data = {
'message': 'TestMessage[None]',
'body': {}
}
assert str(message) == json.dumps(data, indent=2)
14 changes: 14 additions & 0 deletions tests/test_itch_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,20 @@ def test__load_generated_code__code_loads_without_issue(load_generated_itch_code
assert load_generated_itch_code is not None


def test__tools_codegen__code_generated(tools_codegen_invoker):
package = 'test'
app_name = 'test'
expected_file_name = f'itch_{app_name}_tools.py'
generated_files = tools_codegen_invoker(
codegen.generate_itch_tools,
app_name,
package
)

assert len(generated_files) == 1
assert expected_file_name in generated_files


async def test__connect__using_generated_code(load_generated_itch_code, soup_clientapp_common_tests):
module = load_generated_itch_code

Expand Down
Loading

0 comments on commit 1ed1202

Please sign in to comment.