From 43332f3c61c9257799cd8e6943bcf1b5be145758 Mon Sep 17 00:00:00 2001 From: SamDanielThangarajan <12202554+SamDanielThangarajan@users.noreply.github.com> Date: Fri, 11 Oct 2024 13:37:53 +0200 Subject: [PATCH] fist step in multi app support --- MANIFEST.in | 1 + pyproject.toml | 1 + pytest.ini | 5 ++- src/nasdaq_protocols/fix/__init__.py | 5 ++- src/nasdaq_protocols/fix/core.py | 18 +++++++++- .../fix/parser/definitions.py | 13 +++++++ src/nasdaq_protocols/fix/parser/generator.py | 11 ++++++ src/nasdaq_protocols/fix/parser/parser.py | 9 +++-- .../fix/parser/templates/app.mustache | 34 +++++++++++++++++++ .../fix/parser/templates/fields.mustache | 2 +- .../fix/parser/templates/init.mustache | 2 ++ .../fix/parser/templates/messages.mustache | 6 +++- .../fix/parser/version_types.py | 9 ++--- src/nasdaq_protocols/fix/session.py | 3 +- tests/test_fix_codegen.py | 4 +-- tests/test_fix_parser.py | 31 +++++++++++++++++ tests/testdata.py | 2 +- tox.ini | 2 +- 18 files changed, 138 insertions(+), 20 deletions(-) create mode 100644 src/nasdaq_protocols/fix/parser/templates/app.mustache diff --git a/MANIFEST.in b/MANIFEST.in index b1a1401..162890c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ include src/nasdaq_protocols/common/message/templates/*.mustache include src/nasdaq_protocols/tools/templates/*.mustache +include src/nasdaq_protocols/fix/parser/templates/*.mustache include src/nasdaq_protocols/tools/templates/*.xml \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 1cfa33d..3955102 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dynamic = ["version"] nasdaq-ouch-codegen="nasdaq_protocols.ouch.codegen:generate" nasdaq-itch-codegen="nasdaq_protocols.itch.codegen:generate" nasdaq-itch-tools-codegen="nasdaq_protocols.itch.codegen:generate_itch_tools" +nasdaq-fix-codegen="nasdaq_protocols.fix.codegen:generate" nasdaq-protocols-create-new-project="nasdaq_protocols.tools.new_project:create" diff --git a/pytest.ini b/pytest.ini index 91864a5..a7fd31d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,7 +3,6 @@ pythonpath = src asyncio_mode=auto asyncio_default_fixture_loop_scope=function log_cli=true -log_level=DEBUG +log_level=INFO log_format = %(name)-20s: %(message)s -log_date_format = %I:%M:%S -addopts = --cov=src --cov-fail-under=97 \ No newline at end of file +log_date_format = %I:%M:%S \ No newline at end of file diff --git a/src/nasdaq_protocols/fix/__init__.py b/src/nasdaq_protocols/fix/__init__.py index 4898beb..127a9b4 100644 --- a/src/nasdaq_protocols/fix/__init__.py +++ b/src/nasdaq_protocols/fix/__init__.py @@ -1,7 +1,7 @@ import asyncio from typing import Callable -from .session import FixSession +from .session import * from .types import * from .core import * from ._reader import FixMessageReader @@ -9,8 +9,7 @@ async def connect_async(remote: tuple[str, int], login_msg: Message, - session_fac: Callable[[], FixSession], - _sequence=1): + session_fac: Callable[[], FixSession]): loop = asyncio.get_running_loop() _, session_1 = await loop.create_connection(session_fac, *remote) diff --git a/src/nasdaq_protocols/fix/core.py b/src/nasdaq_protocols/fix/core.py index b5848c2..c0396fb 100644 --- a/src/nasdaq_protocols/fix/core.py +++ b/src/nasdaq_protocols/fix/core.py @@ -1,6 +1,6 @@ import abc import pprint -from collections import OrderedDict +from collections import OrderedDict, defaultdict from enum import Enum from typing import ClassVar, Any, Type, TypeVar, Union @@ -376,15 +376,31 @@ def __eq__(self, other): class Message(FixSerializable): + MsgIdToClsMap: ClassVar[dict] = defaultdict(dict) + MsgNameToMsgMap: ClassVar[dict] = defaultdict(dict) Type: ClassVar[int] Name: ClassVar[str] Category: ClassVar[str] SegmentCls: ClassVar[dict[MessageSegments, type[DataSegment]]] + AppName: ClassVar[str] Def = {} + MandatoryFields = [ + 'Name', + 'Type', + 'Category', + 'HeaderCls', + 'TrailerCls', + 'BodyCls', + ] @classmethod def __init_subclass__(cls, **kwargs): + for field in cls.MandatoryFields: + if field not in kwargs: + return + app_name = kwargs.get('app_name', 'fix') + cls.AppName = app_name cls.Name = kwargs['Name'] cls.Type = kwargs['Type'] cls.Category = kwargs['Category'] diff --git a/src/nasdaq_protocols/fix/parser/definitions.py b/src/nasdaq_protocols/fix/parser/definitions.py index 49831c7..017662a 100644 --- a/src/nasdaq_protocols/fix/parser/definitions.py +++ b/src/nasdaq_protocols/fix/parser/definitions.py @@ -1,3 +1,4 @@ +import keyword from collections import defaultdict from itertools import count from typing import Any @@ -18,6 +19,8 @@ 'Definitions' ] +from nasdaq_protocols.fix.parser.version_types import Version + @attrs.define class FieldDef: @@ -30,6 +33,7 @@ def _values_ctx(self): output = [] if self.possible_values: for key, value in self.possible_values.items(): + value = f'{value}' if not keyword.iskeyword(value) else f'{value}_' output.append({ 'f_name': key, 'f_value': value, @@ -123,6 +127,7 @@ def get_codegen_context(self, definitions): @attrs.define class Definitions: + version: Version fields: dict[str, FieldDef] = attrs.field(init=False, factory=dict) components: dict[str, Component] = attrs.field(kw_only=True, factory=dict) header: EntryContainer = attrs.field(kw_only=True, factory=EntryContainer) @@ -137,6 +142,7 @@ def get_codegen_context(self): for message in self.messages ] return { + 'client_session': self._client_session(), 'fields': [field.get_codegen_context(self) for field in self.fields.values()], 'bodies': [ self.header.get_codegen_context(self) | { @@ -149,3 +155,10 @@ def get_codegen_context(self): 'messages': message_context, 'groups': Group.Contexts, # always last, as it is dependent on other entries } + + def _client_session(self): + if self.version == Version.FIX_4_4: + return 'Fix44Session' + if self.version in (Version.FIX_5_0, Version.FIX_5_0_2): + return 'Fix50Session' + raise ValueError(f'Version {self.version} is not supported') diff --git a/src/nasdaq_protocols/fix/parser/generator.py b/src/nasdaq_protocols/fix/parser/generator.py index de394dc..d4f28f6 100644 --- a/src/nasdaq_protocols/fix/parser/generator.py +++ b/src/nasdaq_protocols/fix/parser/generator.py @@ -54,9 +54,20 @@ def generate(self): generated_modules.append(module_name) generated_files.append(generated_file) + # Generate the app module + Generator._generate( + self._context, + os.path.join(str(TEMPLATES_PATH), 'app.mustache'), + os.path.join(self.op_dir, 'app.py') + ) + generated_modules.append('app') + generated_files.append(os.path.join(self.op_dir, 'app.py')) + # Generate the __init__.py file if self.generate_init_file: context = { + 'app_name': self.app_name, + 'client_session': self._context['client_session'], 'modules': [ {'name': module} for module in generated_modules ] diff --git a/src/nasdaq_protocols/fix/parser/parser.py b/src/nasdaq_protocols/fix/parser/parser.py index 0391a68..3ab9f03 100644 --- a/src/nasdaq_protocols/fix/parser/parser.py +++ b/src/nasdaq_protocols/fix/parser/parser.py @@ -31,7 +31,12 @@ def parse(file: str) -> Definitions: if root.tag != 'fix': raise ValueError('root tag is not fix') - version = int(f'{root.get("major")}{root.get("minor")}') + version_str = f'{root.get("major")}{root.get("minor")}' + servicepack = int(root.get('servicepack', '0')) + if servicepack > 0: + version_str += f'{servicepack}' + version = int(version_str) + try: version = Version(version) except ValueError as v_error: @@ -44,7 +49,7 @@ def parse(file: str) -> Definitions: 'trailer': _handle_trailer, 'messages': _handle_messages } - definitions = Definitions() + definitions = Definitions(version) for element in list(root)[::-1]: handlers[element.tag](definitions, root, element) diff --git a/src/nasdaq_protocols/fix/parser/templates/app.mustache b/src/nasdaq_protocols/fix/parser/templates/app.mustache new file mode 100644 index 0000000..ed7fa74 --- /dev/null +++ b/src/nasdaq_protocols/fix/parser/templates/app.mustache @@ -0,0 +1,34 @@ +from nasdaq_protocols.common import logable +from nasdaq_protocols import fix + + +@logable +class Message(fix.Message, app_name='{{app_name}}'): + def __init_subclass__(cls, **kwargs): + cls.log.debug("{{app_name}} Message subclassed") + for field in fix.Message.MandatoryFields: + if field not in kwargs: + raise ValueError(f"{field} missing when subclassing Message[{{app_name}}]") + kwargs['app_name'] = '{{app_name}}' + super().__init_subclass__(**kwargs) + + +class ClientSession(fix.{{client_session}}): + @classmethod + def decode(cls, data: bytes) -> fix.Message: + return Message.from_bytes(data) + + +async def connect_async(remote: tuple[str, int], + login_msg: Message, + on_msg_coro = None, + on_close_coro = None, + client_heartbeat_interval: int = 10, + server_heartbeat_interval: int = 10) -> fix.FixSession: + session = ClientSession( + on_msg_coro=on_msg_coro, + on_close_coro=on_close_coro, + client_heartbeat_interval=client_heartbeat_interval, + server_heartbeat_interval=server_heartbeat_interval + ) + return await fix.connect_async(remote, login_msg, lambda: session) diff --git a/src/nasdaq_protocols/fix/parser/templates/fields.mustache b/src/nasdaq_protocols/fix/parser/templates/fields.mustache index a16d1cf..cb8770c 100644 --- a/src/nasdaq_protocols/fix/parser/templates/fields.mustache +++ b/src/nasdaq_protocols/fix/parser/templates/fields.mustache @@ -9,7 +9,7 @@ class {{name}}(fix.Field, Tag="{{tag}}", Name="{{name}}", Type=fix.{{type}}): {{/values}} } {{#values}} - {{f_value}} = {{#quote}}'{{/quote}}{{f_name}}{{#quote}}'{{/quote}}, + {{f_value}} = {{#quote}}'{{/quote}}{{f_name}}{{#quote}}'{{/quote}} {{/values}} diff --git a/src/nasdaq_protocols/fix/parser/templates/init.mustache b/src/nasdaq_protocols/fix/parser/templates/init.mustache index 3520e8c..2a81d03 100644 --- a/src/nasdaq_protocols/fix/parser/templates/init.mustache +++ b/src/nasdaq_protocols/fix/parser/templates/init.mustache @@ -1,3 +1,5 @@ +from nasdaq_protocols import fix {{#modules}} from .{{name}} import * {{/modules}} + diff --git a/src/nasdaq_protocols/fix/parser/templates/messages.mustache b/src/nasdaq_protocols/fix/parser/templates/messages.mustache index 3fea8f2..d42cafd 100644 --- a/src/nasdaq_protocols/fix/parser/templates/messages.mustache +++ b/src/nasdaq_protocols/fix/parser/templates/messages.mustache @@ -1,16 +1,20 @@ from nasdaq_protocols import fix from . import {{module_prefix}}_groups as groups from . import {{module_prefix}}_bodies as bodies +from .app import Message as Message {{#messages}} -class {{name}}(fix.Message, +class {{name}}(Message, Name="{{name}}", Type="{{tag}}", Category="{{category}}", HeaderCls=bodies.Header, BodyCls=bodies.{{name}}Body, TrailerCls=bodies.Trailer): + Header: bodies.Header + Body: bodies.{{name}}Body + Trailer: bodies.Trailer {{#entries}} {{^is_group}} {{field.name}}: {{field.type_hint}} diff --git a/src/nasdaq_protocols/fix/parser/version_types.py b/src/nasdaq_protocols/fix/parser/version_types.py index f1af4cd..1b35800 100644 --- a/src/nasdaq_protocols/fix/parser/version_types.py +++ b/src/nasdaq_protocols/fix/parser/version_types.py @@ -16,7 +16,7 @@ class Version(enum.IntEnum): FIX_4_2 = 42 FIX_4_4 = 44 FIX_5_0 = 50 - FIX_5_2 = 52 + FIX_5_0_2 = 502 def get_supported_types(version: Version) -> SupportedTypes: @@ -24,7 +24,7 @@ def get_supported_types(version: Version) -> SupportedTypes: Version.FIX_4_2: fix_42_version_types, Version.FIX_4_4: fix_44_version_types, Version.FIX_5_0: fix_50_version_types, - Version.FIX_5_2: fix_52_version_types + Version.FIX_5_0_2: fix_502_version_types } try: return version_map[version]() @@ -81,10 +81,11 @@ def fix_50_version_types(): return fix_50_types -def fix_52_version_types(): +def fix_502_version_types(): fix_52_types = fix_50_version_types() fix_52_types.update({ 'LOCALMKTDATE': types.FixLocalMktDate, - 'TZTIMEONLY': types.FixTzTimeonly + 'TZTIMEONLY': types.FixTzTimeonly, + 'MULTIPLESTRINGVALUE': types.FixMultipleValueString, }) return fix_52_types diff --git a/src/nasdaq_protocols/fix/session.py b/src/nasdaq_protocols/fix/session.py index 2afe2bf..b56cd3f 100644 --- a/src/nasdaq_protocols/fix/session.py +++ b/src/nasdaq_protocols/fix/session.py @@ -75,11 +75,12 @@ def send_msg(self, msg: core.Message) -> None: msg.validate(segments=[core.MessageSegments.BODY]) msg.Header.SenderSubID = self.sender_sub_id msg.Header.TargetCompID = self.target_comp_id + msg.Header.SenderCompID = self.sender_comp_id msg.Header.MsgSeqNum = next(self.sequence) msg.Header.SendingTime = datetime.now(timezone.utc).strftime("%Y%m%d-%H:%M:%S") data = self._prepare_complete_msg(msg) - self.log.info(data) + self.log.debug('%s> sent message[%s]: %s', self.session_id, msg.Name, data) self._transport.write(data) self.log.debug('%s> sent message[%s]: %s', self.session_id, msg.Name, msg.as_collection()) diff --git a/tests/test_fix_codegen.py b/tests/test_fix_codegen.py index 4cae8ac..0b890e6 100644 --- a/tests/test_fix_codegen.py +++ b/tests/test_fix_codegen.py @@ -29,7 +29,7 @@ def test__no_init_file__no_prefix__code_generated(codegen_invoker): prefix=prefix ) - assert len(generated_files) == 4 + assert len(generated_files) == 5 def test__init_file__no_prefix__code_generated(fix_44_definitions, codegen_invoker, tmp_path, module_loader): @@ -45,7 +45,7 @@ def test__init_file__no_prefix__code_generated(fix_44_definitions, codegen_invok output_dir=output_dir ) - assert len(generated_files) == 5 + assert len(generated_files) == 6 # This ensures the generated code is correct generated_package = module_loader('test__init_file__no_prefix__code_generated', output_dir / '__init__.py') diff --git a/tests/test_fix_parser.py b/tests/test_fix_parser.py index b205523..ad9d3a6 100644 --- a/tests/test_fix_parser.py +++ b/tests/test_fix_parser.py @@ -96,6 +96,37 @@ def test__fix_parser__parse__field_not_found(tmp_file_writer): assert str(e.value) == 'Field definition for NotFound not found' +def test__fix_parser__parse__xml_with_service_pack(tmp_file_writer): + fix_502 = ''' + + + ''' + file = tmp_file_writer(fix_502) + + definitions = parse(file) + assert definitions.version == 502 + + +def test__fix_parser__parse__xml_with_keywords__keywords_are_transformed(tmp_file_writer): + fix_502 = ''' + + + + + + + + + ''' + file = tmp_file_writer(fix_502) + + definitions = parse(file) + assert definitions.version == 502 + context = definitions.fields['MsgType'].get_codegen_context(None) + assert context['values'][0]['f_value'] == 'None_' + assert context['values'][1]['f_value'] == 'if_' + + def test__fix_parser__fields_are_parsed(fix_44_definitions): def assert_field(name, tag, typ, total_possible_values): assert name in fix_44_definitions.fields diff --git a/tests/testdata.py b/tests/testdata.py index d40c844..cf9f693 100644 --- a/tests/testdata.py +++ b/tests/testdata.py @@ -192,7 +192,7 @@ TEST_FIX_44_XML = """ - +
diff --git a/tox.ini b/tox.ini index f1d9386..970921b 100644 --- a/tox.ini +++ b/tox.ini @@ -28,7 +28,7 @@ deps = pytest-asyncio pytest-sugar pytest-cov -commands = pytest --cov-fail-under=97 # Locking down the current cov percent as baseline +commands = pytest --cov=src --cov-fail-under=97 # Locking down the current cov percent as baseline [testenv:build]