diff --git a/src/nasdaq_protocols/fix/codegen.py b/src/nasdaq_protocols/fix/codegen.py
new file mode 100644
index 0000000..7ae6223
--- /dev/null
+++ b/src/nasdaq_protocols/fix/codegen.py
@@ -0,0 +1,29 @@
+import click
+from nasdaq_protocols.fix.parser import parse, Generator
+
+
+__all__ = [
+ 'generate'
+]
+
+
+@click.command()
+@click.option('--spec-file', type=click.Path(exists=True))
+@click.option('--app-name', type=click.STRING)
+@click.option('--prefix', type=click.STRING, default='')
+@click.option('--op-dir', type=click.Path(exists=True, writable=True))
+@click.option('--init-file/--no-init-file', show_default=True, default=True)
+def generate(spec_file, app_name, op_dir, prefix, init_file):
+
+ try:
+ generator = Generator(
+ parse(spec_file),
+ app_name,
+ op_dir,
+ prefix,
+ generate_init_file=init_file
+ )
+ generator.generate()
+ except Exception as e:
+ print(f'Error: {e}')
+ raise e
diff --git a/src/nasdaq_protocols/fix/parser/__init__.py b/src/nasdaq_protocols/fix/parser/__init__.py
index e69de29..d3a1fb8 100644
--- a/src/nasdaq_protocols/fix/parser/__init__.py
+++ b/src/nasdaq_protocols/fix/parser/__init__.py
@@ -0,0 +1,3 @@
+from .definitions import *
+from .parser import *
+from .generator import *
diff --git a/src/nasdaq_protocols/fix/parser/definitions.py b/src/nasdaq_protocols/fix/parser/definitions.py
index ec65bce..49831c7 100644
--- a/src/nasdaq_protocols/fix/parser/definitions.py
+++ b/src/nasdaq_protocols/fix/parser/definitions.py
@@ -1,3 +1,5 @@
+from collections import defaultdict
+from itertools import count
from typing import Any
import attrs
@@ -24,27 +26,81 @@ class FieldDef:
type: TypeDefinition
possible_values: dict[str, Any] = None
+ def _values_ctx(self):
+ output = []
+ if self.possible_values:
+ for key, value in self.possible_values.items():
+ output.append({
+ 'f_name': key,
+ 'f_value': value,
+ 'quote': self.type.type_cls in [str, bool]
+ })
+ return output
+
+ def get_codegen_context(self, _definitions):
+ return {
+ 'tag': self.tag,
+ 'name': self.name,
+ 'type': self.type.__name__,
+ 'type_hint': self.type.hint,
+ 'values': self._values_ctx()
+ }
+
@attrs.define
class Entry:
required: bool = attrs.field(kw_only=True)
+ def get_codegen_context(self, _definitions):
+ return {
+ 'required': 'True' if self.required else 'False'
+ }
+
@attrs.define
class EntryContainer:
entries: list[Entry] = attrs.field(kw_only=True, factory=list)
+ def get_codegen_context(self, definitions):
+ return {
+ 'entries': [entry.get_codegen_context(definitions) for entry in self.entries]
+ }
+
@attrs.define
class Field(Entry):
field: FieldDef = attrs.field(kw_only=True)
+ def get_codegen_context(self, definitions):
+ return super().get_codegen_context(definitions) | {
+ 'field': self.field.get_codegen_context(definitions)
+ }
+
@attrs.define
class Group(Entry):
name: str = attrs.field(kw_only=True)
entries: list[Entry] = attrs.field(kw_only=True, factory=list)
+ Contexts = []
+ UniqueNameCounter = defaultdict(lambda: count(1))
+
+ def get_codegen_context(self, definitions):
+ unique_name = f'{self.name}_{next(Group.UniqueNameCounter[self.name])}' # to avoid name conflicts
+ group_context = super().get_codegen_context(definitions) | {
+ 'name': self.name,
+ 'unique_name': unique_name,
+ 'is_group': True,
+ 'entries': [entry.get_codegen_context(definitions) for entry in self.entries],
+ }
+
+ Group.Contexts.append({
+ 'name': self.name,
+ 'unique_name': unique_name,
+ 'entries': [entry.get_codegen_context(definitions) for entry in self.entries],
+ })
+ return group_context
+
@attrs.define
class Component(EntryContainer):
@@ -57,6 +113,13 @@ class Message(EntryContainer):
name: str
category: str
+ def get_codegen_context(self, definitions):
+ return super().get_codegen_context(definitions) | {
+ 'tag': self.tag,
+ 'name': self.name,
+ 'category': self.category,
+ }
+
@attrs.define
class Definitions:
@@ -65,3 +128,24 @@ class Definitions:
header: EntryContainer = attrs.field(kw_only=True, factory=EntryContainer)
trailer: EntryContainer = attrs.field(kw_only=True, factory=EntryContainer)
messages: list[Message] = attrs.field(kw_only=True, factory=list)
+
+ def get_codegen_context(self):
+ message_context = [
+ message.get_codegen_context(self) | {
+ 'body_name': f'{message.name}Body',
+ }
+ for message in self.messages
+ ]
+ return {
+ 'fields': [field.get_codegen_context(self) for field in self.fields.values()],
+ 'bodies': [
+ self.header.get_codegen_context(self) | {
+ 'body_name': 'Header',
+ },
+ self.trailer.get_codegen_context(self) | {
+ 'body_name': 'Trailer',
+ },
+ ] + message_context,
+ 'messages': message_context,
+ 'groups': Group.Contexts, # always last, as it is dependent on other entries
+ }
diff --git a/src/nasdaq_protocols/fix/parser/generator.py b/src/nasdaq_protocols/fix/parser/generator.py
new file mode 100644
index 0000000..de394dc
--- /dev/null
+++ b/src/nasdaq_protocols/fix/parser/generator.py
@@ -0,0 +1,77 @@
+import os
+from importlib import resources
+from pathlib import Path
+
+import attrs
+import chevron
+from .parser import Definitions
+from . import templates
+
+
+__all__ = [
+ 'Generator'
+]
+TEMPLATES_PATH = resources.files(templates)
+
+
+@attrs.define(auto_attribs=True)
+class Generator:
+ definitions: Definitions
+ app_name: str
+ op_dir: str
+ prefix: str = ''
+ generate_init_file: bool = False
+ _init_file: str = None
+ _module_prefix: str = None
+ _context: dict = None
+
+ def __attrs_post_init__(self):
+ prefix = f'{self.prefix}_' if self.prefix else ''
+ self._module_prefix = f'{prefix}fix_{self.app_name}'
+ self._init_file = os.path.join(self.op_dir, '__init__.py')
+ Path(self.op_dir).mkdir(parents=True, exist_ok=True)
+ self._prepare_context()
+
+ def _prepare_context(self):
+ self._context = self.definitions.get_codegen_context() | {
+ 'module_prefix': self._module_prefix,
+ 'app_name': self.app_name,
+ }
+
+ def generate(self):
+ files = ['fields', 'groups', 'bodies', 'messages']
+ generated_modules = []
+ generated_files = []
+ # Generate the modules
+ for file in files:
+ module_name = f'{self._module_prefix}_{file}'
+ generated_file = os.path.join(self.op_dir, f'{module_name}.py')
+ Generator._generate(
+ self._context,
+ os.path.join(str(TEMPLATES_PATH), f'{file}.mustache'),
+ generated_file
+ )
+ generated_modules.append(module_name)
+ generated_files.append(generated_file)
+
+ # Generate the __init__.py file
+ if self.generate_init_file:
+ context = {
+ 'modules': [
+ {'name': module} for module in generated_modules
+ ]
+ }
+ Generator._generate(
+ context,
+ os.path.join(str(TEMPLATES_PATH), 'init.mustache'),
+ self._init_file
+ )
+ generated_files.append(self._init_file)
+ return generated_files
+
+ @staticmethod
+ def _generate(context, template, op_file):
+ with open(op_file, 'a', encoding='utf-8') as op, open(template, 'r', encoding='utf-8') as inp:
+ code_as_string = chevron.render(inp.read(), context, partials_path=str(TEMPLATES_PATH))
+ op.write(code_as_string)
+ print(f'Generated: {op_file}')
diff --git a/src/nasdaq_protocols/fix/parser/parser.py b/src/nasdaq_protocols/fix/parser/parser.py
index c1ed4b8..0391a68 100644
--- a/src/nasdaq_protocols/fix/parser/parser.py
+++ b/src/nasdaq_protocols/fix/parser/parser.py
@@ -18,6 +18,9 @@
)
+__all__ = [
+ 'parse'
+]
LOG = logging.getLogger(__name__)
diff --git a/src/nasdaq_protocols/fix/parser/templates/__init__.py b/src/nasdaq_protocols/fix/parser/templates/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/nasdaq_protocols/fix/parser/templates/bodies.mustache b/src/nasdaq_protocols/fix/parser/templates/bodies.mustache
new file mode 100644
index 0000000..8bff9b9
--- /dev/null
+++ b/src/nasdaq_protocols/fix/parser/templates/bodies.mustache
@@ -0,0 +1,29 @@
+from nasdaq_protocols import fix
+from . import {{module_prefix}}_fields as fields
+from . import {{module_prefix}}_groups as groups
+
+
+{{#bodies}}
+class {{body_name}}(fix.DataSegment):
+ Entries = [
+{{#entries}}
+ {{^is_group}}
+ fix.Entry(fields.{{field.name}}, {{required}}),
+ {{/is_group}}
+ {{#is_group}}
+ fix.Entry(groups.{{unique_name}}_List, {{required}}),
+ {{/is_group}}
+{{/entries}}
+ ]
+
+{{#entries}}
+ {{^is_group}}
+ {{field.name}}: {{field.type_hint}}
+ {{/is_group}}
+ {{#is_group}}
+ {{name}}: groups.{{unique_name}}_List
+ {{/is_group}}
+{{/entries}}
+
+
+{{/bodies}}
\ No newline at end of file
diff --git a/src/nasdaq_protocols/fix/parser/templates/fields.mustache b/src/nasdaq_protocols/fix/parser/templates/fields.mustache
new file mode 100644
index 0000000..a16d1cf
--- /dev/null
+++ b/src/nasdaq_protocols/fix/parser/templates/fields.mustache
@@ -0,0 +1,17 @@
+from nasdaq_protocols import fix
+
+
+{{#fields}}
+class {{name}}(fix.Field, Tag="{{tag}}", Name="{{name}}", Type=fix.{{type}}):
+ Values = {
+{{#values}}
+ {{#quote}}'{{/quote}}{{f_name}}{{#quote}}'{{/quote}}: '{{f_value}}',
+{{/values}}
+ }
+{{#values}}
+ {{f_value}} = {{#quote}}'{{/quote}}{{f_name}}{{#quote}}'{{/quote}},
+{{/values}}
+
+
+{{/fields}}
+
diff --git a/src/nasdaq_protocols/fix/parser/templates/groups.mustache b/src/nasdaq_protocols/fix/parser/templates/groups.mustache
new file mode 100644
index 0000000..92e0150
--- /dev/null
+++ b/src/nasdaq_protocols/fix/parser/templates/groups.mustache
@@ -0,0 +1,33 @@
+from nasdaq_protocols import fix
+from . import {{module_prefix}}_fields as fields
+
+
+{{#groups}}
+class {{unique_name}}(fix.Group):
+ Entries = [
+{{#entries}}
+ {{^is_group}}
+ fix.Entry(fields.{{field.name}}, {{required}}),
+ {{/is_group}}
+ {{#is_group}}
+ fix.Entry({{unique_name}}_List, {{required}}),
+ {{/is_group}}
+{{/entries}}
+ ]
+
+{{#entries}}
+ {{^is_group}}
+ {{field.name}}: {{field.type_hint}}
+ {{/is_group}}
+ {{#is_group}}
+ {{name}}: {{unique_name}}_List
+ {{/is_group}}
+{{/entries}}
+
+
+class {{unique_name}}_List(fix.GroupContainer, CountCls=fields.{{name}}, GroupCls={{unique_name}}):
+ def __getitem__(self, idx) -> {{unique_name}}:
+ return super({{unique_name}}_List, self).__getitem__(idx)
+
+
+{{/groups}}
diff --git a/src/nasdaq_protocols/fix/parser/templates/init.mustache b/src/nasdaq_protocols/fix/parser/templates/init.mustache
new file mode 100644
index 0000000..3520e8c
--- /dev/null
+++ b/src/nasdaq_protocols/fix/parser/templates/init.mustache
@@ -0,0 +1,3 @@
+{{#modules}}
+from .{{name}} import *
+{{/modules}}
diff --git a/src/nasdaq_protocols/fix/parser/templates/messages.mustache b/src/nasdaq_protocols/fix/parser/templates/messages.mustache
new file mode 100644
index 0000000..3fea8f2
--- /dev/null
+++ b/src/nasdaq_protocols/fix/parser/templates/messages.mustache
@@ -0,0 +1,24 @@
+from nasdaq_protocols import fix
+from . import {{module_prefix}}_groups as groups
+from . import {{module_prefix}}_bodies as bodies
+
+
+{{#messages}}
+class {{name}}(fix.Message,
+ Name="{{name}}",
+ Type="{{tag}}",
+ Category="{{category}}",
+ HeaderCls=bodies.Header,
+ BodyCls=bodies.{{name}}Body,
+ TrailerCls=bodies.Trailer):
+{{#entries}}
+ {{^is_group}}
+ {{field.name}}: {{field.type_hint}}
+ {{/is_group}}
+ {{#is_group}}
+ {{name}}: groups.{{unique_name}}_List
+ {{/is_group}}
+{{/entries}}
+
+
+{{/messages}}
\ No newline at end of file
diff --git a/tests/conftest.py b/tests/conftest.py
index 3f3111e..ff1775f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -39,19 +39,20 @@ async def mock_server_session(unused_tcp_port):
@pytest.fixture(scope='function')
-def codegen_invoker(tmp_path):
- def generator(codegen, xml_content, app_name, generate_init_file, prefix):
+def codegen_invoker(capsys, tmp_path):
+ def generator(codegen, xml_content, app_name, generate_init_file, prefix, output_dir=None):
runner = CliRunner()
- with runner.isolated_filesystem(temp_dir=tmp_path):
+ with capsys.disabled(), runner.isolated_filesystem(temp_dir=tmp_path):
with open('spec.xml', 'w') as spec_file:
spec_file.write(xml_content)
- Path('output').mkdir(parents=True, exist_ok=True)
+ output_dir = output_dir or 'output'
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
result = runner.invoke(
codegen,
[
'--spec-file', 'spec.xml',
'--app-name', app_name,
- '--op-dir', 'output',
+ '--op-dir', output_dir,
'--prefix', prefix,
'--init-file' if generate_init_file else '--no-init-file'
]
@@ -60,18 +61,18 @@ def generator(codegen, xml_content, app_name, generate_init_file, prefix):
# Read the generated files
generated_file_contents = {}
- for file in os.listdir('output'):
- with open(os.path.join('output', file)) as f:
+ for file in os.listdir(output_dir):
+ with open(os.path.join(output_dir, file)) as f:
generated_file_contents[file] = f.read()
return generated_file_contents
return generator
@pytest.fixture(scope='function')
-def tools_codegen_invoker(tmp_path):
+def tools_codegen_invoker(capsys, tmp_path):
def generator(codegen, app_name, package):
runner = CliRunner()
- with runner.isolated_filesystem(temp_dir=tmp_path):
+ with capsys.disabled(), runner.isolated_filesystem(temp_dir=tmp_path):
Path('output').mkdir(parents=True, exist_ok=True)
result = runner.invoke(
codegen,
@@ -100,4 +101,16 @@ def loader_(module_name, code_as_string):
exec(code_as_string, module_.__dict__)
sys.modules[module_name] = module_
return module_
- return loader_
\ No newline at end of file
+ return loader_
+
+
+@pytest.fixture(scope='session')
+def module_loader():
+ def load_module(module_name, file):
+ spec = importlib.util.spec_from_file_location(module_name, file)
+ assert spec is not None
+ generated_package = importlib.util.module_from_spec(spec)
+ sys.modules[module_name] = generated_package
+ spec.loader.exec_module(generated_package)
+ return generated_package
+ return load_module
diff --git a/tests/test_fix_codegen.py b/tests/test_fix_codegen.py
new file mode 100644
index 0000000..4cae8ac
--- /dev/null
+++ b/tests/test_fix_codegen.py
@@ -0,0 +1,57 @@
+import importlib
+import os
+import sys
+from pathlib import Path
+
+import pytest
+
+from nasdaq_protocols.fix import codegen
+from nasdaq_protocols.fix.parser import parse
+from tests.testdata import TEST_FIX_44_XML, TEST_XML_ITCH_MESSAGE
+
+
+@pytest.fixture(scope='function')
+def fix_44_definitions(tmp_file_writer):
+ file = tmp_file_writer(TEST_FIX_44_XML)
+ definitions = parse(file)
+ assert definitions is not None
+ yield definitions
+
+
+def test__no_init_file__no_prefix__code_generated(codegen_invoker):
+ prefix = ''
+ app_name = 'gwy_44'
+ generated_files = codegen_invoker(
+ codegen.generate,
+ TEST_FIX_44_XML,
+ app_name,
+ generate_init_file=False,
+ prefix=prefix
+ )
+
+ assert len(generated_files) == 4
+
+
+def test__init_file__no_prefix__code_generated(fix_44_definitions, codegen_invoker, tmp_path, module_loader):
+ prefix = ''
+ app_name = 'gwy_44'
+ output_dir = tmp_path / app_name
+ generated_files = codegen_invoker(
+ codegen.generate,
+ TEST_FIX_44_XML,
+ app_name,
+ generate_init_file=True,
+ prefix=prefix,
+ output_dir=output_dir
+ )
+
+ assert len(generated_files) == 5
+
+ # This ensures the generated code is correct
+ generated_package = module_loader('test__init_file__no_prefix__code_generated', output_dir / '__init__.py')
+
+ for field in fix_44_definitions.fields:
+ assert hasattr(generated_package, field)
+
+ for message in fix_44_definitions.messages:
+ assert hasattr(generated_package, message.name)
diff --git a/tests/test_fix_parser.py b/tests/test_fix_parser.py
index f28b689..b205523 100644
--- a/tests/test_fix_parser.py
+++ b/tests/test_fix_parser.py
@@ -125,12 +125,12 @@ def assert_component(name, *fields_defn):
assert_component('InstrumentExtensionNoInstrAttribSubGroup', ('PossResend', True), ('QuoteStatus', False))
assert 'GroupOfComponents' in fix_44_definitions.components
assert_group(
- 'InstrmtLegGrp_Group',
+ 'NoLegs',
fix_44_definitions.components['GroupOfComponents'].entries[0],
('PossResend', False), ('QuoteStatus', False)
)
assert_group(
- 'Instrument_Group',
+ 'NoStreams',
fix_44_definitions.components['GroupOfComponents'].entries[1],
('PossResend', True), ('QuoteStatus', True)
)
@@ -172,7 +172,7 @@ def test__fix_parser__message_is_parsed(fix_44_definitions):
assert message.entries[2].field.name == 'QuoteStatus'
assert not message.entries[2].required
assert_group(
- 'Instrument_Group',
+ 'NoStreams',
message.entries[3],
('PossResend', True), ('QuoteStatus', True)
)
diff --git a/tests/testdata.py b/tests/testdata.py
index f4309eb..d40c844 100644
--- a/tests/testdata.py
+++ b/tests/testdata.py
@@ -204,7 +204,7 @@
-
+
@@ -215,10 +215,10 @@
-
+
-
+
@@ -249,6 +249,8 @@
+
+
"""
\ No newline at end of file
diff --git a/tox.ini b/tox.ini
index b98f0a0..f1d9386 100644
--- a/tox.ini
+++ b/tox.ini
@@ -60,3 +60,4 @@ per-file-ignores =
./src/nasdaq_protocols/common/message/__init__.py:F401,F403
./src/nasdaq_protocols/fix/__init__.py:F401,F403,F405
./src/nasdaq_protocols/itch/__init__.py:F401,F403
+ ./src/nasdaq_protocols/fix/parser/__init__.py:F401,F403