diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 9ae6e853..bc0b59dd 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -4,16 +4,30 @@ on: [push, pull_request] jobs: test: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + defaults: + run: + working-directory: python + steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: fetch-depth: 0 - - name: Test odin-data + - name: Set up python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install odin-data python pacakge run: | - python3 -m venv venv && source venv/bin/activate && pip install --upgrade pip - cd python + python -m pip install --upgrade pip pip install .[dev,meta_writer] python -c "from odin_data import __version__; print(__version__)" - pytest tests/test_hdf5dataset.py + + - name: Run tests + run: + pytest -vs --cov=odin_data --cov-report=term-missing diff --git a/python/setup.cfg b/python/setup.cfg index d322cd97..66cc6b8b 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -14,6 +14,8 @@ classifiers = Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 [options] packages = find: @@ -21,7 +23,7 @@ package_dir = =src install_requires = - odin-control @ git+https://git@github.com/odin-detector/odin-control.git@1.3.0 + odin-control @ git+https://git@github.com/odin-detector/odin-control.git@1.4.0 posix_ipc>=1.0.4 pysnmp>=4.4.4 numpy>=1.14.0 @@ -35,6 +37,7 @@ meta_writer = # For development tests/docs dev = pytest + pytest-cov # Docs sphinx-autobuild sphinx-external-toc diff --git a/python/src/odin_data/control/live_view_proxy_adapter.py b/python/src/odin_data/control/live_view_proxy_adapter.py index 8b386633..6cd25310 100644 --- a/python/src/odin_data/control/live_view_proxy_adapter.py +++ b/python/src/odin_data/control/live_view_proxy_adapter.py @@ -61,7 +61,7 @@ def __init__(self, **kwargs): self.publish_channel.bind() except ZMQError as channel_err: # ZMQError raised here if the socket addr is already in use. - logging.error("Connection Failed. Error given: %s", channel_err.message) + logging.error("Connection Failed. Error given: %s", str(channel_err)) self.max_queue = self.options.get(QUEUE_LENGTH_CONFIG_NAME, DEFAULT_QUEUE_LENGTH) if SOURCE_ENDPOINTS_CONFIG_NAME in self.options: @@ -76,7 +76,7 @@ def __init__(self, **kwargs): self.drop_warn_percent, self.add_to_queue)) except (ValueError, ZMQError): - logging.debug("Error parsing target list: %s", target_str) + logging.error("Error parsing target list: %s", target_str) else: self.source_endpoints = [LiveViewProxyNode( "node_1", @@ -84,6 +84,9 @@ def __init__(self, **kwargs): self.drop_warn_percent, self.add_to_queue)] + self.last_sent_frame = (0, 0) + self.dropped_frame_count = 0 + tree = { "target_endpoint": (lambda: self.dest_endpoint, None), 'last_sent_frame': (lambda: self.last_sent_frame, None), @@ -98,9 +101,6 @@ def __init__(self, **kwargs): self.queue = PriorityQueue(self.max_queue) - self.last_sent_frame = (0, 0) - self.dropped_frame_count = 0 - self.get_frame_from_queue() def cleanup(self): diff --git a/python/tests/test_hdf5dataset.py b/python/tests/test_hdf5dataset.py index ba9b2cbb..615d3a87 100644 --- a/python/tests/test_hdf5dataset.py +++ b/python/tests/test_hdf5dataset.py @@ -1,11 +1,13 @@ import os -from unittest import TestCase -import numpy import time import h5py as h5 +import numpy -from odin_data.meta_writer.hdf5dataset import HDF5UnlimitedCache, StringHDF5Dataset +from odin_data.meta_writer.hdf5dataset import ( + HDF5UnlimitedCache, + StringHDF5Dataset, +) class _TestMockDataset: @@ -27,7 +29,7 @@ def __setitem__(self, key, value): self.values.append(value) -class TestHDF5Dataset(TestCase): +class TestHDF5Dataset: def test_unlimited_cache_1D(self): """verify that the cache functions as expected""" @@ -42,70 +44,68 @@ def test_unlimited_cache_1D(self): ) # Verify 1 block has been created, with a size of 10 - self.assertEqual(len(cache._blocks), 1) - self.assertEqual(0 in cache._blocks, True) - self.assertEqual(cache._blocks[0].has_new_data, True) - self.assertEqual(cache._blocks[0].data.size, 10) + assert len(cache._blocks) == 1 + assert 0 in cache._blocks + assert cache._blocks[0].has_new_data + assert cache._blocks[0].data.size == 10 # Add 11 items at increasing indexes and verify another # block is created for offset in range(11): cache.add_value(2, offset) - self.assertEqual(len(cache._blocks), 2) - self.assertEqual(0 in cache._blocks, True) - self.assertEqual(1 in cache._blocks, True) - self.assertEqual(cache._blocks[0].has_new_data, True) - self.assertEqual(cache._blocks[0].data.size, 10) - self.assertEqual(sum(cache._blocks[0].data), 20) - self.assertEqual(cache._blocks[1].has_new_data, True) - self.assertEqual(cache._blocks[1].data.size, 10) + assert len(cache._blocks) == 2 + assert 0 in cache._blocks + assert 1 in cache._blocks + assert cache._blocks[0].has_new_data + assert cache._blocks[0].data.size == 10 + assert sum(cache._blocks[0].data) == 20 + assert cache._blocks[1].has_new_data + assert cache._blocks[1].data.size == 10 # Add 1 item at index to create another block cache.add_value(3, 25) - self.assertEqual(len(cache._blocks), 3) - self.assertEqual(0 in cache._blocks, True) - self.assertEqual(1 in cache._blocks, True) - self.assertEqual(2 in cache._blocks, True) - self.assertEqual(cache._blocks[0].has_new_data, True) - self.assertEqual(cache._blocks[0].data.size, 10) - self.assertEqual(sum(cache._blocks[0].data), 20) - self.assertEqual(cache._blocks[1].has_new_data, True) - self.assertEqual(cache._blocks[1].data.size, 10) - self.assertEqual(sum(cache._blocks[1].data), -16) # 2 + (9 * -2) fillvalue - self.assertEqual(cache._blocks[2].has_new_data, True) - self.assertEqual(cache._blocks[2].data.size, 10) - self.assertEqual(sum(cache._blocks[2].data), -15) # 3 + (9 * -2) fillvalue + assert len(cache._blocks) == 3 + assert 0 in cache._blocks + assert 1 in cache._blocks + assert 2 in cache._blocks + assert cache._blocks[0].has_new_data + assert cache._blocks[0].data.size, 10 + assert sum(cache._blocks[0].data), 20 + assert cache._blocks[1].has_new_data + assert cache._blocks[1].data.size, 10 + assert sum(cache._blocks[1].data), -16 # 2 + (9 * -2) fillvalue + assert cache._blocks[2].has_new_data + assert cache._blocks[2].data.size, 10 + assert sum(cache._blocks[2].data), -15 # 3 + (9 * -2) fillvalue # Now flush the blocks ds = _TestMockDataset() cache.flush(ds) # Verify the flushed slices are expected (2 full blocks and the last partial) - self.assertEqual( - ds.keys, [slice(0, 10, None), slice(10, 20, None), slice(20, 26, None)] - ) - self.assertEqual(ds.values[0][0], 2) - self.assertEqual(ds.values[0][1], 2) - self.assertEqual(ds.values[0][2], 2) - self.assertEqual(ds.values[0][3], 2) - self.assertEqual(ds.values[0][4], 2) - self.assertEqual(ds.values[0][5], 2) - self.assertEqual(ds.values[0][6], 2) - self.assertEqual(ds.values[0][7], 2) - self.assertEqual(ds.values[0][8], 2) - self.assertEqual(ds.values[0][9], 2) - self.assertEqual(ds.values[1][0], 2) - self.assertEqual(ds.values[1][1], -2) - self.assertEqual(ds.values[1][2], -2) - self.assertEqual(ds.values[1][3], -2) - self.assertEqual(ds.values[1][4], -2) - self.assertEqual(ds.values[1][5], -2) - self.assertEqual(ds.values[1][6], -2) - self.assertEqual(ds.values[1][7], -2) - self.assertEqual(ds.values[1][8], -2) - self.assertEqual(ds.values[1][9], -2) + assert ds.keys == [slice(0, 10, None), slice(10, 20, None), slice(20, 26, None)] + assert ds.values[0][0] == 2 + assert ds.values[0][1] == 2 + assert ds.values[0][2] == 2 + assert ds.values[0][3] == 2 + assert ds.values[0][4] == 2 + assert ds.values[0][5] == 2 + assert ds.values[0][6] == 2 + assert ds.values[0][7] == 2 + assert ds.values[0][8] == 2 + assert ds.values[0][9] == 2 + assert ds.values[1][0] == 2 + assert ds.values[1][1] == -2 + assert ds.values[1][2] == -2 + assert ds.values[1][3] == -2 + assert ds.values[1][4] == -2 + assert ds.values[1][5] == -2 + assert ds.values[1][6] == -2 + assert ds.values[1][7] == -2 + assert ds.values[1][8] == -2 + assert ds.values[1][9] == -2 # Now wait time.sleep(0.1) @@ -118,22 +118,22 @@ def test_unlimited_cache_1D(self): cache.flush(ds) cache.purge_blocks() - self.assertEqual(len(cache._blocks), 1) - self.assertEqual(0 in cache._blocks, False) - self.assertEqual(1 in cache._blocks, False) - self.assertEqual(2 in cache._blocks, True) - self.assertEqual(cache._blocks[2].has_new_data, False) - self.assertEqual(cache._blocks[2].data.size, 10) - self.assertEqual(sum(cache._blocks[2].data), -10) # 6 + (8 * -2) fillvalue - - self.assertEqual(ds.keys, [slice(20, 27, None)]) - self.assertEqual(ds.values[0][0], -2) - self.assertEqual(ds.values[0][1], -2) - self.assertEqual(ds.values[0][2], -2) - self.assertEqual(ds.values[0][3], -2) - self.assertEqual(ds.values[0][4], -2) - self.assertEqual(ds.values[0][5], 3) - self.assertEqual(ds.values[0][6], 3) + assert len(cache._blocks) == 1 + assert 0 not in cache._blocks + assert 1 not in cache._blocks + assert 2 in cache._blocks + assert not cache._blocks[2].has_new_data + assert cache._blocks[2].data.size == 10 + assert sum(cache._blocks[2].data) == -10 # 6 + (8 * -2) fillvalue + + assert ds.keys == [slice(20, 27, None)] + assert ds.values[0][0] == -2 + assert ds.values[0][1] == -2 + assert ds.values[0][2] == -2 + assert ds.values[0][3] == -2 + assert ds.values[0][4] == -2 + assert ds.values[0][5] == 3 + assert ds.values[0][6] == 3 def test_unlimited_cache_2D(self): """verify that the cache functions as expected""" @@ -149,10 +149,10 @@ def test_unlimited_cache_2D(self): ) # Verify 1 block has been created, with a size of 10x2x3 - self.assertEqual(len(cache._blocks), 1) - self.assertEqual(0 in cache._blocks, True) - self.assertEqual(cache._blocks[0].has_new_data, True) - self.assertEqual(cache._blocks[0].data.size, 60) + assert len(cache._blocks) == 1 + assert 0 in cache._blocks + assert cache._blocks[0].has_new_data + assert cache._blocks[0].data.size == 60 # Add 11 items at increasing indexes and verify another # block is created @@ -160,62 +160,56 @@ def test_unlimited_cache_2D(self): for offset in range(11): cache.add_value(value, offset) - self.assertEqual(len(cache._blocks), 2) - self.assertEqual(0 in cache._blocks, True) - self.assertEqual(1 in cache._blocks, True) - self.assertEqual(cache._blocks[0].has_new_data, True) - self.assertEqual(cache._blocks[0].data.size, 60) - self.assertEqual(numpy.sum(cache._blocks[0].data), 210) - self.assertEqual(cache._blocks[1].has_new_data, True) - self.assertEqual(cache._blocks[1].data.size, 60) + assert len(cache._blocks) == 2 + assert 0 in cache._blocks + assert 1 in cache._blocks + assert cache._blocks[0].has_new_data + assert cache._blocks[0].data.size == 60 + assert numpy.sum(cache._blocks[0].data) == 210 + assert cache._blocks[1].has_new_data + assert cache._blocks[1].data.size == 60 # Add 1 item at index to create another block cache.add_value(value, 25) - self.assertEqual(len(cache._blocks), 3) - self.assertEqual(0 in cache._blocks, True) - self.assertEqual(1 in cache._blocks, True) - self.assertEqual(2 in cache._blocks, True) - self.assertEqual(cache._blocks[0].has_new_data, True) - self.assertEqual(cache._blocks[0].data.size, 60) - self.assertEqual(numpy.sum(cache._blocks[0].data), 210) - self.assertEqual(cache._blocks[1].has_new_data, True) - self.assertEqual(cache._blocks[1].data.size, 60) - self.assertEqual( - numpy.sum(cache._blocks[1].data), -87 - ) # 21 + (9 * -12) fillvalue - self.assertEqual(cache._blocks[2].has_new_data, True) - self.assertEqual(cache._blocks[2].data.size, 60) - self.assertEqual( - numpy.sum(cache._blocks[2].data), -87 - ) # 21 + (9 * -12) fillvalue + assert len(cache._blocks) == 3 + assert 0 in cache._blocks + assert 1 in cache._blocks + assert 2 in cache._blocks + assert cache._blocks[0].has_new_data + assert cache._blocks[0].data.size == 60 + assert numpy.sum(cache._blocks[0].data) == 210 + assert cache._blocks[1].has_new_data + assert cache._blocks[1].data.size == 60 + assert numpy.sum(cache._blocks[1].data) == -87 # 21 + (9 * -12) fillvalue + assert cache._blocks[2].has_new_data + assert cache._blocks[2].data.size == 60 + assert numpy.sum(cache._blocks[2].data) == -87 # 21 + (9 * -12) fillvalue # Now flush the blocks ds = _TestMockDataset() cache.flush(ds) # Verify the flushed slices are expected (2 full blocks and the last partial) - self.assertEqual( - ds.keys, [slice(0, 10, None), slice(10, 20, None), slice(20, 26, None)] - ) - self.assertEqual(ds.values[0][0][0][0], 1) - self.assertEqual(ds.values[0][0][0][1], 2) - self.assertEqual(ds.values[0][0][0][2], 3) - self.assertEqual(ds.values[0][0][1][0], 4) - self.assertEqual(ds.values[0][0][1][1], 5) - self.assertEqual(ds.values[0][0][1][2], 6) - self.assertEqual(ds.values[0][5][0][0], 1) - self.assertEqual(ds.values[0][5][0][1], 2) - self.assertEqual(ds.values[0][5][0][2], 3) - self.assertEqual(ds.values[0][5][1][0], 4) - self.assertEqual(ds.values[0][5][1][1], 5) - self.assertEqual(ds.values[0][1][1][2], 6) - self.assertEqual(ds.values[1][1][0][0], -2) - self.assertEqual(ds.values[1][1][0][1], -2) - self.assertEqual(ds.values[1][1][0][2], -2) - self.assertEqual(ds.values[1][1][1][0], -2) - self.assertEqual(ds.values[1][1][1][1], -2) - self.assertEqual(ds.values[1][1][1][2], -2) + assert ds.keys == [slice(0, 10, None), slice(10, 20, None), slice(20, 26, None)] + assert ds.values[0][0][0][0] == 1 + assert ds.values[0][0][0][1] == 2 + assert ds.values[0][0][0][2] == 3 + assert ds.values[0][0][1][0] == 4 + assert ds.values[0][0][1][1] == 5 + assert ds.values[0][0][1][2] == 6 + assert ds.values[0][5][0][0] == 1 + assert ds.values[0][5][0][1] == 2 + assert ds.values[0][5][0][2] == 3 + assert ds.values[0][5][1][0] == 4 + assert ds.values[0][5][1][1] == 5 + assert ds.values[0][1][1][2] == 6 + assert ds.values[1][1][0][0] == -2 + assert ds.values[1][1][0][1] == -2 + assert ds.values[1][1][0][2] == -2 + assert ds.values[1][1][1][0] == -2 + assert ds.values[1][1][1][1] == -2 + assert ds.values[1][1][1][2] == -2 # Now wait time.sleep(0.1) @@ -228,37 +222,35 @@ def test_unlimited_cache_2D(self): cache.flush(ds) cache.purge_blocks() - self.assertEqual(len(cache._blocks), 1) - self.assertEqual(0 in cache._blocks, False) - self.assertEqual(1 in cache._blocks, False) - self.assertEqual(2 in cache._blocks, True) - self.assertEqual(cache._blocks[2].has_new_data, False) - self.assertEqual(cache._blocks[2].data.size, 60) - self.assertEqual( - numpy.sum(cache._blocks[2].data), -54 - ) # 42 + (8 * -12) fillvalue - - self.assertEqual(ds.keys, [slice(20, 27, None)]) - self.assertEqual(ds.values[0][0][0][0], -2) - self.assertEqual(ds.values[0][0][0][1], -2) - self.assertEqual(ds.values[0][0][0][2], -2) - self.assertEqual(ds.values[0][0][1][0], -2) - self.assertEqual(ds.values[0][0][1][1], -2) - self.assertEqual(ds.values[0][0][1][2], -2) - - self.assertEqual(ds.values[0][5][0][0], 1) - self.assertEqual(ds.values[0][5][0][1], 2) - self.assertEqual(ds.values[0][5][0][2], 3) - self.assertEqual(ds.values[0][5][1][0], 4) - self.assertEqual(ds.values[0][5][1][1], 5) - self.assertEqual(ds.values[0][5][1][2], 6) - - self.assertEqual(ds.values[0][6][0][0], 1) - self.assertEqual(ds.values[0][6][0][1], 2) - self.assertEqual(ds.values[0][6][0][2], 3) - self.assertEqual(ds.values[0][6][1][0], 4) - self.assertEqual(ds.values[0][6][1][1], 5) - self.assertEqual(ds.values[0][6][1][2], 6) + assert len(cache._blocks) == 1 + assert 0 not in cache._blocks + assert 1 not in cache._blocks + assert 2 in cache._blocks + assert not cache._blocks[2].has_new_data + assert cache._blocks[2].data.size == 60 + assert numpy.sum(cache._blocks[2].data) == -54 # 42 + (8 * -12) fillvalue + + assert ds.keys == [slice(20, 27, None)] + assert ds.values[0][0][0][0] == -2 + assert ds.values[0][0][0][1] == -2 + assert ds.values[0][0][0][2] == -2 + assert ds.values[0][0][1][0] == -2 + assert ds.values[0][0][1][1] == -2 + assert ds.values[0][0][1][2] == -2 + + assert ds.values[0][5][0][0] == 1 + assert ds.values[0][5][0][1] == 2 + assert ds.values[0][5][0][2] == 3 + assert ds.values[0][5][1][0] == 4 + assert ds.values[0][5][1][1] == 5 + assert ds.values[0][5][1][2] == 6 + + assert ds.values[0][6][0][0] == 1 + assert ds.values[0][6][0][1] == 2 + assert ds.values[0][6][0][2] == 3 + assert ds.values[0][6][1][0] == 4 + assert ds.values[0][6][1][1] == 5 + assert ds.values[0][6][1][2] == 6 def test_string_types(): diff --git a/python/tests/test_ipc_channel.py b/python/tests/test_ipc_channel.py index c5665b0d..45b79838 100644 --- a/python/tests/test_ipc_channel.py +++ b/python/tests/test_ipc_channel.py @@ -1,43 +1,41 @@ -from nose.tools import assert_equal +import pytest from odin_data.control.ipc_channel import IpcChannel -class TestIpcChannel: - - @classmethod - def setup_class(cls): +rx_endpoint = "inproc://rx_channel" - cls.endpoint = "inproc://rx_channel" - cls.send_channel = IpcChannel(IpcChannel.CHANNEL_TYPE_PAIR) - cls.recv_channel = IpcChannel(IpcChannel.CHANNEL_TYPE_PAIR) - cls.send_channel.bind(cls.endpoint) - cls.recv_channel.connect(cls.endpoint) +@pytest.fixture(scope="class") +def send_channel(): + channel = IpcChannel(IpcChannel.CHANNEL_TYPE_PAIR) + channel.bind(rx_endpoint) + yield channel + channel.close() - @classmethod - def teardown_class(cls): - cls.recv_channel.close() - cls.send_channel.close() +@pytest.fixture(scope="class") +def recv_channel(): + channel = IpcChannel(IpcChannel.CHANNEL_TYPE_PAIR) + channel.connect(rx_endpoint) + yield channel + channel.close() - def test_basic_send_receive(self): +class TestIpcChannel: + def test_basic_send_receive(self, send_channel, recv_channel): msg = "This is a test message" - self.send_channel.send(msg) + send_channel.send(msg) - reply = self.recv_channel.recv() - - assert_equal(msg, reply) - assert_equal(type(msg), type(reply)) + reply = recv_channel.recv() + assert msg == reply + assert type(msg) == type(reply) def test_dealer_router(self): - - endpoint = 'inproc://dr_channel' + endpoint = "inproc://dr_channel" msg = "This is a dealer router message" - dealer_indentity = 'test_dealer' + dealer_indentity = "test_dealer" - dealer_channel = IpcChannel(IpcChannel.CHANNEL_TYPE_DEALER, - identity=dealer_indentity) + dealer_channel = IpcChannel(IpcChannel.CHANNEL_TYPE_DEALER, identity=dealer_indentity) router_channel = IpcChannel(IpcChannel.CHANNEL_TYPE_ROUTER) router_channel.bind(endpoint) @@ -46,7 +44,7 @@ def test_dealer_router(self): dealer_channel.send(msg) (recv_identity, reply) = router_channel.recv() - assert_equal(len(reply), 1) - assert_equal(msg, reply[0]) - assert_equal(type(msg), type(reply[0])) - assert_equal(dealer_indentity, recv_identity) + assert len(reply) == 1 + assert msg == reply[0] + assert type(msg) == type(reply[0]) + assert dealer_indentity == recv_identity diff --git a/python/tests/test_ipc_message.py b/python/tests/test_ipc_message.py index a36339a6..4cff9f45 100644 --- a/python/tests/test_ipc_message.py +++ b/python/tests/test_ipc_message.py @@ -1,9 +1,8 @@ -from nose.tools import assert_equals, assert_raises, assert_true, assert_false,\ - assert_equal, assert_not_equal, assert_regexp_matches +import pytest from odin_data.control.ipc_message import IpcMessage, IpcMessageException -def test_valid_ipc_msg_from_string(): +def test_valid_ipc_msg_from_string(): # Instantiate a valid message from a JSON string json_str = """ @@ -24,44 +23,43 @@ def test_valid_ipc_msg_from_string(): the_msg = IpcMessage(from_str=json_str) # Check the message is indeed valid - assert_true(the_msg.is_valid()) + assert the_msg.is_valid() # Check that all attributes are as expected - assert_equals(the_msg.get_msg_type(), "cmd") - assert_equals(the_msg.get_msg_val(), "status") - assert_equals(the_msg.get_msg_timestamp(), "2015-01-27T15:26:01.123456") - assert_equals(the_msg.get_msg_id(), 322) + assert the_msg.get_msg_type() == "cmd" + assert the_msg.get_msg_val() == "status" + assert the_msg.get_msg_timestamp() == "2015-01-27T15:26:01.123456" + assert the_msg.get_msg_id() == 322 # Check that all parameters are as expected - assert_equals(the_msg.get_param("paramInt"), 1234) - assert_equals(the_msg.get_param("paramStr"), "testParam") - assert_equals(the_msg.get_param("paramDouble"), 3.1415) + assert the_msg.get_param("paramInt") == 1234 + assert the_msg.get_param("paramStr") == "testParam" + assert the_msg.get_param("paramDouble") == 3.1415 # Check valid message throws an exception on missing parameter - with assert_raises(IpcMessageException) as cm: - missingParam = the_msg.get_param("missingParam") - ex = cm.exception - assert_equals(ex.msg, 'Missing parameter missingParam') + with pytest.raises(IpcMessageException) as excinfo: + _ = the_msg.get_param("missingParam") + assert "Missing parameter missingParam" in str(excinfo.value) # Check valid message can fall back to default value if parameter missing defaultParamValue = 90210 - assert_equals(the_msg.get_param("missingParam", defaultParamValue), defaultParamValue) + assert the_msg.get_param("missingParam", defaultParamValue) == defaultParamValue -def test_empty_ipc_msg_invalid(): +def test_empty_ipc_msg_invalid(): # Instantiate an empty message the_msg = IpcMessage() # Check that the message is not valid - assert_false(the_msg.is_valid()) + assert not the_msg.is_valid() -def test_filled_ipc_msg_valid(): +def test_filled_ipc_msg_valid(): # Instantiate an empty Message the_msg = IpcMessage() # Check that empty message is not valid - assert_false(the_msg.is_valid()) + assert not the_msg.is_valid() # Set the message type, Value and id msg_type = "cmd" @@ -72,41 +70,41 @@ def test_filled_ipc_msg_valid(): the_msg.set_msg_id(msg_id) # Check that the message is now valid - assert_true(the_msg.is_valid()) + assert the_msg.is_valid() -def test_create_modify_empty_msg_params(): +def test_create_modify_empty_msg_params(): # Instantiate an empty message the_msg = IpcMessage() # Define and set some parameters - paramInt1 = 1234; - paramInt2 = 901201; - paramInt3 = 4567; + paramInt1 = 1234 + paramInt2 = 901201 + paramInt3 = 4567 paramStr = "paramString" - the_msg.set_param('paramInt1', paramInt1) - the_msg.set_param('paramInt2', paramInt2) - the_msg.set_param('paramInt3', paramInt3) - the_msg.set_param('paramStr', paramStr) + the_msg.set_param("paramInt1", paramInt1) + the_msg.set_param("paramInt2", paramInt2) + the_msg.set_param("paramInt3", paramInt3) + the_msg.set_param("paramStr", paramStr) # Read them back and check they have the correct value - assert_true(the_msg.get_param('paramInt1'), paramInt1) - assert_true(the_msg.get_param('paramInt2'), paramInt2) - assert_true(the_msg.get_param('paramInt3'), paramInt3) - assert_true(the_msg.get_param('paramStr'), paramStr) + assert the_msg.get_param("paramInt1") == paramInt1 + assert the_msg.get_param("paramInt2") == paramInt2 + assert the_msg.get_param("paramInt3") == paramInt3 + assert the_msg.get_param("paramStr") == paramStr # Modify several parameters and check they are still correct - paramInt2 = 228724; - the_msg.set_param('paramInt2', paramInt2) + paramInt2 = 228724 + the_msg.set_param("paramInt2", paramInt2) paramStr = "another string" the_msg.set_param("paramStr", paramStr) - assert_true(the_msg.get_param('paramInt2'), paramInt2) - assert_true(the_msg.get_param('paramStr'), paramStr) + assert the_msg.get_param("paramInt2") == paramInt2 + assert the_msg.get_param("paramStr") == paramStr -def test_round_trip_from_empty_msg(): +def test_round_trip_from_empty_msg(): # Instantiate an empty message the_msg = IpcMessage() @@ -118,16 +116,16 @@ def test_round_trip_from_empty_msg(): msg_id = 61616 the_msg.set_msg_id(msg_id) - # Define and set some parameters - paramInt1 = 1234; - paramInt2 = 901201; - paramInt3 = 4567; + # Define and set some parameters + paramInt1 = 1234 + paramInt2 = 901201 + paramInt3 = 4567 paramStr = "paramString" - the_msg.set_param('paramInt1', paramInt1) - the_msg.set_param('paramInt2', paramInt2) - the_msg.set_param('paramInt3', paramInt3) - the_msg.set_param('paramStr', paramStr) + the_msg.set_param("paramInt1", paramInt1) + the_msg.set_param("paramInt2", paramInt2) + the_msg.set_param("paramInt3", paramInt3) + the_msg.set_param("paramStr", paramStr) # Retrieve the encoded version the_msg_encoded = the_msg.encode() @@ -136,17 +134,17 @@ def test_round_trip_from_empty_msg(): msg_from_encoded = IpcMessage(from_str=the_msg_encoded) # Validate the contents of all attributes and parameters of the new message - assert_equal(msg_from_encoded.get_msg_type(), msg_type) - assert_equal(msg_from_encoded.get_msg_val(), msg_val) - assert_equal(msg_from_encoded.get_msg_timestamp(), the_msg.get_msg_timestamp()) - assert_equal(msg_from_encoded.get_msg_id(), the_msg.get_msg_id()) - assert_equal(msg_from_encoded.get_param('paramInt1'), paramInt1) - assert_equal(msg_from_encoded.get_param('paramInt2'), paramInt2) - assert_equal(msg_from_encoded.get_param('paramInt3'), paramInt3) - assert_equal(msg_from_encoded.get_param('paramStr'), paramStr) + assert msg_from_encoded.get_msg_type() == msg_type + assert msg_from_encoded.get_msg_val() == msg_val + assert msg_from_encoded.get_msg_timestamp() == the_msg.get_msg_timestamp() + assert msg_from_encoded.get_msg_id() == the_msg.get_msg_id() + assert msg_from_encoded.get_param("paramInt1") == paramInt1 + assert msg_from_encoded.get_param("paramInt2") == paramInt2 + assert msg_from_encoded.get_param("paramInt3") == paramInt3 + assert msg_from_encoded.get_param("paramStr") == paramStr -def test_round_trip_from_empty_msg_comparison(): +def test_round_trip_from_empty_msg_comparison(): # Instantiate an empty message the_msg = IpcMessage() @@ -158,16 +156,16 @@ def test_round_trip_from_empty_msg_comparison(): msg_id = 61616 the_msg.set_msg_id(msg_id) - # Define and set some parameters - paramInt1 = 1234; - paramInt2 = 901201; - paramInt3 = 4567; + # Define and set some parameters + paramInt1 = 1234 + paramInt2 = 901201 + paramInt3 = 4567 paramStr = "paramString" - the_msg.set_param('paramInt1', paramInt1) - the_msg.set_param('paramInt2', paramInt2) - the_msg.set_param('paramInt3', paramInt3) - the_msg.set_param('paramStr', paramStr) + the_msg.set_param("paramInt1", paramInt1) + the_msg.set_param("paramInt2", paramInt2) + the_msg.set_param("paramInt3", paramInt3) + the_msg.set_param("paramStr", paramStr) # Retrieve the encoded version the_msg_encoded = the_msg.encode() @@ -176,12 +174,11 @@ def test_round_trip_from_empty_msg_comparison(): msg_from_encoded = IpcMessage(from_str=the_msg_encoded) # Test that the comparison operators work correctly - assert_true(the_msg == msg_from_encoded) - assert_false(the_msg != msg_from_encoded) + assert the_msg == msg_from_encoded + assert not (the_msg != msg_from_encoded) -def test_invalid_msg_from_string(): - with assert_raises(IpcMessageException) as cm: - invalid_msg = IpcMessage(from_str="{\"wibble\" : \"wobble\" \"shouldnt be here\"}") - ex = cm.exception - assert_regexp_matches(ex.msg, "Illegal message JSON format*") +def test_invalid_msg_from_string(): + with pytest.raises(IpcMessageException) as excinfo: + _ = IpcMessage(from_str='{"wibble" : "wobble" "shouldnt be here"}') + assert "Illegal message JSON format" in str(excinfo.value) diff --git a/python/tests/test_live_view_proxy.py b/python/tests/test_live_view_proxy.py index 332e4e57..734d9b2e 100644 --- a/python/tests/test_live_view_proxy.py +++ b/python/tests/test_live_view_proxy.py @@ -1,238 +1,276 @@ +import logging import sys -from nose.tools import assert_equals, assert_true, assert_false,\ - assert_equal, assert_not_equal - +import pytest from tornado.escape import json_encode -from odin_data.control.live_view_proxy_adapter import LiveViewProxyAdapter, LiveViewProxyNode, Frame, \ - DEFAULT_DEST_ENDPOINT, DEFAULT_DROP_WARN_PERCENT, DEFAULT_QUEUE_LENGTH, DEFAULT_SOURCE_ENDPOINT +from odin_data.control.live_view_proxy_adapter import ( + DEFAULT_DEST_ENDPOINT, + DEFAULT_DROP_WARN_PERCENT, + DEFAULT_QUEUE_LENGTH, + DEFAULT_SOURCE_ENDPOINT, + Frame, + LiveViewProxyAdapter, + LiveViewProxyNode, +) -from odin.testing.utils import OdinTestServer +if sys.version_info[0] == 3: # pragma: no cover + from unittest.mock import Mock +else: # pragma: no cover + from mock import Mock -from zmq.error import ZMQError -if sys.version_info[0] == 3: # pragma: no cover - from unittest.mock import Mock, patch -else: # pragma: no cover - from mock import Mock, patch +def log_message_seen(caplog, level, message, when="call"): + for record in caplog.get_records(when): + if record.levelno == level and message in record.getMessage(): + return True + return False -class TestFrame(): - @classmethod - def setup(self): - self.frame_1_header = {"frame_num": 1} - self.frame_1 = Frame([json_encode({"frame_num": 1}), 0]) - self.frame_2 = Frame([json_encode({"frame_num": 2}), 0]) +@pytest.fixture() +def test_frames(): + frames = [Frame([json_encode({"frame_num": 1}), 0]), Frame([json_encode({"frame_num": 2}), 0])] + return frames - def test_frame_init(self): - assert_equal(self.frame_1.num, self.frame_1_header["frame_num"]) - def test_frame_get_header(self): - assert_equal(self.frame_1.get_header(), json_encode(self.frame_1_header)) +class TestFrame: + def test_frame_init(self, test_frames): + assert test_frames[0].num == 1 - def test_frame_get_data(self): - assert_equal(self.frame_1.data, 0) + def test_frame_get_header(self, test_frames): + assert test_frames[0].get_header() == json_encode({"frame_num": test_frames[0].num}) - def test_frame_set_acq(self): - self.frame_1.set_acq(5) - assert_equal(self.frame_1.acq_id, 5) + def test_frame_get_data(self, test_frames): + assert test_frames[0].data == 0 - def test_frame_order(self): - assert_true(self.frame_1 < self.frame_2) + def test_frame_set_acq(self, test_frames): + test_frames[0].set_acq(5) + assert test_frames[0].acq_id == 5 - def test_frame_order_acq(self): - self.frame_1.set_acq(2) - self.frame_2.set_acq(1) + def test_frame_order(self, test_frames): + assert test_frames[0] < test_frames[1] - assert_true(self.frame_1 > self.frame_2) + def test_frame_order_acq(self, test_frames): + test_frames[0].set_acq(2) + test_frames[1].set_acq(1) + assert test_frames[0] > test_frames[1] + + +@pytest.fixture() +def proxy_node_test(): + class ProxyNodeTest: + def __init__(self): + self.callback_count = 0 + self.node_1_info = {"name": "test_node_1", "endpoint": "tcp://127.0.0.1:5010"} + self.node_2_info = {"name": "test_node_2", "endpoint": "tcp://127.0.0.1:5011"} + self.node_1 = LiveViewProxyNode( + self.node_1_info["name"], self.node_1_info["endpoint"], 0.5, self.callback + ) + self.node_2 = LiveViewProxyNode( + self.node_2_info["name"], self.node_2_info["endpoint"], 0.5, self.callback + ) + + self.test_frame = Frame([json_encode({"frame_num": 1}), "0"]) + + def callback(self, frame, source): + self.callback_count += 1 + + return ProxyNodeTest() -class TestProxyNode(object): - def callback(self, frame, source): - self.callback_count += 1 - - def setup(self): - - self.callback_count = 0 - self.node_1_info = {"name": "test_node_1", "endpoint": "tcp://127.0.0.1:5010"} - self.node_2_info = {"name": "test_node_2", "endpoint": "tcp://127.0.0.1:5011"} - self.node_1 = LiveViewProxyNode( - self.node_1_info["name"], - self.node_1_info["endpoint"], - 0.5, - self.callback) - self.node_2 = LiveViewProxyNode( - self.node_2_info["name"], - self.node_2_info["endpoint"], - 0.5, - self.callback) - - self.test_frame = Frame([json_encode({"frame_num": 1}), "0"]) - - @classmethod - def teardown(self): - self.callback_count = 0 - - def test_node_init(self): - assert_equal(self.node_1.name, self.node_1_info["name"]) - assert_equal(self.node_1.endpoint, self.node_1_info["endpoint"]) - - def test_node_reset(self): - self.node_1.received_frame_count = 10 - self.node_1.dropped_frame_count = 5 - self.node_1.has_warned = True - self.node_1.set_reset() - assert_equal(self.node_1.received_frame_count, 0) - assert_equal(self.node_1.dropped_frame_count, 0) - assert_false(self.node_1.has_warned) - - def test_node_dropped_frame(self): - self.node_1.received_frame_count = 10 - for _ in range(9): - self.node_1.dropped_frame() - - assert_equal(self.node_1.dropped_frame_count, 9) - assert_true(self.node_1.has_warned) - - def test_node_callback(self): +class TestProxyNode(object): + def test_node_init(self, proxy_node_test): + assert proxy_node_test.node_1.name, proxy_node_test.node_1_info["name"] + assert proxy_node_test.node_1.endpoint, proxy_node_test.node_1_info["endpoint"] + + def test_node_reset(self, proxy_node_test): + proxy_node_test.node_1.received_frame_count = 10 + proxy_node_test.node_1.dropped_frame_count = 5 + proxy_node_test.node_1.has_warned = True + proxy_node_test.node_1.set_reset() + assert proxy_node_test.node_1.received_frame_count == 0 + assert proxy_node_test.node_1.dropped_frame_count == 0 + assert not proxy_node_test.node_1.has_warned + + def test_node_dropped_frame(self, proxy_node_test): + proxy_node_test.node_1.received_frame_count = 10 + num_dropped = 9 + for _ in range(num_dropped): + proxy_node_test.node_1.dropped_frame() + + assert proxy_node_test.node_1.dropped_frame_count == num_dropped + assert proxy_node_test.node_1.has_warned + + def test_node_callback(self, proxy_node_test): test_header = {"frame_num": 1} for i in range(10): test_header["frame_num"] = i - self.node_1.local_callback([json_encode(test_header), 0]) + proxy_node_test.node_1.local_callback([json_encode(test_header), 0]) test_header["frame_num"] = 1 - self.node_1.local_callback([json_encode(test_header), 0]) - assert_equal(self.node_1.current_acq, 1) - assert_equal(self.node_1.received_frame_count, 11) - assert_equal(self.node_1.last_frame, 1) - - -class TestLiveViewProxyAdapter(OdinTestServer): - - def setup(self): - self.config_node_list = { - "test_node_1": "tcp://127.0.0.1:5000", - "test_node_2": "tcp://127.0.0.1:5001" - } - self.adapter_config = { - "destination_endpoint": "tcp://127.0.0.1:5021", - "source_endpoints": - "\n{}={},".format(self.config_node_list.keys()[0], self.config_node_list.values()[0]) + - "\n{}={}".format(self.config_node_list.keys()[1], self.config_node_list.values()[1]), - "dropped_frame_warning_cutoff": 0.2, - "queue_length": 15 - } - self.test_frames = [] - for i in range(10): - tmp_frame = Frame([json_encode({"frame_num": i}), "0"]) - self.test_frames.append(tmp_frame) - self.adapter = LiveViewProxyAdapter(**self.adapter_config) - - self.request = Mock - self.request.headers = {'Accept': 'application/json', 'Content-Type': 'application/json'} - self.path = "" - self.has_dropped_frame = False - - def teardown(self): - super(TestLiveViewProxyAdapter, self).teardown_class() - self.adapter.cleanup() - - def dropped_frame(self): - self.has_dropped_frame = True - - def test_adapter_init(self): - assert_equal( - self.adapter.drop_warn_percent, - self.adapter_config["dropped_frame_warning_cutoff"] - ) - assert_equal(self.adapter.dest_endpoint, self.adapter_config["destination_endpoint"]) - assert_true(self.adapter.source_endpoints) - for node in self.adapter.source_endpoints: - assert_true("{}={}".format(node.name, node.endpoint) in self.adapter_config["source_endpoints"]) + proxy_node_test.node_1.local_callback([json_encode(test_header), 0]) + + assert proxy_node_test.node_1.current_acq == 1 + assert proxy_node_test.node_1.received_frame_count == 11 + assert proxy_node_test.node_1.last_frame == 1 + + +@pytest.fixture(scope="class") +def config_node_list(): + return { + "test_node_1": "tcp://127.0.0.1:5000", + "test_node_2": "tcp://127.0.0.1:5001", + } + + +@pytest.fixture(scope="class") +def adapter_config(config_node_list): + config = { + "destination_endpoint": "tcp://127.0.0.1:5021", + "source_endpoints": ",\n".join( + ["{}={}".format(key, val) for (key, val) in config_node_list.items()] + ), + "dropped_frame_warning_cutoff": 0.2, + "queue_length": 15, + } + return config + + +@pytest.fixture(scope="class") +def proxy_adapter(adapter_config): + adapter = LiveViewProxyAdapter(**adapter_config) + return adapter + +num_adapter_test_frames = 10 + + +@pytest.fixture() +def adapter_test_frames(): + test_frames = [] + for i in range(num_adapter_test_frames): + test_frames.append(Frame([json_encode({"frame_num": i}), "0"])) + + return test_frames + + +@pytest.fixture() +def proxy_request(): + request = Mock + request.headers = {"Accept": "application/json", "Content-Type": "application/json"} + return request + + +@pytest.fixture() +def frame_source(): + class FrameSource: + def __init__(self): + self.has_dropped_frame = False + + def dropped_frame(self): + self.has_dropped_frame = True + + return FrameSource() + + +class TestLiveViewProxyAdapter: def test_adapter_default_init(self): default_adapter = LiveViewProxyAdapter(**{}) - assert_equal(default_adapter.dest_endpoint, DEFAULT_DEST_ENDPOINT) - assert_equal(default_adapter.drop_warn_percent, DEFAULT_DROP_WARN_PERCENT) - assert_equal(default_adapter.max_queue, DEFAULT_QUEUE_LENGTH) - assert_equal(default_adapter.source_endpoints[0].endpoint, DEFAULT_SOURCE_ENDPOINT) - - def test_adapter_queue(self): - for frame in reversed(self.test_frames): - self.adapter.add_to_queue(frame, self) + assert default_adapter.dest_endpoint == DEFAULT_DEST_ENDPOINT + assert default_adapter.drop_warn_percent == DEFAULT_DROP_WARN_PERCENT + assert default_adapter.max_queue == DEFAULT_QUEUE_LENGTH + assert default_adapter.source_endpoints[0].endpoint == DEFAULT_SOURCE_ENDPOINT + + def test_adapter_init(self, proxy_adapter, adapter_config): + assert proxy_adapter.drop_warn_percent == adapter_config["dropped_frame_warning_cutoff"] + assert proxy_adapter.dest_endpoint == adapter_config["destination_endpoint"] + assert len(proxy_adapter.source_endpoints) + for node in proxy_adapter.source_endpoints: + assert "{}={}".format(node.name, node.endpoint) in adapter_config["source_endpoints"] + + def test_adapter_queue(self, proxy_adapter, adapter_test_frames): + for frame in reversed(adapter_test_frames): + proxy_adapter.add_to_queue(frame, self) returned_frames = [] - while not self.adapter.queue.empty(): - returned_frames.append(self.adapter.get_frame_from_queue()) + while not proxy_adapter.queue.empty(): + returned_frames.append(proxy_adapter.get_frame_from_queue()) - assert_equals(self.test_frames, returned_frames) + assert adapter_test_frames == returned_frames - def test_adapter_already_connect_exception(self): - try: - # second adapter instance binding to same endpoint - LiveViewProxyAdapter(**self.adapter_config) - except ZMQError: - assert_false("Live View raised ZMQError when binding second socket to same address.") + def test_adapter_already_connect_exception(self, adapter_config, caplog): + # Second adapter instance binding to same endpoint + second_adapter = LiveViewProxyAdapter(**adapter_config) + assert log_message_seen(caplog, logging.ERROR, "Address already in use") - def test_adapter_bad_config(self): - bad_config = self.adapter_config + def test_adapter_bad_config(self, adapter_config, caplog): + bad_config = adapter_config bad_config["source_endpoints"] = "bad_node=not.a.real.socket" + caplog.clear() LiveViewProxyAdapter(**bad_config) + assert log_message_seen(caplog, logging.ERROR, "Error parsing target list") + bad_config["source_endpoints"] = "not_even_parsable" + caplog.clear() LiveViewProxyAdapter(**bad_config) + assert log_message_seen(caplog, logging.ERROR, "Error parsing target list") + + def test_adapter_get(self, proxy_adapter, proxy_request, config_node_list): + response = proxy_adapter.get("", proxy_request) - def test_adapter_get(self): - response = self.adapter.get(self.path, self.request) - assert_equal(response.status_code, 200) - assert_true("target_endpoint" in response.data) - assert_equal(len(response.data["nodes"]), len(self.config_node_list)) + assert response.status_code == 200 + assert "target_endpoint" in response.data + assert len(response.data["nodes"]) == len(config_node_list) for key in response.data["nodes"]: - assert_true(key in self.config_node_list) - assert_equal(response.data["nodes"][key]["endpoint"], self.config_node_list[key]) + assert key in config_node_list + assert response.data["nodes"][key]["endpoint"] == config_node_list[key] - def test_adapter_put_reset(self): - self.request.body = '{"reset": 0}' - self.adapter.last_sent_frame = (1, 2) - self.adapter.dropped_frame_count = 5 + def test_adapter_put_reset(self, proxy_adapter, proxy_request): + proxy_request.body = '{"reset": 0}' + proxy_adapter.last_sent_frame = (1, 2) + proxy_adapter.dropped_frame_count = 5 - response = self.adapter.put(self.path, self.request) - assert_true("last_sent_frame" in response.data) - assert_equal(response.data["last_sent_frame"], (0, 0)) - assert_equal(response.data["dropped_frames"], 0) + response = proxy_adapter.put("", proxy_request) + assert "last_sent_frame" in response.data + assert response.data["last_sent_frame"] == (0, 0) + assert response.data["dropped_frames"] == 0 - def test_adapter_put_invalid(self): - self.request.body = '{"invalid": "Invalid"}' - response = self.adapter.put(self.path, self.request) + def test_adapter_put_invalid(self, proxy_adapter, proxy_request): + proxy_request.body = '{"invalid": "Invalid"}' + response = proxy_adapter.put("", proxy_request) - assert_true("error" in response.data) - assert_equal(response.status_code, 400) + assert "error" in response.data + assert response.status_code == 400 - def test_adapter_fill_queue(self): + def test_adapter_fill_queue(self, adapter_config, proxy_adapter, frame_source): frame_list = [] - queue_length = self.adapter_config["queue_length"] + queue_length = adapter_config["queue_length"] for i in range(queue_length + 1): - tmp_frame = Frame([json_encode({"frame_num": i}), "0"]) - frame_list.append(tmp_frame) + frame_list.append(Frame([json_encode({"frame_num": i}), "0"])) + for frame in frame_list[:queue_length]: - self.adapter.add_to_queue(frame, self) - assert_equal(self.adapter.dropped_frame_count, 0) - self.adapter.add_to_queue(frame_list[-1], self) - assert_equal(self.adapter.dropped_frame_count, 1) + proxy_adapter.add_to_queue(frame, frame_source) + assert proxy_adapter.dropped_frame_count == 0 + + proxy_adapter.add_to_queue(frame_list[-1], self) + assert proxy_adapter.dropped_frame_count == 1 + result_frames = [] - while not self.adapter.queue.empty(): - result_frames.append(self.adapter.queue.get_nowait()) - assert_equal(frame_list[1:], result_frames) + while not proxy_adapter.queue.empty(): + result_frames.append(proxy_adapter.queue.get_nowait()) + assert frame_list[1:] == result_frames - def test_adapter_drop_old_frame(self): + def test_adapter_drop_old_frame(self, proxy_adapter, frame_source): frame_new = Frame([json_encode({"frame_num": 5}), "0"]) frame_old = Frame([json_encode({"frame_num": 1}), "0"]) - self.adapter.add_to_queue(frame_new, self) - get_frame = self.adapter.get_frame_from_queue() - assert_equal(frame_new, get_frame) - assert_equal(self.adapter.last_sent_frame, (0, 5)) - self.adapter.add_to_queue(frame_old, self) - assert_true(self.has_dropped_frame) - assert_true(self.adapter.queue.empty()) + + proxy_adapter.add_to_queue(frame_new, frame_source) + get_frame = proxy_adapter.get_frame_from_queue() + assert frame_new == get_frame + assert proxy_adapter.last_sent_frame == (0, 5) + + proxy_adapter.add_to_queue(frame_old, frame_source) + assert frame_source.has_dropped_frame + assert proxy_adapter.queue.empty() diff --git a/python/tests/test_shared_buffer_manager.py b/python/tests/test_shared_buffer_manager.py index dacfb1d5..46d9da3d 100644 --- a/python/tests/test_shared_buffer_manager.py +++ b/python/tests/test_shared_buffer_manager.py @@ -1,83 +1,77 @@ -from odin_data.shared_buffer_manager import SharedBufferManager, SharedBufferManagerException -from nose.tools import assert_equal, assert_raises, assert_regexp_matches from struct import Struct +import pytest +from odin_data.shared_buffer_manager import SharedBufferManager, SharedBufferManagerException + shared_mem_name = "TestSharedBuffer" -buffer_size = 1000 -num_buffers = 10 +buffer_size = 1000 +num_buffers = 10 shared_mem_size = buffer_size * num_buffers boost_mmap_mode = True -class TestSharedBufferManager: - - @classmethod - def setup_class(cls): - - # Create a shared buffer manager for use in all tests - cls.shared_buffer_manager = SharedBufferManager( - shared_mem_name, shared_mem_size, - buffer_size, remove_when_deleted=True, boost_mmap_mode=boost_mmap_mode) - - @classmethod - def teardown_class(cls): - pass +@pytest.fixture(scope="class") +def shared_buffer_manager(): + return SharedBufferManager( + shared_mem_name, + shared_mem_size, + buffer_size, + remove_when_deleted=True, + boost_mmap_mode=boost_mmap_mode, + ) - def test_basic_shared_buffer(self): +class TestSharedBufferManager: + def test_basic_shared_buffer(self, shared_buffer_manager): # Test the shared buffer manager confoguration is expected - assert_equal(self.shared_buffer_manager.get_num_buffers(), num_buffers) - assert_equal(self.shared_buffer_manager.get_buffer_size(), buffer_size) - - def test_existing_manager(self): + assert shared_buffer_manager.get_num_buffers() == num_buffers + assert shared_buffer_manager.get_buffer_size() == buffer_size + def test_existing_manager(self, shared_buffer_manager): # Map the existing manager - existing_shared_buffer = SharedBufferManager(shared_mem_name, boost_mmap_mode=boost_mmap_mode) + existing_shared_buffer = SharedBufferManager( + shared_mem_name, boost_mmap_mode=boost_mmap_mode + ) # Test that the configuration matches the original - assert_equal(self.shared_buffer_manager.get_manager_id(), existing_shared_buffer.get_manager_id()) - assert_equal(self.shared_buffer_manager.get_num_buffers(), existing_shared_buffer.get_num_buffers()) - assert_equal(self.shared_buffer_manager.get_buffer_size(), existing_shared_buffer.get_buffer_size()) + assert shared_buffer_manager.get_manager_id() == existing_shared_buffer.get_manager_id() + assert shared_buffer_manager.get_num_buffers() == existing_shared_buffer.get_num_buffers() + assert shared_buffer_manager.get_buffer_size() == existing_shared_buffer.get_buffer_size() def test_existing_manager_absent(self): - # Attempt to map a shared buffer manager that doesn't already exist absent_manager_name = "AbsentBufferManager" - with assert_raises(SharedBufferManagerException) as cm: - existing_shared_buffer = SharedBufferManager(absent_manager_name, boost_mmap_mode=boost_mmap_mode) - ex = cm.exception - assert_regexp_matches(ex.msg, "No shared memory exists with the specified name") + with pytest.raises(SharedBufferManagerException) as excinfo: + existing_shared_buffer = SharedBufferManager( + absent_manager_name, boost_mmap_mode=boost_mmap_mode + ) + assert "No shared memory exists with the specified name" in str(excinfo.value) def test_existing_manager_already_exists(self): - # Attempt to create a shared buffer manager that already exists - with assert_raises(SharedBufferManagerException) as cm: - clobbered_shared_buffer = SharedBufferManager(shared_mem_name, - 100, 100, True, boost_mmap_mode=boost_mmap_mode) - ex = cm.exception - assert_regexp_matches(ex.msg, "Shared memory with the specified name already exists") - - def test_illegal_shared_buffer_index(self): - - with assert_raises(SharedBufferManagerException) as cm: - buffer_address = self.shared_buffer_manager.get_buffer_address(-1) - ex = cm.exception - assert_regexp_matches(ex.msg, "Illegal buffer index specified") + with pytest.raises(SharedBufferManagerException) as excinfo: + clobbered_shared_buffer = SharedBufferManager( + shared_mem_name, 100, 100, True, boost_mmap_mode=boost_mmap_mode + ) + assert "Shared memory with the specified name already exists" in str(excinfo.value) - with assert_raises(SharedBufferManagerException) as cm: - buffer_address = self.shared_buffer_manager.get_buffer_address(num_buffers) - ex = cm.exception - assert_regexp_matches(ex.msg, "Illegal buffer index specified") + def test_illegal_shared_buffer_index(self, shared_buffer_manager): + with pytest.raises(SharedBufferManagerException) as excinfo: + buffer_address = shared_buffer_manager.get_buffer_address(-1) + assert "Illegal buffer index specified" in str(excinfo.value) - def test_write_and_read_from_buffer(self): + with pytest.raises(SharedBufferManagerException) as excinfo: + buffer_address = shared_buffer_manager.get_buffer_address(num_buffers) + assert "Illegal buffer index specified" in str(excinfo.value) - data_block = Struct('QQQ') + def test_write_and_read_from_buffer(self, shared_buffer_manager): + data_block = Struct("QQQ") - values = (0xdeadbeef, 0x12345678, 0xaaaa5555aaaa5555) + values = (0xDEADBEEF, 0x12345678, 0xAAAA5555AAAA5555) raw_data = data_block.pack(*values) - self.shared_buffer_manager.write_buffer(0, raw_data) + shared_buffer_manager.write_buffer(0, raw_data) - read_raw = self.shared_buffer_manager.read_buffer(0, data_block.size) + read_raw = shared_buffer_manager.read_buffer(0, data_block.size) read_values = data_block.unpack(read_raw) - assert_equal(values, read_values) + assert values == read_values diff --git a/python/tests/test_util.py b/python/tests/test_util.py index 4275ac59..048c8b77 100644 --- a/python/tests/test_util.py +++ b/python/tests/test_util.py @@ -1,17 +1,14 @@ -import unittest - from odin_data.util import remove_prefix, remove_suffix -class UtilTest(unittest.TestCase): - +class TestUtils: def test_remove_prefix(self): test = "config/hdf/frames" expected = "hdf/frames" result = remove_prefix(test, "config/") - self.assertEqual(expected, result) + assert expected == result def test_remove_suffix(self): test = "hdf/frames/0" @@ -19,4 +16,4 @@ def test_remove_suffix(self): result = remove_suffix(test, "/0") - self.assertEqual(expected, result) + assert expected == result