Skip to content

Commit

Permalink
fist step in multi app support
Browse files Browse the repository at this point in the history
  • Loading branch information
SamDanielThangarajan committed Oct 11, 2024
1 parent 3b41c5e commit af6a281
Show file tree
Hide file tree
Showing 17 changed files with 107 additions and 20 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include src/nasdaq_protocols/common/message/templates/*.mustache
include src/nasdaq_protocols/tools/templates/*.mustache
include src/nasdaq_protocols/fix/parser/templates/*.mustache
include src/nasdaq_protocols/tools/templates/*.xml
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dynamic = ["version"]
nasdaq-ouch-codegen="nasdaq_protocols.ouch.codegen:generate"
nasdaq-itch-codegen="nasdaq_protocols.itch.codegen:generate"
nasdaq-itch-tools-codegen="nasdaq_protocols.itch.codegen:generate_itch_tools"
nasdaq-fix-codegen="nasdaq_protocols.fix.codegen:generate"
nasdaq-protocols-create-new-project="nasdaq_protocols.tools.new_project:create"


Expand Down
5 changes: 2 additions & 3 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ pythonpath = src
asyncio_mode=auto
asyncio_default_fixture_loop_scope=function
log_cli=true
log_level=DEBUG
log_level=INFO
log_format = %(name)-20s: %(message)s
log_date_format = %I:%M:%S
addopts = --cov=src --cov-fail-under=97
log_date_format = %I:%M:%S
5 changes: 2 additions & 3 deletions src/nasdaq_protocols/fix/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import asyncio
from typing import Callable

from .session import FixSession
from .session import *
from .types import *
from .core import *
from ._reader import FixMessageReader


async def connect_async(remote: tuple[str, int],
login_msg: Message,
session_fac: Callable[[], FixSession],
_sequence=1):
session_fac: Callable[[], FixSession]):
loop = asyncio.get_running_loop()

_, session_1 = await loop.create_connection(session_fac, *remote)
Expand Down
18 changes: 17 additions & 1 deletion src/nasdaq_protocols/fix/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import pprint
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import ClassVar, Any, Type, TypeVar, Union

Expand Down Expand Up @@ -376,15 +376,31 @@ def __eq__(self, other):


class Message(FixSerializable):
MsgIdToClsMap: ClassVar[dict] = defaultdict(dict)
MsgNameToMsgMap: ClassVar[dict] = defaultdict(dict)
Type: ClassVar[int]
Name: ClassVar[str]
Category: ClassVar[str]
SegmentCls: ClassVar[dict[MessageSegments, type[DataSegment]]]
AppName: ClassVar[str]

Def = {}
MandatoryFields = [
'Name',
'Type',
'Category',
'HeaderCls',
'TrailerCls',
'BodyCls',
]

@classmethod
def __init_subclass__(cls, **kwargs):
for field in cls.MandatoryFields:
if field not in kwargs:
return
app_name = kwargs.get('app_name', 'fix')
cls.AppName = app_name
cls.Name = kwargs['Name']
cls.Type = kwargs['Type']
cls.Category = kwargs['Category']
Expand Down
13 changes: 13 additions & 0 deletions src/nasdaq_protocols/fix/parser/definitions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import keyword
from collections import defaultdict
from itertools import count
from typing import Any
Expand All @@ -18,6 +19,8 @@
'Definitions'
]

from nasdaq_protocols.fix.parser.version_types import Version


@attrs.define
class FieldDef:
Expand All @@ -30,6 +33,7 @@ def _values_ctx(self):
output = []
if self.possible_values:
for key, value in self.possible_values.items():
value = f'{value}' if not keyword.iskeyword(value) else f'{value}_'
output.append({
'f_name': key,
'f_value': value,
Expand Down Expand Up @@ -123,6 +127,7 @@ def get_codegen_context(self, definitions):

@attrs.define
class Definitions:
version: Version
fields: dict[str, FieldDef] = attrs.field(init=False, factory=dict)
components: dict[str, Component] = attrs.field(kw_only=True, factory=dict)
header: EntryContainer = attrs.field(kw_only=True, factory=EntryContainer)
Expand All @@ -137,6 +142,7 @@ def get_codegen_context(self):
for message in self.messages
]
return {
'client_session': self._client_session(),
'fields': [field.get_codegen_context(self) for field in self.fields.values()],
'bodies': [
self.header.get_codegen_context(self) | {
Expand All @@ -149,3 +155,10 @@ def get_codegen_context(self):
'messages': message_context,
'groups': Group.Contexts, # always last, as it is dependent on other entries
}

def _client_session(self):
if self.version == Version.FIX_4_4:
return 'Fix44Session'
if self.version in (Version.FIX_5_0, Version.FIX_5_0_2):
return 'Fix50Session'
raise ValueError(f'Version {self.version} is not supported')
11 changes: 11 additions & 0 deletions src/nasdaq_protocols/fix/parser/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,20 @@ def generate(self):
generated_modules.append(module_name)
generated_files.append(generated_file)

# Generate the app module
Generator._generate(
self._context,
os.path.join(str(TEMPLATES_PATH), 'app.mustache'),
os.path.join(self.op_dir, 'app.py')
)
generated_modules.append('app')
generated_files.append(os.path.join(self.op_dir, 'app.py'))

# Generate the __init__.py file
if self.generate_init_file:
context = {
'app_name': self.app_name,
'client_session': self._context['client_session'],
'modules': [
{'name': module} for module in generated_modules
]
Expand Down
9 changes: 7 additions & 2 deletions src/nasdaq_protocols/fix/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def parse(file: str) -> Definitions:
if root.tag != 'fix':
raise ValueError('root tag is not fix')

version = int(f'{root.get("major")}{root.get("minor")}')
version_str = f'{root.get("major")}{root.get("minor")}'
servicepack = int(root.get('servicepack', '0'))
if servicepack > 0:
version_str += f'{servicepack}'
version = int(version_str)

try:
version = Version(version)
except ValueError as v_error:
Expand All @@ -44,7 +49,7 @@ def parse(file: str) -> Definitions:
'trailer': _handle_trailer,
'messages': _handle_messages
}
definitions = Definitions()
definitions = Definitions(version)

for element in list(root)[::-1]:
handlers[element.tag](definitions, root, element)
Expand Down
34 changes: 34 additions & 0 deletions src/nasdaq_protocols/fix/parser/templates/app.mustache
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from nasdaq_protocols.common import logable
from nasdaq_protocols import fix


@logable
class Message(fix.Message, app_name='{{app_name}}'):
def __init_subclass__(cls, **kwargs):
cls.log.debug("{{app_name}} Message subclassed")
for field in fix.Message.MandatoryFields:
if field not in kwargs:
raise ValueError(f"{field} missing when subclassing Message[{{app_name}}]")
kwargs['app_name'] = '{{app_name}}'
super().__init_subclass__(**kwargs)


class ClientSession(fix.{{client_session}}):
@classmethod
def decode(cls, data: bytes) -> fix.Message:
return Message.from_bytes(data)


async def connect_async(remote: tuple[str, int],
login_msg: Message,
on_msg_coro = None,
on_close_coro = None,
client_heartbeat_interval: int = 10,
server_heartbeat_interval: int = 10) -> fix.FixSession:
session = ClientSession(
on_msg_coro=on_msg_coro,
on_close_coro=on_close_coro,
client_heartbeat_interval=client_heartbeat_interval,
server_heartbeat_interval=server_heartbeat_interval
)
return await fix.connect_async(remote, login_msg, lambda: session)
2 changes: 1 addition & 1 deletion src/nasdaq_protocols/fix/parser/templates/fields.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class {{name}}(fix.Field, Tag="{{tag}}", Name="{{name}}", Type=fix.{{type}}):
{{/values}}
}
{{#values}}
{{f_value}} = {{#quote}}'{{/quote}}{{f_name}}{{#quote}}'{{/quote}},
{{f_value}} = {{#quote}}'{{/quote}}{{f_name}}{{#quote}}'{{/quote}}
{{/values}}


Expand Down
2 changes: 2 additions & 0 deletions src/nasdaq_protocols/fix/parser/templates/init.mustache
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from nasdaq_protocols import fix
{{#modules}}
from .{{name}} import *
{{/modules}}

6 changes: 5 additions & 1 deletion src/nasdaq_protocols/fix/parser/templates/messages.mustache
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from nasdaq_protocols import fix
from . import {{module_prefix}}_groups as groups
from . import {{module_prefix}}_bodies as bodies
from .app import Message as Message


{{#messages}}
class {{name}}(fix.Message,
class {{name}}(Message,
Name="{{name}}",
Type="{{tag}}",
Category="{{category}}",
HeaderCls=bodies.Header,
BodyCls=bodies.{{name}}Body,
TrailerCls=bodies.Trailer):
Header: bodies.Header
Body: bodies.{{name}}Body
Trailer: bodies.Trailer
{{#entries}}
{{^is_group}}
{{field.name}}: {{field.type_hint}}
Expand Down
9 changes: 5 additions & 4 deletions src/nasdaq_protocols/fix/parser/version_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ class Version(enum.IntEnum):
FIX_4_2 = 42
FIX_4_4 = 44
FIX_5_0 = 50
FIX_5_2 = 52
FIX_5_0_2 = 502


def get_supported_types(version: Version) -> SupportedTypes:
version_map = {
Version.FIX_4_2: fix_42_version_types,
Version.FIX_4_4: fix_44_version_types,
Version.FIX_5_0: fix_50_version_types,
Version.FIX_5_2: fix_52_version_types
Version.FIX_5_0_2: fix_502_version_types
}
try:
return version_map[version]()
Expand Down Expand Up @@ -81,10 +81,11 @@ def fix_50_version_types():
return fix_50_types


def fix_52_version_types():
def fix_502_version_types():
fix_52_types = fix_50_version_types()
fix_52_types.update({
'LOCALMKTDATE': types.FixLocalMktDate,
'TZTIMEONLY': types.FixTzTimeonly
'TZTIMEONLY': types.FixTzTimeonly,
'MULTIPLESTRINGVALUE': types.FixMultipleValueString,
})
return fix_52_types
3 changes: 2 additions & 1 deletion src/nasdaq_protocols/fix/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,12 @@ def send_msg(self, msg: core.Message) -> None:
msg.validate(segments=[core.MessageSegments.BODY])
msg.Header.SenderSubID = self.sender_sub_id
msg.Header.TargetCompID = self.target_comp_id
msg.Header.SenderCompID = self.sender_comp_id
msg.Header.MsgSeqNum = next(self.sequence)
msg.Header.SendingTime = datetime.now(timezone.utc).strftime("%Y%m%d-%H:%M:%S")

data = self._prepare_complete_msg(msg)
self.log.info(data)
self.log.debug('%s> sent message[%s]: %s', self.session_id, msg.Name, data)
self._transport.write(data)
self.log.debug('%s> sent message[%s]: %s', self.session_id, msg.Name, msg.as_collection())

Expand Down
4 changes: 2 additions & 2 deletions tests/test_fix_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test__no_init_file__no_prefix__code_generated(codegen_invoker):
prefix=prefix
)

assert len(generated_files) == 4
assert len(generated_files) == 5


def test__init_file__no_prefix__code_generated(fix_44_definitions, codegen_invoker, tmp_path, module_loader):
Expand All @@ -45,7 +45,7 @@ def test__init_file__no_prefix__code_generated(fix_44_definitions, codegen_invok
output_dir=output_dir
)

assert len(generated_files) == 5
assert len(generated_files) == 6

# This ensures the generated code is correct
generated_package = module_loader('test__init_file__no_prefix__code_generated', output_dir / '__init__.py')
Expand Down
2 changes: 1 addition & 1 deletion tests/testdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@


TEST_FIX_44_XML = """
<fix major="4" minor="4">
<fix major="4" minor="4" servicepack="0">
<header>
<field name="BeginString" required="Y"/>
<field name="BodyLength" required="N"/>
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ deps =
pytest-asyncio
pytest-sugar
pytest-cov
commands = pytest --cov-fail-under=97 # Locking down the current cov percent as baseline
commands = pytest --cov=src --cov-fail-under=97 # Locking down the current cov percent as baseline


[testenv:build]
Expand Down

0 comments on commit af6a281

Please sign in to comment.