diff --git a/requirements.txt b/requirements.txt index 1c908cb..e860e2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,5 @@ ndicapi>=3.2.6 numpy>=1.11 six>=1.10 -scikit-surgerycore>=0.6.9 +scikit-surgerycore>=0.7.0 pyserial diff --git a/setup.py b/setup.py index 2595bec..3f364b9 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ 'six>=1.10', 'numpy>=1.11', 'ndicapi>=3.2.6', - 'scikit-surgerycore>=0.6.9', + 'scikit-surgerycore>=0.7.0', 'pyserial' ], ) diff --git a/sksurgerynditracker/nditracker.py b/sksurgerynditracker/nditracker.py index 8f9f490..fcad680 100644 --- a/sksurgerynditracker/nditracker.py +++ b/sksurgerynditracker/nditracker.py @@ -12,7 +12,7 @@ from serial.tools import list_ports #pylint: disable=import-error from six import int2byte -from numpy import full, nan, reshape, transpose +from numpy import full, nan, reshape from sksurgerycore.baseclasses.tracker import SKSBaseTracker import ndicapy from sksurgerynditracker.serial_utils.com_ports import \ @@ -116,7 +116,10 @@ def __init__(self, configuration): ports to probe: - use quaternions: + use quaternions: default is false + + smoothing buffer: specify a buffer over which to average the + tracking, defaults to 1 :raises Exception: IOError, KeyError, OSError """ @@ -449,7 +452,8 @@ def get_frame(self): port_handles = [] time_stamps = [] frame_numbers = [] - tracking = [] + tracking_rots = [] + tracking_trans = [] tracking_quality = [] timestamp = time() @@ -466,21 +470,13 @@ def get_frame(self): descriptor.get("c_str port handle")) if not qtransform == "MISSING" and not qtransform == "DISABLED": tracking_quality.append(qtransform[7]) - if not self.use_quaternions: - transform = transpose( - reshape(ndicapy.ndiTransformToMatrixd(qtransform), - [4, 4])) - else: - transform = reshape(qtransform[0:7], [1, 7]) + transform = reshape(qtransform[0:7], [1, 7]) else: tracking_quality.append(nan) - if not self.use_quaternions: - transform = full((4, 4), nan) - else: - transform = full((1, 7), nan) - - tracking.append(transform) + transform = full((1, 7), nan) + tracking_rots.append(transform[0][0:4]) + tracking_trans.append(transform[0][4:7]) else: for descriptor in self._tool_descriptors: port_handles.append(descriptor.get( @@ -488,13 +484,14 @@ def get_frame(self): time_stamps.append(timestamp) frame_numbers.append(0) tracking_quality.append(0.0) - if not self.use_quaternions: - tracking.append(full((4, 4), nan)) - else: - tracking.append(full((1, 7), nan)) + tracking_rots.append(full((1, 4), nan)) + tracking_trans.append(full((1, 3), nan)) + + self.add_frame_to_buffer(port_handles, time_stamps, frame_numbers, + tracking_rots, tracking_trans, tracking_quality, + rot_is_quaternion = True) - return port_handles, time_stamps, frame_numbers, tracking, \ - tracking_quality + return self.get_smooth_frame(port_handles) def get_tool_descriptions(self): """ Returns the port handles and tool descriptions """ diff --git a/tests/polaris_mocks.py b/tests/polaris_mocks.py index a92a04f..2b5d3a8 100644 --- a/tests/polaris_mocks.py +++ b/tests/polaris_mocks.py @@ -13,14 +13,13 @@ "data/8700339.rom"] } -SETTINGS_POLARIS_QUAT = { - "tracker type": "polaris", - "ports to probe": 20, - "romfiles" : [ - "data/something_else.rom", - "data/8700339.rom"], - "use quaternions": "true" - } +SETTINGS_POLARIS_QUAT = SETTINGS_POLARIS.copy() +SETTINGS_POLARIS_QUAT["use quaternions"] = "true" + +SETTINGS_POLARIS_SMOOTH = SETTINGS_POLARIS.copy() +SETTINGS_POLARIS_QUAT_SMOOTH = SETTINGS_POLARIS_QUAT.copy() +SETTINGS_POLARIS_SMOOTH["smoothing buffer"] = 2 +SETTINGS_POLARIS_QUAT_SMOOTH["smoothing buffer"] = 2 class MockPort: """A fake serial port for ndi""" diff --git a/tests/test_sksurgerynditracker_mockndi_getframe.py b/tests/test_sksurgerynditracker_mockndi_getframe.py index 3814057..f7e5967 100644 --- a/tests/test_sksurgerynditracker_mockndi_getframe.py +++ b/tests/test_sksurgerynditracker_mockndi_getframe.py @@ -6,6 +6,7 @@ from sksurgerynditracker.nditracker import NDITracker from tests.polaris_mocks import SETTINGS_POLARIS, SETTINGS_POLARIS_QUAT, \ + SETTINGS_POLARIS_SMOOTH, SETTINGS_POLARIS_QUAT_SMOOTH, \ mockndiProbe, \ mockndiOpen, mockndiGetError, mockComports, \ mockndiGetPHSRHandle, mockndiVER, \ @@ -156,7 +157,159 @@ def test_getframe_missing(mocker): assert len(port_handles) == 2 assert len(time_stamps) == 2 assert frame_numbers.count(1) == 2 - assert np.all(np.isnan(tracking)) + assert np.any(np.isnan(tracking[0])) + assert np.any(np.isnan(tracking[1])) + assert np.all(np.isnan(tracking_quality)) + + del tracker + +def test_getframe_smooth_mock(mocker): + """ + connects and configures, mocks ndicapy.ndiProbe to pass + reqs: 03, 04 + """ + tracker = None + bxsource = MockBXFrameSource() + ndidevice = MockNDIDevice() + mocker.patch('serial.tools.list_ports.comports', mockComports) + mocker.patch('ndicapy.ndiProbe', mockndiProbe) + mocker.patch('ndicapy.ndiOpen', mockndiOpen) + mocker.patch('ndicapy.ndiCommand', ndidevice.mockndiCommand) + mocker.patch('ndicapy.ndiGetError', mockndiGetError) + mocker.patch('ndicapy.ndiClose') + mocker.patch('ndicapy.ndiGetPHSRNumberOfHandles', + ndidevice.mockndiGetPHSRNumberOfHandles) + mocker.patch('ndicapy.ndiGetPHRQHandle', ndidevice.mockndiGetPHRQHandle) + mocker.patch('ndicapy.ndiPVWRFromFile') + mocker.patch('ndicapy.ndiGetPHSRHandle', mockndiGetPHSRHandle) + mocker.patch('ndicapy.ndiVER', mockndiVER) + mocker.patch('ndicapy.ndiGetBXFrame', bxsource.mockndiGetBXFrame) + mocker.patch('ndicapy.ndiGetBXTransform', bxsource.mockndiGetBXTransform) + + tracker = NDITracker(SETTINGS_POLARIS_SMOOTH) + + bxsource.setdevice(ndidevice) + + (port_handles, time_stamps, frame_numbers, tracking, + tracking_quality ) = tracker.get_frame() + + assert len(port_handles) == 2 + assert len(time_stamps) == 2 + assert frame_numbers.count(1) == 2 + expected_tracking_0 = np.array([[1.,0.,0.,10.], + [0.,1.,0.,-20.], + [0.,0.,1.,5.], + [0.,0.,0.,1.]]) + assert np.array_equal(expected_tracking_0, tracking[0]) + expected_tracking_1 = np.array([[1.,0.,0.,0.], + [0.,1.,0.,0.], + [0.,0.,1.,0.], + [0.,0.,0.,1.]]) + assert np.array_equal(expected_tracking_1, tracking[1]) + assert tracking_quality.count(1.) == 2 + + (port_handles, time_stamps, frame_numbers, tracking, + tracking_quality ) = tracker.get_frame() + + assert len(port_handles) == 2 + assert len(time_stamps) == 2 + assert frame_numbers.count(2) == 2 + expected_tracking_0 = np.array([[1.,0.,0.,15.], + [0.,1.,0.,-30.], + [0.,0.,1.,7.5], + [0.,0.,0.,1.]]) + assert np.array_equal(expected_tracking_0, tracking[0]) + assert np.array_equal(expected_tracking_1, tracking[1]) + assert tracking_quality.count(1.) == 2 + + del tracker + +def test_getframe_smooth_mock_quat(mocker): + """ + Checks that get frame works with quaternions + """ + tracker = None + bxsource = MockBXFrameSource() + ndidevice = MockNDIDevice() + mocker.patch('serial.tools.list_ports.comports', mockComports) + mocker.patch('ndicapy.ndiProbe', mockndiProbe) + mocker.patch('ndicapy.ndiOpen', mockndiOpen) + mocker.patch('ndicapy.ndiCommand', ndidevice.mockndiCommand) + mocker.patch('ndicapy.ndiGetError', mockndiGetError) + mocker.patch('ndicapy.ndiClose') + mocker.patch('ndicapy.ndiGetPHSRNumberOfHandles', + ndidevice.mockndiGetPHSRNumberOfHandles) + mocker.patch('ndicapy.ndiGetPHRQHandle', ndidevice.mockndiGetPHRQHandle) + mocker.patch('ndicapy.ndiPVWRFromFile') + mocker.patch('ndicapy.ndiGetPHSRHandle', mockndiGetPHSRHandle) + mocker.patch('ndicapy.ndiVER', mockndiVER) + mocker.patch('ndicapy.ndiGetBXFrame', bxsource.mockndiGetBXFrame) + mocker.patch('ndicapy.ndiGetBXTransform', bxsource.mockndiGetBXTransform) + + tracker = NDITracker(SETTINGS_POLARIS_QUAT_SMOOTH) + + bxsource.setdevice(ndidevice) + + (port_handles, time_stamps, frame_numbers, tracking, + tracking_quality ) = tracker.get_frame() + + assert len(port_handles) == 2 + assert len(time_stamps) == 2 + assert frame_numbers.count(1) == 2 + expected_tracking_0 = np.array([[1.,0.,0.,0.,10.,-20,5.]]) + assert np.array_equal(expected_tracking_0, tracking[0]) + expected_tracking_1 = np.array([[1.,0.,0.,0.,0.,0.,0.]]) + assert np.array_equal(expected_tracking_1, tracking[1]) + assert tracking_quality.count(1.) == 2 + + (port_handles, time_stamps, frame_numbers, tracking, + tracking_quality ) = tracker.get_frame() + + assert len(port_handles) == 2 + assert len(time_stamps) == 2 + assert frame_numbers.count(2) == 2 + expected_tracking_0 = np.array([[1.,0.,0.,0.,15.,-30,7.5]]) + assert np.array_equal(expected_tracking_0, tracking[0]) + assert np.array_equal(expected_tracking_1, tracking[1]) + assert tracking_quality.count(1.) == 2 + + del tracker + +def test_getframe_missing_smooth(mocker): + """ + connects and configures, mocks ndicapy.ndiProbe to pass + reqs: 03, 04 + """ + tracker = None + bxsource = MockBXFrameSource() + ndidevice = MockNDIDevice() + mocker.patch('serial.tools.list_ports.comports', mockComports) + mocker.patch('ndicapy.ndiProbe', mockndiProbe) + mocker.patch('ndicapy.ndiOpen', mockndiOpen) + mocker.patch('ndicapy.ndiCommand', ndidevice.mockndiCommand) + mocker.patch('ndicapy.ndiGetError', mockndiGetError) + mocker.patch('ndicapy.ndiClose') + mocker.patch('ndicapy.ndiGetPHSRNumberOfHandles', + ndidevice.mockndiGetPHSRNumberOfHandles) + mocker.patch('ndicapy.ndiGetPHRQHandle', ndidevice.mockndiGetPHRQHandle) + mocker.patch('ndicapy.ndiPVWRFromFile') + mocker.patch('ndicapy.ndiGetPHSRHandle', mockndiGetPHSRHandle) + mocker.patch('ndicapy.ndiVER', mockndiVER) + mocker.patch('ndicapy.ndiGetBXFrame', bxsource.mockndiGetBXFrame) + mocker.patch('ndicapy.ndiGetBXTransform', + bxsource.mockndiGetBXTransformMissing) + + tracker = NDITracker(SETTINGS_POLARIS_SMOOTH) + + bxsource.setdevice(ndidevice) + (port_handles, time_stamps, frame_numbers, tracking, + tracking_quality ) = tracker.get_frame() + + assert len(port_handles) == 2 + assert len(time_stamps) == 2 + assert frame_numbers.count(1) == 2 + assert np.any(np.isnan(tracking[0])) + assert np.any(np.isnan(tracking[1])) assert np.all(np.isnan(tracking_quality)) del tracker