diff --git a/src/nasdaq_protocols/fix/codegen.py b/src/nasdaq_protocols/fix/codegen.py new file mode 100644 index 0000000..7ae6223 --- /dev/null +++ b/src/nasdaq_protocols/fix/codegen.py @@ -0,0 +1,29 @@ +import click +from nasdaq_protocols.fix.parser import parse, Generator + + +__all__ = [ + 'generate' +] + + +@click.command() +@click.option('--spec-file', type=click.Path(exists=True)) +@click.option('--app-name', type=click.STRING) +@click.option('--prefix', type=click.STRING, default='') +@click.option('--op-dir', type=click.Path(exists=True, writable=True)) +@click.option('--init-file/--no-init-file', show_default=True, default=True) +def generate(spec_file, app_name, op_dir, prefix, init_file): + + try: + generator = Generator( + parse(spec_file), + app_name, + op_dir, + prefix, + generate_init_file=init_file + ) + generator.generate() + except Exception as e: + print(f'Error: {e}') + raise e diff --git a/src/nasdaq_protocols/fix/parser/__init__.py b/src/nasdaq_protocols/fix/parser/__init__.py index e69de29..d3a1fb8 100644 --- a/src/nasdaq_protocols/fix/parser/__init__.py +++ b/src/nasdaq_protocols/fix/parser/__init__.py @@ -0,0 +1,3 @@ +from .definitions import * +from .parser import * +from .generator import * diff --git a/src/nasdaq_protocols/fix/parser/definitions.py b/src/nasdaq_protocols/fix/parser/definitions.py index ec65bce..49831c7 100644 --- a/src/nasdaq_protocols/fix/parser/definitions.py +++ b/src/nasdaq_protocols/fix/parser/definitions.py @@ -1,3 +1,5 @@ +from collections import defaultdict +from itertools import count from typing import Any import attrs @@ -24,27 +26,81 @@ class FieldDef: type: TypeDefinition possible_values: dict[str, Any] = None + def _values_ctx(self): + output = [] + if self.possible_values: + for key, value in self.possible_values.items(): + output.append({ + 'f_name': key, + 'f_value': value, + 'quote': self.type.type_cls in [str, bool] + }) + return output + + def get_codegen_context(self, _definitions): + return { + 'tag': self.tag, + 'name': self.name, + 'type': self.type.__name__, + 'type_hint': self.type.hint, + 'values': self._values_ctx() + } + @attrs.define class Entry: required: bool = attrs.field(kw_only=True) + def get_codegen_context(self, _definitions): + return { + 'required': 'True' if self.required else 'False' + } + @attrs.define class EntryContainer: entries: list[Entry] = attrs.field(kw_only=True, factory=list) + def get_codegen_context(self, definitions): + return { + 'entries': [entry.get_codegen_context(definitions) for entry in self.entries] + } + @attrs.define class Field(Entry): field: FieldDef = attrs.field(kw_only=True) + def get_codegen_context(self, definitions): + return super().get_codegen_context(definitions) | { + 'field': self.field.get_codegen_context(definitions) + } + @attrs.define class Group(Entry): name: str = attrs.field(kw_only=True) entries: list[Entry] = attrs.field(kw_only=True, factory=list) + Contexts = [] + UniqueNameCounter = defaultdict(lambda: count(1)) + + def get_codegen_context(self, definitions): + unique_name = f'{self.name}_{next(Group.UniqueNameCounter[self.name])}' # to avoid name conflicts + group_context = super().get_codegen_context(definitions) | { + 'name': self.name, + 'unique_name': unique_name, + 'is_group': True, + 'entries': [entry.get_codegen_context(definitions) for entry in self.entries], + } + + Group.Contexts.append({ + 'name': self.name, + 'unique_name': unique_name, + 'entries': [entry.get_codegen_context(definitions) for entry in self.entries], + }) + return group_context + @attrs.define class Component(EntryContainer): @@ -57,6 +113,13 @@ class Message(EntryContainer): name: str category: str + def get_codegen_context(self, definitions): + return super().get_codegen_context(definitions) | { + 'tag': self.tag, + 'name': self.name, + 'category': self.category, + } + @attrs.define class Definitions: @@ -65,3 +128,24 @@ class Definitions: header: EntryContainer = attrs.field(kw_only=True, factory=EntryContainer) trailer: EntryContainer = attrs.field(kw_only=True, factory=EntryContainer) messages: list[Message] = attrs.field(kw_only=True, factory=list) + + def get_codegen_context(self): + message_context = [ + message.get_codegen_context(self) | { + 'body_name': f'{message.name}Body', + } + for message in self.messages + ] + return { + 'fields': [field.get_codegen_context(self) for field in self.fields.values()], + 'bodies': [ + self.header.get_codegen_context(self) | { + 'body_name': 'Header', + }, + self.trailer.get_codegen_context(self) | { + 'body_name': 'Trailer', + }, + ] + message_context, + 'messages': message_context, + 'groups': Group.Contexts, # always last, as it is dependent on other entries + } diff --git a/src/nasdaq_protocols/fix/parser/generator.py b/src/nasdaq_protocols/fix/parser/generator.py new file mode 100644 index 0000000..de394dc --- /dev/null +++ b/src/nasdaq_protocols/fix/parser/generator.py @@ -0,0 +1,77 @@ +import os +from importlib import resources +from pathlib import Path + +import attrs +import chevron +from .parser import Definitions +from . import templates + + +__all__ = [ + 'Generator' +] +TEMPLATES_PATH = resources.files(templates) + + +@attrs.define(auto_attribs=True) +class Generator: + definitions: Definitions + app_name: str + op_dir: str + prefix: str = '' + generate_init_file: bool = False + _init_file: str = None + _module_prefix: str = None + _context: dict = None + + def __attrs_post_init__(self): + prefix = f'{self.prefix}_' if self.prefix else '' + self._module_prefix = f'{prefix}fix_{self.app_name}' + self._init_file = os.path.join(self.op_dir, '__init__.py') + Path(self.op_dir).mkdir(parents=True, exist_ok=True) + self._prepare_context() + + def _prepare_context(self): + self._context = self.definitions.get_codegen_context() | { + 'module_prefix': self._module_prefix, + 'app_name': self.app_name, + } + + def generate(self): + files = ['fields', 'groups', 'bodies', 'messages'] + generated_modules = [] + generated_files = [] + # Generate the modules + for file in files: + module_name = f'{self._module_prefix}_{file}' + generated_file = os.path.join(self.op_dir, f'{module_name}.py') + Generator._generate( + self._context, + os.path.join(str(TEMPLATES_PATH), f'{file}.mustache'), + generated_file + ) + generated_modules.append(module_name) + generated_files.append(generated_file) + + # Generate the __init__.py file + if self.generate_init_file: + context = { + 'modules': [ + {'name': module} for module in generated_modules + ] + } + Generator._generate( + context, + os.path.join(str(TEMPLATES_PATH), 'init.mustache'), + self._init_file + ) + generated_files.append(self._init_file) + return generated_files + + @staticmethod + def _generate(context, template, op_file): + with open(op_file, 'a', encoding='utf-8') as op, open(template, 'r', encoding='utf-8') as inp: + code_as_string = chevron.render(inp.read(), context, partials_path=str(TEMPLATES_PATH)) + op.write(code_as_string) + print(f'Generated: {op_file}') diff --git a/src/nasdaq_protocols/fix/parser/parser.py b/src/nasdaq_protocols/fix/parser/parser.py index c1ed4b8..0391a68 100644 --- a/src/nasdaq_protocols/fix/parser/parser.py +++ b/src/nasdaq_protocols/fix/parser/parser.py @@ -18,6 +18,9 @@ ) +__all__ = [ + 'parse' +] LOG = logging.getLogger(__name__) diff --git a/src/nasdaq_protocols/fix/parser/templates/__init__.py b/src/nasdaq_protocols/fix/parser/templates/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/nasdaq_protocols/fix/parser/templates/bodies.mustache b/src/nasdaq_protocols/fix/parser/templates/bodies.mustache new file mode 100644 index 0000000..8bff9b9 --- /dev/null +++ b/src/nasdaq_protocols/fix/parser/templates/bodies.mustache @@ -0,0 +1,29 @@ +from nasdaq_protocols import fix +from . import {{module_prefix}}_fields as fields +from . import {{module_prefix}}_groups as groups + + +{{#bodies}} +class {{body_name}}(fix.DataSegment): + Entries = [ +{{#entries}} + {{^is_group}} + fix.Entry(fields.{{field.name}}, {{required}}), + {{/is_group}} + {{#is_group}} + fix.Entry(groups.{{unique_name}}_List, {{required}}), + {{/is_group}} +{{/entries}} + ] + +{{#entries}} + {{^is_group}} + {{field.name}}: {{field.type_hint}} + {{/is_group}} + {{#is_group}} + {{name}}: groups.{{unique_name}}_List + {{/is_group}} +{{/entries}} + + +{{/bodies}} \ No newline at end of file diff --git a/src/nasdaq_protocols/fix/parser/templates/fields.mustache b/src/nasdaq_protocols/fix/parser/templates/fields.mustache new file mode 100644 index 0000000..a16d1cf --- /dev/null +++ b/src/nasdaq_protocols/fix/parser/templates/fields.mustache @@ -0,0 +1,17 @@ +from nasdaq_protocols import fix + + +{{#fields}} +class {{name}}(fix.Field, Tag="{{tag}}", Name="{{name}}", Type=fix.{{type}}): + Values = { +{{#values}} + {{#quote}}'{{/quote}}{{f_name}}{{#quote}}'{{/quote}}: '{{f_value}}', +{{/values}} + } +{{#values}} + {{f_value}} = {{#quote}}'{{/quote}}{{f_name}}{{#quote}}'{{/quote}}, +{{/values}} + + +{{/fields}} + diff --git a/src/nasdaq_protocols/fix/parser/templates/groups.mustache b/src/nasdaq_protocols/fix/parser/templates/groups.mustache new file mode 100644 index 0000000..92e0150 --- /dev/null +++ b/src/nasdaq_protocols/fix/parser/templates/groups.mustache @@ -0,0 +1,33 @@ +from nasdaq_protocols import fix +from . import {{module_prefix}}_fields as fields + + +{{#groups}} +class {{unique_name}}(fix.Group): + Entries = [ +{{#entries}} + {{^is_group}} + fix.Entry(fields.{{field.name}}, {{required}}), + {{/is_group}} + {{#is_group}} + fix.Entry({{unique_name}}_List, {{required}}), + {{/is_group}} +{{/entries}} + ] + +{{#entries}} + {{^is_group}} + {{field.name}}: {{field.type_hint}} + {{/is_group}} + {{#is_group}} + {{name}}: {{unique_name}}_List + {{/is_group}} +{{/entries}} + + +class {{unique_name}}_List(fix.GroupContainer, CountCls=fields.{{name}}, GroupCls={{unique_name}}): + def __getitem__(self, idx) -> {{unique_name}}: + return super({{unique_name}}_List, self).__getitem__(idx) + + +{{/groups}} diff --git a/src/nasdaq_protocols/fix/parser/templates/init.mustache b/src/nasdaq_protocols/fix/parser/templates/init.mustache new file mode 100644 index 0000000..3520e8c --- /dev/null +++ b/src/nasdaq_protocols/fix/parser/templates/init.mustache @@ -0,0 +1,3 @@ +{{#modules}} +from .{{name}} import * +{{/modules}} diff --git a/src/nasdaq_protocols/fix/parser/templates/messages.mustache b/src/nasdaq_protocols/fix/parser/templates/messages.mustache new file mode 100644 index 0000000..3fea8f2 --- /dev/null +++ b/src/nasdaq_protocols/fix/parser/templates/messages.mustache @@ -0,0 +1,24 @@ +from nasdaq_protocols import fix +from . import {{module_prefix}}_groups as groups +from . import {{module_prefix}}_bodies as bodies + + +{{#messages}} +class {{name}}(fix.Message, + Name="{{name}}", + Type="{{tag}}", + Category="{{category}}", + HeaderCls=bodies.Header, + BodyCls=bodies.{{name}}Body, + TrailerCls=bodies.Trailer): +{{#entries}} + {{^is_group}} + {{field.name}}: {{field.type_hint}} + {{/is_group}} + {{#is_group}} + {{name}}: groups.{{unique_name}}_List + {{/is_group}} +{{/entries}} + + +{{/messages}} \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 3f3111e..ff1775f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,19 +39,20 @@ async def mock_server_session(unused_tcp_port): @pytest.fixture(scope='function') -def codegen_invoker(tmp_path): - def generator(codegen, xml_content, app_name, generate_init_file, prefix): +def codegen_invoker(capsys, tmp_path): + def generator(codegen, xml_content, app_name, generate_init_file, prefix, output_dir=None): runner = CliRunner() - with runner.isolated_filesystem(temp_dir=tmp_path): + with capsys.disabled(), runner.isolated_filesystem(temp_dir=tmp_path): with open('spec.xml', 'w') as spec_file: spec_file.write(xml_content) - Path('output').mkdir(parents=True, exist_ok=True) + output_dir = output_dir or 'output' + Path(output_dir).mkdir(parents=True, exist_ok=True) result = runner.invoke( codegen, [ '--spec-file', 'spec.xml', '--app-name', app_name, - '--op-dir', 'output', + '--op-dir', output_dir, '--prefix', prefix, '--init-file' if generate_init_file else '--no-init-file' ] @@ -60,18 +61,18 @@ def generator(codegen, xml_content, app_name, generate_init_file, prefix): # Read the generated files generated_file_contents = {} - for file in os.listdir('output'): - with open(os.path.join('output', file)) as f: + for file in os.listdir(output_dir): + with open(os.path.join(output_dir, file)) as f: generated_file_contents[file] = f.read() return generated_file_contents return generator @pytest.fixture(scope='function') -def tools_codegen_invoker(tmp_path): +def tools_codegen_invoker(capsys, tmp_path): def generator(codegen, app_name, package): runner = CliRunner() - with runner.isolated_filesystem(temp_dir=tmp_path): + with capsys.disabled(), runner.isolated_filesystem(temp_dir=tmp_path): Path('output').mkdir(parents=True, exist_ok=True) result = runner.invoke( codegen, @@ -100,4 +101,16 @@ def loader_(module_name, code_as_string): exec(code_as_string, module_.__dict__) sys.modules[module_name] = module_ return module_ - return loader_ \ No newline at end of file + return loader_ + + +@pytest.fixture(scope='session') +def module_loader(): + def load_module(module_name, file): + spec = importlib.util.spec_from_file_location(module_name, file) + assert spec is not None + generated_package = importlib.util.module_from_spec(spec) + sys.modules[module_name] = generated_package + spec.loader.exec_module(generated_package) + return generated_package + return load_module diff --git a/tests/test_fix_codegen.py b/tests/test_fix_codegen.py new file mode 100644 index 0000000..4cae8ac --- /dev/null +++ b/tests/test_fix_codegen.py @@ -0,0 +1,57 @@ +import importlib +import os +import sys +from pathlib import Path + +import pytest + +from nasdaq_protocols.fix import codegen +from nasdaq_protocols.fix.parser import parse +from tests.testdata import TEST_FIX_44_XML, TEST_XML_ITCH_MESSAGE + + +@pytest.fixture(scope='function') +def fix_44_definitions(tmp_file_writer): + file = tmp_file_writer(TEST_FIX_44_XML) + definitions = parse(file) + assert definitions is not None + yield definitions + + +def test__no_init_file__no_prefix__code_generated(codegen_invoker): + prefix = '' + app_name = 'gwy_44' + generated_files = codegen_invoker( + codegen.generate, + TEST_FIX_44_XML, + app_name, + generate_init_file=False, + prefix=prefix + ) + + assert len(generated_files) == 4 + + +def test__init_file__no_prefix__code_generated(fix_44_definitions, codegen_invoker, tmp_path, module_loader): + prefix = '' + app_name = 'gwy_44' + output_dir = tmp_path / app_name + generated_files = codegen_invoker( + codegen.generate, + TEST_FIX_44_XML, + app_name, + generate_init_file=True, + prefix=prefix, + output_dir=output_dir + ) + + assert len(generated_files) == 5 + + # This ensures the generated code is correct + generated_package = module_loader('test__init_file__no_prefix__code_generated', output_dir / '__init__.py') + + for field in fix_44_definitions.fields: + assert hasattr(generated_package, field) + + for message in fix_44_definitions.messages: + assert hasattr(generated_package, message.name) diff --git a/tests/test_fix_parser.py b/tests/test_fix_parser.py index f28b689..b205523 100644 --- a/tests/test_fix_parser.py +++ b/tests/test_fix_parser.py @@ -125,12 +125,12 @@ def assert_component(name, *fields_defn): assert_component('InstrumentExtensionNoInstrAttribSubGroup', ('PossResend', True), ('QuoteStatus', False)) assert 'GroupOfComponents' in fix_44_definitions.components assert_group( - 'InstrmtLegGrp_Group', + 'NoLegs', fix_44_definitions.components['GroupOfComponents'].entries[0], ('PossResend', False), ('QuoteStatus', False) ) assert_group( - 'Instrument_Group', + 'NoStreams', fix_44_definitions.components['GroupOfComponents'].entries[1], ('PossResend', True), ('QuoteStatus', True) ) @@ -172,7 +172,7 @@ def test__fix_parser__message_is_parsed(fix_44_definitions): assert message.entries[2].field.name == 'QuoteStatus' assert not message.entries[2].required assert_group( - 'Instrument_Group', + 'NoStreams', message.entries[3], ('PossResend', True), ('QuoteStatus', True) ) diff --git a/tests/testdata.py b/tests/testdata.py index f4309eb..d40c844 100644 --- a/tests/testdata.py +++ b/tests/testdata.py @@ -204,7 +204,7 @@ - + @@ -215,10 +215,10 @@ - + - + @@ -249,6 +249,8 @@ + + """ \ No newline at end of file diff --git a/tox.ini b/tox.ini index b98f0a0..f1d9386 100644 --- a/tox.ini +++ b/tox.ini @@ -60,3 +60,4 @@ per-file-ignores = ./src/nasdaq_protocols/common/message/__init__.py:F401,F403 ./src/nasdaq_protocols/fix/__init__.py:F401,F403,F405 ./src/nasdaq_protocols/itch/__init__.py:F401,F403 + ./src/nasdaq_protocols/fix/parser/__init__.py:F401,F403