Skip to content

Commit

Permalink
codegen for fix
Browse files Browse the repository at this point in the history
  • Loading branch information
SamDanielThangarajan committed Oct 10, 2024
1 parent 7e7a050 commit 3b41c5e
Show file tree
Hide file tree
Showing 16 changed files with 391 additions and 16 deletions.
29 changes: 29 additions & 0 deletions src/nasdaq_protocols/fix/codegen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import click
from nasdaq_protocols.fix.parser import parse, Generator


__all__ = [
'generate'
]


@click.command()
@click.option('--spec-file', type=click.Path(exists=True))
@click.option('--app-name', type=click.STRING)
@click.option('--prefix', type=click.STRING, default='')
@click.option('--op-dir', type=click.Path(exists=True, writable=True))
@click.option('--init-file/--no-init-file', show_default=True, default=True)
def generate(spec_file, app_name, op_dir, prefix, init_file):

try:
generator = Generator(
parse(spec_file),
app_name,
op_dir,
prefix,
generate_init_file=init_file
)
generator.generate()
except Exception as e:
print(f'Error: {e}')
raise e
3 changes: 3 additions & 0 deletions src/nasdaq_protocols/fix/parser/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .definitions import *
from .parser import *
from .generator import *
84 changes: 84 additions & 0 deletions src/nasdaq_protocols/fix/parser/definitions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import defaultdict
from itertools import count
from typing import Any

import attrs
Expand All @@ -24,27 +26,81 @@ class FieldDef:
type: TypeDefinition
possible_values: dict[str, Any] = None

def _values_ctx(self):
output = []
if self.possible_values:
for key, value in self.possible_values.items():
output.append({
'f_name': key,
'f_value': value,
'quote': self.type.type_cls in [str, bool]
})
return output

def get_codegen_context(self, _definitions):
return {
'tag': self.tag,
'name': self.name,
'type': self.type.__name__,
'type_hint': self.type.hint,
'values': self._values_ctx()
}


@attrs.define
class Entry:
required: bool = attrs.field(kw_only=True)

def get_codegen_context(self, _definitions):
return {
'required': 'True' if self.required else 'False'
}


@attrs.define
class EntryContainer:
entries: list[Entry] = attrs.field(kw_only=True, factory=list)

def get_codegen_context(self, definitions):
return {
'entries': [entry.get_codegen_context(definitions) for entry in self.entries]
}


@attrs.define
class Field(Entry):
field: FieldDef = attrs.field(kw_only=True)

def get_codegen_context(self, definitions):
return super().get_codegen_context(definitions) | {
'field': self.field.get_codegen_context(definitions)
}


@attrs.define
class Group(Entry):
name: str = attrs.field(kw_only=True)
entries: list[Entry] = attrs.field(kw_only=True, factory=list)

Contexts = []
UniqueNameCounter = defaultdict(lambda: count(1))

def get_codegen_context(self, definitions):
unique_name = f'{self.name}_{next(Group.UniqueNameCounter[self.name])}' # to avoid name conflicts
group_context = super().get_codegen_context(definitions) | {
'name': self.name,
'unique_name': unique_name,
'is_group': True,
'entries': [entry.get_codegen_context(definitions) for entry in self.entries],
}

Group.Contexts.append({
'name': self.name,
'unique_name': unique_name,
'entries': [entry.get_codegen_context(definitions) for entry in self.entries],
})
return group_context


@attrs.define
class Component(EntryContainer):
Expand All @@ -57,6 +113,13 @@ class Message(EntryContainer):
name: str
category: str

def get_codegen_context(self, definitions):
return super().get_codegen_context(definitions) | {
'tag': self.tag,
'name': self.name,
'category': self.category,
}


@attrs.define
class Definitions:
Expand All @@ -65,3 +128,24 @@ class Definitions:
header: EntryContainer = attrs.field(kw_only=True, factory=EntryContainer)
trailer: EntryContainer = attrs.field(kw_only=True, factory=EntryContainer)
messages: list[Message] = attrs.field(kw_only=True, factory=list)

def get_codegen_context(self):
message_context = [
message.get_codegen_context(self) | {
'body_name': f'{message.name}Body',
}
for message in self.messages
]
return {
'fields': [field.get_codegen_context(self) for field in self.fields.values()],
'bodies': [
self.header.get_codegen_context(self) | {
'body_name': 'Header',
},
self.trailer.get_codegen_context(self) | {
'body_name': 'Trailer',
},
] + message_context,
'messages': message_context,
'groups': Group.Contexts, # always last, as it is dependent on other entries
}
77 changes: 77 additions & 0 deletions src/nasdaq_protocols/fix/parser/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
from importlib import resources
from pathlib import Path

import attrs
import chevron
from .parser import Definitions
from . import templates


__all__ = [
'Generator'
]
TEMPLATES_PATH = resources.files(templates)


@attrs.define(auto_attribs=True)
class Generator:
definitions: Definitions
app_name: str
op_dir: str
prefix: str = ''
generate_init_file: bool = False
_init_file: str = None
_module_prefix: str = None
_context: dict = None

