From 43332f3c61c9257799cd8e6943bcf1b5be145758 Mon Sep 17 00:00:00 2001
From: SamDanielThangarajan
<12202554+SamDanielThangarajan@users.noreply.github.com>
Date: Fri, 11 Oct 2024 13:37:53 +0200
Subject: [PATCH] fist step in multi app support
---
MANIFEST.in | 1 +
pyproject.toml | 1 +
pytest.ini | 5 ++-
src/nasdaq_protocols/fix/__init__.py | 5 ++-
src/nasdaq_protocols/fix/core.py | 18 +++++++++-
.../fix/parser/definitions.py | 13 +++++++
src/nasdaq_protocols/fix/parser/generator.py | 11 ++++++
src/nasdaq_protocols/fix/parser/parser.py | 9 +++--
.../fix/parser/templates/app.mustache | 34 +++++++++++++++++++
.../fix/parser/templates/fields.mustache | 2 +-
.../fix/parser/templates/init.mustache | 2 ++
.../fix/parser/templates/messages.mustache | 6 +++-
.../fix/parser/version_types.py | 9 ++---
src/nasdaq_protocols/fix/session.py | 3 +-
tests/test_fix_codegen.py | 4 +--
tests/test_fix_parser.py | 31 +++++++++++++++++
tests/testdata.py | 2 +-
tox.ini | 2 +-
18 files changed, 138 insertions(+), 20 deletions(-)
create mode 100644 src/nasdaq_protocols/fix/parser/templates/app.mustache
diff --git a/MANIFEST.in b/MANIFEST.in
index b1a1401..162890c 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,3 +1,4 @@
include src/nasdaq_protocols/common/message/templates/*.mustache
include src/nasdaq_protocols/tools/templates/*.mustache
+include src/nasdaq_protocols/fix/parser/templates/*.mustache
include src/nasdaq_protocols/tools/templates/*.xml
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 1cfa33d..3955102 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,6 +24,7 @@ dynamic = ["version"]
nasdaq-ouch-codegen="nasdaq_protocols.ouch.codegen:generate"
nasdaq-itch-codegen="nasdaq_protocols.itch.codegen:generate"
nasdaq-itch-tools-codegen="nasdaq_protocols.itch.codegen:generate_itch_tools"
+nasdaq-fix-codegen="nasdaq_protocols.fix.codegen:generate"
nasdaq-protocols-create-new-project="nasdaq_protocols.tools.new_project:create"
diff --git a/pytest.ini b/pytest.ini
index 91864a5..a7fd31d 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -3,7 +3,6 @@ pythonpath = src
asyncio_mode=auto
asyncio_default_fixture_loop_scope=function
log_cli=true
-log_level=DEBUG
+log_level=INFO
log_format = %(name)-20s: %(message)s
-log_date_format = %I:%M:%S
-addopts = --cov=src --cov-fail-under=97
\ No newline at end of file
+log_date_format = %I:%M:%S
\ No newline at end of file
diff --git a/src/nasdaq_protocols/fix/__init__.py b/src/nasdaq_protocols/fix/__init__.py
index 4898beb..127a9b4 100644
--- a/src/nasdaq_protocols/fix/__init__.py
+++ b/src/nasdaq_protocols/fix/__init__.py
@@ -1,7 +1,7 @@
import asyncio
from typing import Callable
-from .session import FixSession
+from .session import *
from .types import *
from .core import *
from ._reader import FixMessageReader
@@ -9,8 +9,7 @@
async def connect_async(remote: tuple[str, int],
login_msg: Message,
- session_fac: Callable[[], FixSession],
- _sequence=1):
+ session_fac: Callable[[], FixSession]):
loop = asyncio.get_running_loop()
_, session_1 = await loop.create_connection(session_fac, *remote)
diff --git a/src/nasdaq_protocols/fix/core.py b/src/nasdaq_protocols/fix/core.py
index b5848c2..c0396fb 100644
--- a/src/nasdaq_protocols/fix/core.py
+++ b/src/nasdaq_protocols/fix/core.py
@@ -1,6 +1,6 @@
import abc
import pprint
-from collections import OrderedDict
+from collections import OrderedDict, defaultdict
from enum import Enum
from typing import ClassVar, Any, Type, TypeVar, Union
@@ -376,15 +376,31 @@ def __eq__(self, other):
class Message(FixSerializable):
+ MsgIdToClsMap: ClassVar[dict] = defaultdict(dict)
+ MsgNameToMsgMap: ClassVar[dict] = defaultdict(dict)
Type: ClassVar[int]
Name: ClassVar[str]
Category: ClassVar[str]
SegmentCls: ClassVar[dict[MessageSegments, type[DataSegment]]]
+ AppName: ClassVar[str]
Def = {}
+ MandatoryFields = [
+ 'Name',
+ 'Type',
+ 'Category',
+ 'HeaderCls',
+ 'TrailerCls',
+ 'BodyCls',
+ ]
@classmethod
def __init_subclass__(cls, **kwargs):
+ for field in cls.MandatoryFields:
+ if field not in kwargs:
+ return
+ app_name = kwargs.get('app_name', 'fix')
+ cls.AppName = app_name
cls.Name = kwargs['Name']
cls.Type = kwargs['Type']
cls.Category = kwargs['Category']
diff --git a/src/nasdaq_protocols/fix/parser/definitions.py b/src/nasdaq_protocols/fix/parser/definitions.py
index 49831c7..017662a 100644
--- a/src/nasdaq_protocols/fix/parser/definitions.py
+++ b/src/nasdaq_protocols/fix/parser/definitions.py
@@ -1,3 +1,4 @@
+import keyword
from collections import defaultdict
from itertools import count
from typing import Any
@@ -18,6 +19,8 @@
'Definitions'
]
+from nasdaq_protocols.fix.parser.version_types import Version
+
@attrs.define
class FieldDef:
@@ -30,6 +33,7 @@ def _values_ctx(self):
output = []
if self.possible_values:
for key, value in self.possible_values.items():
+ value = f'{value}' if not keyword.iskeyword(value) else f'{value}_'
output.append({
'f_name': key,
'f_value': value,
@@ -123,6 +127,7 @@ def get_codegen_context(self, definitions):
@attrs.define
class Definitions:
+ version: Version
fields: dict[str, FieldDef] = attrs.field(init=False, factory=dict)
components: dict[str, Component] = attrs.field(kw_only=True, factory=dict)
header: EntryContainer = attrs.field(kw_only=True, factory=EntryContainer)
@@ -137,6 +142,7 @@ def get_codegen_context(self):
for message in self.messages
]
return {
+ 'client_session': self._client_session(),
'fields': [field.get_codegen_context(self) for field in self.fields.values()],
'bodies': [
self.header.get_codegen_context(self) | {
@@ -149,3 +155,10 @@ def get_codegen_context(self):
'messages': message_context,
'groups': Group.Contexts, # always last, as it is dependent on other entries
}
+
+ def _client_session(self):
+ if self.version == Version.FIX_4_4:
+ return 'Fix44Session'
+ if self.version in (Version.FIX_5_0, Version.FIX_5_0_2):
+ return 'Fix50Session'
+ raise ValueError(f'Version {self.version} is not supported')
diff --git a/src/nasdaq_protocols/fix/parser/generator.py b/src/nasdaq_protocols/fix/parser/generator.py
index de394dc..d4f28f6 100644
--- a/src/nasdaq_protocols/fix/parser/generator.py
+++ b/src/nasdaq_protocols/fix/parser/generator.py
@@ -54,9 +54,20 @@ def generate(self):
generated_modules.append(module_name)
generated_files.append(generated_file)
+ # Generate the app module
+ Generator._generate(
+ self._context,
+ os.path.join(str(TEMPLATES_PATH), 'app.mustache'),
+ os.path.join(self.op_dir, 'app.py')
+ )
+ generated_modules.append('app')
+ generated_files.append(os.path.join(self.op_dir, 'app.py'))
+
# Generate the __init__.py file
if self.generate_init_file:
context = {
+ 'app_name': self.app_name,
+ 'client_session': self._context['client_session'],
'modules': [
{'name': module} for module in generated_modules
]
diff --git a/src/nasdaq_protocols/fix/parser/parser.py b/src/nasdaq_protocols/fix/parser/parser.py
index 0391a68..3ab9f03 100644
--- a/src/nasdaq_protocols/fix/parser/parser.py
+++ b/src/nasdaq_protocols/fix/parser/parser.py
@@ -31,7 +31,12 @@ def parse(file: str) -> Definitions:
if root.tag != 'fix':
raise ValueError('root tag is not fix')
- version = int(f'{root.get("major")}{root.get("minor")}')
+ version_str = f'{root.get("major")}{root.get("minor")}'
+ servicepack = int(root.get('servicepack', '0'))
+ if servicepack > 0:
+ version_str += f'{servicepack}'
+ version = int(version_str)
+
try:
version = Version(version)
except ValueError as v_error:
@@ -44,7 +49,7 @@ def parse(file: str) -> Definitions:
'trailer': _handle_trailer,
'messages': _handle_messages
}
- definitions = Definitions()
+ definitions = Definitions(version)
for element in list(root)[::-1]:
handlers[element.tag](definitions, root, element)
diff --git a/src/nasdaq_protocols/fix/parser/templates/app.mustache b/src/nasdaq_protocols/fix/parser/templates/app.mustache
new file mode 100644
index 0000000..ed7fa74
--- /dev/null
+++ b/src/nasdaq_protocols/fix/parser/templates/app.mustache
@@ -0,0 +1,34 @@
+from nasdaq_protocols.common import logable
+from nasdaq_protocols import fix
+
+
+@logable
+class Message(fix.Message, app_name='{{app_name}}'):
+ def __init_subclass__(cls, **kwargs):
+ cls.log.debug("{{app_name}} Message subclassed")
+ for field in fix.Message.MandatoryFields:
+ if field not in kwargs:
+ raise ValueError(f"{field} missing when subclassing Message[{{app_name}}]")
+ kwargs['app_name'] = '{{app_name}}'
+ super().__init_subclass__(**kwargs)
+
+
+class ClientSession(fix.{{client_session}}):
+ @classmethod
+ def decode(cls, data: bytes) -> fix.Message:
+ return Message.from_bytes(data)
+
+
+async def connect_async(remote: tuple[str, int],
+ login_msg: Message,
+ on_msg_coro = None,
+ on_close_coro = None,
+ client_heartbeat_interval: int = 10,
+ server_heartbeat_interval: int = 10) -> fix.FixSession:
+ session = ClientSession(
+ on_msg_coro=on_msg_coro,
+ on_close_coro=on_close_coro,
+ client_heartbeat_interval=client_heartbeat_interval,
+ server_heartbeat_interval=server_heartbeat_interval
+ )
+ return await fix.connect_async(remote, login_msg, lambda: session)
diff --git a/src/nasdaq_protocols/fix/parser/templates/fields.mustache b/src/nasdaq_protocols/fix/parser/templates/fields.mustache
index a16d1cf..cb8770c 100644
--- a/src/nasdaq_protocols/fix/parser/templates/fields.mustache
+++ b/src/nasdaq_protocols/fix/parser/templates/fields.mustache
@@ -9,7 +9,7 @@ class {{name}}(fix.Field, Tag="{{tag}}", Name="{{name}}", Type=fix.{{type}}):
{{/values}}
}
{{#values}}
- {{f_value}} = {{#quote}}'{{/quote}}{{f_name}}{{#quote}}'{{/quote}},
+ {{f_value}} = {{#quote}}'{{/quote}}{{f_name}}{{#quote}}'{{/quote}}
{{/values}}
diff --git a/src/nasdaq_protocols/fix/parser/templates/init.mustache b/src/nasdaq_protocols/fix/parser/templates/init.mustache
index 3520e8c..2a81d03 100644
--- a/src/nasdaq_protocols/fix/parser/templates/init.mustache
+++ b/src/nasdaq_protocols/fix/parser/templates/init.mustache
@@ -1,3 +1,5 @@
+from nasdaq_protocols import fix
{{#modules}}
from .{{name}} import *
{{/modules}}
+
diff --git a/src/nasdaq_protocols/fix/parser/templates/messages.mustache b/src/nasdaq_protocols/fix/parser/templates/messages.mustache
index 3fea8f2..d42cafd 100644
--- a/src/nasdaq_protocols/fix/parser/templates/messages.mustache
+++ b/src/nasdaq_protocols/fix/parser/templates/messages.mustache
@@ -1,16 +1,20 @@
from nasdaq_protocols import fix
from . import {{module_prefix}}_groups as groups
from . import {{module_prefix}}_bodies as bodies
+from .app import Message as Message
{{#messages}}
-class {{name}}(fix.Message,
+class {{name}}(Message,
Name="{{name}}",
Type="{{tag}}",
Category="{{category}}",
HeaderCls=bodies.Header,
BodyCls=bodies.{{name}}Body,
TrailerCls=bodies.Trailer):
+ Header: bodies.Header
+ Body: bodies.{{name}}Body
+ Trailer: bodies.Trailer
{{#entries}}
{{^is_group}}
{{field.name}}: {{field.type_hint}}
diff --git a/src/nasdaq_protocols/fix/parser/version_types.py b/src/nasdaq_protocols/fix/parser/version_types.py
index f1af4cd..1b35800 100644
--- a/src/nasdaq_protocols/fix/parser/version_types.py
+++ b/src/nasdaq_protocols/fix/parser/version_types.py
@@ -16,7 +16,7 @@ class Version(enum.IntEnum):
FIX_4_2 = 42
FIX_4_4 = 44
FIX_5_0 = 50
- FIX_5_2 = 52
+ FIX_5_0_2 = 502
def get_supported_types(version: Version) -> SupportedTypes:
@@ -24,7 +24,7 @@ def get_supported_types(version: Version) -> SupportedTypes:
Version.FIX_4_2: fix_42_version_types,
Version.FIX_4_4: fix_44_version_types,
Version.FIX_5_0: fix_50_version_types,
- Version.FIX_5_2: fix_52_version_types
+ Version.FIX_5_0_2: fix_502_version_types
}
try:
return version_map[version]()
@@ -81,10 +81,11 @@ def fix_50_version_types():
return fix_50_types
-def fix_52_version_types():
+def fix_502_version_types():
fix_52_types = fix_50_version_types()
fix_52_types.update({
'LOCALMKTDATE': types.FixLocalMktDate,
- 'TZTIMEONLY': types.FixTzTimeonly
+ 'TZTIMEONLY': types.FixTzTimeonly,
+ 'MULTIPLESTRINGVALUE': types.FixMultipleValueString,
})
return fix_52_types
diff --git a/src/nasdaq_protocols/fix/session.py b/src/nasdaq_protocols/fix/session.py
index 2afe2bf..b56cd3f 100644
--- a/src/nasdaq_protocols/fix/session.py
+++ b/src/nasdaq_protocols/fix/session.py
@@ -75,11 +75,12 @@ def send_msg(self, msg: core.Message) -> None:
msg.validate(segments=[core.MessageSegments.BODY])
msg.Header.SenderSubID = self.sender_sub_id
msg.Header.TargetCompID = self.target_comp_id
+ msg.Header.SenderCompID = self.sender_comp_id
msg.Header.MsgSeqNum = next(self.sequence)
msg.Header.SendingTime = datetime.now(timezone.utc).strftime("%Y%m%d-%H:%M:%S")
data = self._prepare_complete_msg(msg)
- self.log.info(data)
+ self.log.debug('%s> sent message[%s]: %s', self.session_id, msg.Name, data)
self._transport.write(data)
self.log.debug('%s> sent message[%s]: %s', self.session_id, msg.Name, msg.as_collection())
diff --git a/tests/test_fix_codegen.py b/tests/test_fix_codegen.py
index 4cae8ac..0b890e6 100644
--- a/tests/test_fix_codegen.py
+++ b/tests/test_fix_codegen.py
@@ -29,7 +29,7 @@ def test__no_init_file__no_prefix__code_generated(codegen_invoker):
prefix=prefix
)
- assert len(generated_files) == 4
+ assert len(generated_files) == 5
def test__init_file__no_prefix__code_generated(fix_44_definitions, codegen_invoker, tmp_path, module_loader):
@@ -45,7 +45,7 @@ def test__init_file__no_prefix__code_generated(fix_44_definitions, codegen_invok
output_dir=output_dir
)
- assert len(generated_files) == 5
+ assert len(generated_files) == 6
# This ensures the generated code is correct
generated_package = module_loader('test__init_file__no_prefix__code_generated', output_dir / '__init__.py')
diff --git a/tests/test_fix_parser.py b/tests/test_fix_parser.py
index b205523..ad9d3a6 100644
--- a/tests/test_fix_parser.py
+++ b/tests/test_fix_parser.py
@@ -96,6 +96,37 @@ def test__fix_parser__parse__field_not_found(tmp_file_writer):
assert str(e.value) == 'Field definition for NotFound not found'
+def test__fix_parser__parse__xml_with_service_pack(tmp_file_writer):
+ fix_502 = '''
+
+
+ '''
+ file = tmp_file_writer(fix_502)
+
+ definitions = parse(file)
+ assert definitions.version == 502
+
+
+def test__fix_parser__parse__xml_with_keywords__keywords_are_transformed(tmp_file_writer):
+ fix_502 = '''
+
+
+
+
+
+
+
+
+ '''
+ file = tmp_file_writer(fix_502)
+
+ definitions = parse(file)
+ assert definitions.version == 502
+ context = definitions.fields['MsgType'].get_codegen_context(None)
+ assert context['values'][0]['f_value'] == 'None_'
+ assert context['values'][1]['f_value'] == 'if_'
+
+
def test__fix_parser__fields_are_parsed(fix_44_definitions):
def assert_field(name, tag, typ, total_possible_values):
assert name in fix_44_definitions.fields
diff --git a/tests/testdata.py b/tests/testdata.py
index d40c844..cf9f693 100644
--- a/tests/testdata.py
+++ b/tests/testdata.py
@@ -192,7 +192,7 @@
TEST_FIX_44_XML = """
-
+
diff --git a/tox.ini b/tox.ini
index f1d9386..970921b 100644
--- a/tox.ini
+++ b/tox.ini
@@ -28,7 +28,7 @@ deps =
pytest-asyncio
pytest-sugar
pytest-cov
-commands = pytest --cov-fail-under=97 # Locking down the current cov percent as baseline
+commands = pytest --cov=src --cov-fail-under=97 # Locking down the current cov percent as baseline
[testenv:build]