Skip to content

Commit

Permalink
message_parser: add option to override messages
Browse files Browse the repository at this point in the history
  • Loading branch information
SamDanielThangarajan committed Sep 24, 2024
1 parent 5571a19 commit c9c0cd1
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 9 deletions.
6 changes: 5 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[FORMAT]
max-line-length=180
class-naming-style=PascalCase
attr-naming-style=any

[MASTER]
ignore-paths=src/nasdaq_protocols/_version.py
Expand All @@ -12,4 +14,6 @@ disable=
R0903, #too-few-public-methods
R0902, #too-many-instance-attributes
W0707, #too-many-statements
R0801, #similarities
R0801, #similarities
R0913, #too-many-arguments
R0917, #too-many-positional-arguments
15 changes: 11 additions & 4 deletions src/nasdaq_protocols/common/message/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def get_codegen_context(self):

class Parser:
@staticmethod
def parse(file: str) -> Definitions:
def parse(file: str, override_messages: bool = False) -> Definitions:
root = ElementTree.parse(file).getroot()
elements = list(root)
definitions = Definitions()
Expand All @@ -170,7 +170,7 @@ def parse(file: str) -> Definitions:
elif element.tag == 'records-root':
definitions.records = {_.name: _ for _ in Parser._parse_records(element)}
elif element.tag == 'messages-root':
definitions.messages = Parser._parse_messages(element)
definitions.messages = Parser._parse_messages(element, override_messages)

return definitions

Expand All @@ -183,14 +183,21 @@ def _parse_records(element) -> list[RecordDef]:
return [RecordDef(record.get('id'), Parser._parse_fields(record[0])) for record in list(element)]

@staticmethod
def _parse_messages(element) -> list[MessageDef]:
def _parse_messages(element, override_messages: bool = False) -> list[MessageDef]:
def create_msg(child):
return MessageDef(
child.get('id'), child.get('message-id'), child.get('message-group'),
Parser._parse_fields(child[0] if len(child) else None),
child.get('direction')
)
return [create_msg(child) for child in list(element) if child is not None]
messages = {}
for child in list(element):
msg = create_msg(child)
key = f'{msg.id}-{msg.group}-{msg.direction}'
if key in messages and not override_messages:
raise ValueError(f'Message {key} already exists')
messages[key] = msg
return list(messages.values())

@staticmethod
def _parse_fields(element) -> list[FieldDef]:
Expand Down
5 changes: 3 additions & 2 deletions src/nasdaq_protocols/itch/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
@click.option('--app-name', type=click.STRING)
@click.option('--prefix', type=click.STRING, default='')
@click.option('--op-dir', type=click.Path(exists=True, writable=True))
@click.option('--override-messages/--no-override-messages', show_default=True, default=True)
@click.option('--init-file/--no-init-file', show_default=True, default=True)
def generate(spec_file, app_name, prefix, op_dir, init_file):
def generate(spec_file, app_name, prefix, op_dir, override_messages, init_file):
context = {
'record_type': 'Record',
}
generator = Generator(
Parser.parse(spec_file),
Parser.parse(spec_file, override_messages=override_messages),
'itch',
app_name,
op_dir,
Expand Down
5 changes: 3 additions & 2 deletions src/nasdaq_protocols/ouch/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
@click.option('--app-name', type=click.STRING)
@click.option('--prefix', type=click.STRING, default='')
@click.option('--op-dir', type=click.Path(exists=True, writable=True))
@click.option('--override-messages/--no-override-messages', show_default=True, default=True)
@click.option('--init-file/--no-init-file', show_default=True, default=True)
def generate(spec_file, app_name, op_dir, prefix, init_file):
def generate(spec_file, app_name, op_dir, prefix, override_messages, init_file):
context = {
'record_type': 'Record',
}
generator = Generator(
Parser.parse(spec_file),
Parser.parse(spec_file, override_messages=override_messages),
'ouch',
app_name,
op_dir,
Expand Down
18 changes: 18 additions & 0 deletions tests/test_common_message_parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from nasdaq_protocols.common.message import parser
from .testdata import *

Expand Down Expand Up @@ -176,3 +178,19 @@ def test__parser__messages__codegen_context(tmp_file_writer):
}
],
}


def test__parser__repeats_messages__no_overriding__parse(tmp_file_writer):
with pytest.raises(ValueError):
definitions = parser.Parser.parse(tmp_file_writer(TEST_XML_MESSAGES_REPEAT), override_messages=False)


def test__parser__repeats_messages__overriding__parse(tmp_file_writer):

definitions = parser.Parser.parse(tmp_file_writer(TEST_XML_MESSAGES_REPEAT), override_messages=True)

assert len(definitions.messages) == 1
assert 'test_message_extn' in [msg.name for msg in definitions.messages]

message: parser.MessageDef = definitions.messages[0]
assert len(message.fields) == 2
25 changes: 25 additions & 0 deletions tests/testdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,29 @@
</message>
</messages-root>
</root>
"""


TEST_XML_MESSAGES_REPEAT = """
<root>
<enums-root>
</enums-root>
<fielddef-root>
</fielddef-root>
<records-root>
</records-root>
<messages-root>
<message id="test_message" message-id="1" message-group="2" direction="incoming">
<fields>
<field name="msg_field1" type="uint_4" default="0"/>
</fields>
</message>
<message id="test_message_extn" message-id="1" message-group="2" direction="incoming">
<fields>
<field name="msg_field1" type="uint_4" default="0"/>
<field name="msg_field2" type="uint_4" default="0"/>
</fields>
</message>
</messages-root>
</root>
"""

0 comments on commit c9c0cd1

Please sign in to comment.