def __attrs_post_init__(self):
prefix = f'{self.prefix}_' if self.prefix else ''
self._module_prefix = f'{prefix}fix_{self.app_name}'
self._init_file = os.path.join(self.op_dir, '__init__.py')
Path(self.op_dir).mkdir(parents=True, exist_ok=True)
self._prepare_context()

def _prepare_context(self):
self._context = self.definitions.get_codegen_context() | {
'module_prefix': self._module_prefix,
'app_name': self.app_name,
}

def generate(self):
files = ['fields', 'groups', 'bodies', 'messages']
generated_modules = []
generated_files = []
# Generate the modules
for file in files:
module_name = f'{self._module_prefix}_{file}'
generated_file = os.path.join(self.op_dir, f'{module_name}.py')
Generator._generate(
self._context,
os.path.join(str(TEMPLATES_PATH), f'{file}.mustache'),
generated_file
)
generated_modules.append(module_name)
generated_files.append(generated_file)

# Generate the __init__.py file
if self.generate_init_file:
context = {
'modules': [
{'name': module} for module in generated_modules
]
}
Generator._generate(
context,
os.path.join(str(TEMPLATES_PATH), 'init.mustache'),
self._init_file
)
generated_files.append(self._init_file)
return generated_files

@staticmethod
def _generate(context, template, op_file):
with open(op_file, 'a', encoding='utf-8') as op, open(template, 'r', encoding='utf-8') as inp:
code_as_string = chevron.render(inp.read(), context, partials_path=str(TEMPLATES_PATH))
op.write(code_as_string)
print(f'Generated: {op_file}')
3 changes: 3 additions & 0 deletions src/nasdaq_protocols/fix/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
)


__all__ = [
'parse'
]
LOG = logging.getLogger(__name__)


Expand Down
Empty file.
29 changes: 29 additions & 0 deletions src/nasdaq_protocols/fix/parser/templates/bodies.mustache
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from nasdaq_protocols import fix
from . import {{module_prefix}}_fields as fields
from . import {{module_prefix}}_groups as groups


{{#bodies}}
class {{body_name}}(fix.DataSegment):
Entries = [
{{#entries}}
{{^is_group}}
fix.Entry(fields.{{field.name}}, {{required}}),
{{/is_group}}
{{#is_group}}
fix.Entry(groups.{{unique_name}}_List, {{required}}),
{{/is_group}}
{{/entries}}
]

{{#entries}}
{{^is_group}}
{{field.name}}: {{field.type_hint}}
{{/is_group}}
{{#is_group}}
{{name}}: groups.{{unique_name}}_List
{{/is_group}}
{{/entries}}


{{/bodies}}
17 changes: 17 additions & 0 deletions src/nasdaq_protocols/fix/parser/templates/fields.mustache
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from nasdaq_protocols import fix


{{#fields}}
class {{name}}(fix.Field, Tag="{{tag}}", Name="{{name}}", Type=fix.{{type}}):
Values = {
{{#values}}
{{#quote}}'{{/quote}}{{f_name}}{{#quote}}'{{/quote}}: '{{f_value}}',
{{/values}}
}
{{#values}}
{{f_value}} = {{#quote}}'{{/quote}}{{f_name}}{{#quote}}'{{/quote}},
{{/values}}


{{/fields}}

33 changes: 33 additions & 0 deletions src/nasdaq_protocols/fix/parser/templates/groups.mustache
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from nasdaq_protocols import fix
from . import {{module_prefix}}_fields as fields


{{#groups}}
class {{unique_name}}(fix.Group):
Entries = [
{{#entries}}
{{^is_group}}
fix.Entry(fields.{{field.name}}, {{required}}),
{{/is_group}}
{{#is_group}}
fix.Entry({{unique_name}}_List, {{required}}),
{{/is_group}}
{{/entries}}
]

{{#entries}}
{{^is_group}}
{{field.name}}: {{field.type_hint}}
{{/is_group}}
{{#is_group}}
{{name}}: {{unique_name}}_List
{{/is_group}}
{{/entries}}


class {{unique_name}}_List(fix.GroupContainer, CountCls=fields.{{name}}, GroupCls={{unique_name}}):
def __getitem__(self, idx) -> {{unique_name}}:
return super({{unique_name}}_List, self).__getitem__(idx)


{{/groups}}
3 changes: 3 additions & 0 deletions src/nasdaq_protocols/fix/parser/templates/init.mustache
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{{#modules}}
from .{{name}} import *
{{/modules}}
24 changes: 24 additions & 0 deletions src/nasdaq_protocols/fix/parser/templates/messages.mustache
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from nasdaq_protocols import fix
from . import {{module_prefix}}_groups as groups
from . import {{module_prefix}}_bodies as bodies


{{#messages}}
class {{name}}(fix.Message,
Name="{{name}}",
Type="{{tag}}",
Category="{{category}}",
HeaderCls=bodies.Header,
BodyCls=bodies.{{name}}Body,
TrailerCls=bodies.Trailer):
{{#entries}}
{{^is_group}}
{{field.name}}: {{field.type_hint}}
{{/is_group}}
{{#is_group}}
{{name}}: groups.{{unique_name}}_List
{{/is_group}}
{{/entries}}


{{/messages}}
Loading

0 comments on commit 3b41c5e

Please sign in to comment.