Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow OAV to Redis forwarder to use different OAV streams #808

Merged
merged 8 commits into from
Oct 14, 2024
20 changes: 18 additions & 2 deletions src/dodal/devices/oav/oav_to_redis_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import uuid
from collections.abc import Awaitable, Callable
from datetime import timedelta
from enum import Enum

import numpy as np
from aiohttp import ClientResponse, ClientSession
from bluesky.protocols import Flyable, Stoppable
from ophyd_async.core import (
AsyncStatus,
DeviceVector,
StandardReadable,
soft_signal_r_and_setter,
soft_signal_rw,
Expand All @@ -30,6 +32,11 @@ async def get_next_jpeg(response: ClientResponse) -> bytes:
return line + await response.content.readuntil(JPEG_STOP_BYTE)


class Source(Enum):
FULL_SCREEN = 0
ROI = 1


class OAVToRedisForwarder(StandardReadable, Flyable, Stoppable):
"""Forwards OAV image data to redis. To use call:

Expand Down Expand Up @@ -59,7 +66,15 @@ def __init__(
redis_db: int which redis database to connect to, defaults to 0
name: str the name of this device
"""
self.stream_url = epics_signal_r(str, f"{prefix}MJPG:MJPG_URL_RBV")
self._sources = DeviceVector(
{
Source.FULL_SCREEN.value: epics_signal_r(
str, f"{prefix}XTAL:MJPG_URL_RBV"
),
Source.ROI.value: epics_signal_r(str, f"{prefix}MJPG:MJPG_URL_RBV"),
}
)
self.selected_source = soft_signal_rw(Source)

with self.add_children_as_readables():
self.uuid, self.uuid_setter = soft_signal_r_and_setter(str)
Expand Down Expand Up @@ -95,7 +110,8 @@ async def _get_frame_and_put_to_redis(self, response: ClientResponse):
async def _open_connection_and_do_function(
self, function_to_do: Callable[[ClientResponse, str | None], Awaitable]
):
stream_url = await self.stream_url.get_value()
source = await self.selected_source.get_value()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could: use asyncio.gather for consecutive reads

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I can here because the read of stream_url needs the value from the read of source. asyncio.gather will try and read them in parallel

stream_url = await self._sources[source.value].get_value()
async with ClientSession() as session:
async with session.get(stream_url) as response:
await function_to_do(response, stream_url)
Expand Down
17 changes: 11 additions & 6 deletions system_tests/test_oav_to_redis_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aiohttp.client_exceptions import ClientConnectorError
from ophyd_async.core import DeviceCollector, set_mock_value

from dodal.devices.oav.oav_to_redis_forwarder import OAVToRedisForwarder
from dodal.devices.oav.oav_to_redis_forwarder import OAVToRedisForwarder, Source


def _oav_to_redis_forwarder(mock):
Expand All @@ -28,14 +28,19 @@ def mock_oav_to_redis_forwarder(_, RE):
return _oav_to_redis_forwarder(True)


def _set_url(mock_oav_to_redis_forwarder: OAVToRedisForwarder, url: str):
set_mock_value(
mock_oav_to_redis_forwarder._sources[Source.FULL_SCREEN.value],
url,
)
set_mock_value(mock_oav_to_redis_forwarder.selected_source, Source.FULL_SCREEN)


@pytest.mark.s03 # Doesn't actually depend on s03 but is a system test as it depends on external webpage. See https://github.com/DiamondLightSource/mx-bluesky/issues/183
async def test_given_stream_url_is_not_a_real_webpage_when_kickoff_then_error(
mock_oav_to_redis_forwarder: OAVToRedisForwarder,
):
set_mock_value(
mock_oav_to_redis_forwarder.stream_url,
"http://www.this_is_not_a_valid_webpage.com/",
)
_set_url(mock_oav_to_redis_forwarder, "http://www.this_is_not_a_valid_webpage.com/")
with pytest.raises(ClientConnectorError):
await mock_oav_to_redis_forwarder.kickoff()

Expand All @@ -45,7 +50,7 @@ async def test_given_stream_url_is_real_webpage_but_not_mjpg_when_kickoff_then_e
mock_oav_to_redis_forwarder: OAVToRedisForwarder,
):
URL = "https://www.google.com/"
set_mock_value(mock_oav_to_redis_forwarder.stream_url, URL)
_set_url(mock_oav_to_redis_forwarder, URL)
with pytest.raises(ValueError) as e:
await mock_oav_to_redis_forwarder.kickoff()
assert URL in str(e.value)
Expand Down
40 changes: 33 additions & 7 deletions tests/devices/unit_tests/oav/test_oav_to_redis_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,29 @@
import io
import pickle
from datetime import timedelta
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import ANY, AsyncMock, MagicMock, patch

import numpy as np
import pytest
from ophyd_async.core import DeviceCollector, set_mock_value
from PIL import Image

from dodal.devices.oav.oav_to_redis_forwarder import OAVToRedisForwarder, get_next_jpeg
from dodal.devices.oav.oav_to_redis_forwarder import (
OAVToRedisForwarder,
Source,
get_next_jpeg,
)


@pytest.fixture
@patch("dodal.devices.oav.oav_to_redis_forwarder.StrictRedis", new=AsyncMock)
def oav_forwarder(RE):
with DeviceCollector(mock=True):
oav_forwarder = OAVToRedisForwarder("prefix", "host", "password")
set_mock_value(oav_forwarder.stream_url, "test-stream-url")
set_mock_value(
oav_forwarder._sources[Source.FULL_SCREEN.value], "test-full-screen-stream-url"
)
set_mock_value(oav_forwarder._sources[Source.ROI.value], "test-roi-stream-url")
return oav_forwarder


Expand All @@ -30,7 +37,7 @@ def oav_forwarder_with_valid_response(oav_forwarder):
mock_get.return_value.__aenter__.return_value = (mock_response := AsyncMock())
mock_response.content_type = "multipart/x-mixed-replace"
oav_forwarder._get_frame_and_put_to_redis = AsyncMock()
yield oav_forwarder, mock_response
yield oav_forwarder, mock_response, mock_get
client_session_patch.stop()


Expand All @@ -50,7 +57,7 @@ async def test_given_response_is_not_mjpeg_when_oav_forwarder_kicked_off_then_ex
async def test_when_oav_forwarder_kicked_off_then_connection_open_and_data_streamed(
oav_forwarder_with_valid_response,
):
oav_forwarder, mock_response = oav_forwarder_with_valid_response
oav_forwarder, mock_response, _ = oav_forwarder_with_valid_response

await oav_forwarder.kickoff()

Expand All @@ -63,7 +70,7 @@ async def test_when_oav_forwarder_kicked_off_then_connection_open_and_data_strea
async def test_when_oav_forwarder_kicked_off_then_stopped_forwarding_is_stopped(
oav_forwarder_with_valid_response,
):
oav_forwarder, _ = oav_forwarder_with_valid_response
oav_forwarder, _, _ = oav_forwarder_with_valid_response

await oav_forwarder.kickoff()
await oav_forwarder.stop()
Expand All @@ -73,7 +80,7 @@ async def test_when_oav_forwarder_kicked_off_then_stopped_forwarding_is_stopped(
async def test_when_oav_forwarder_kicked_off_then_completed_forwarding_is_stopped(
oav_forwarder_with_valid_response,
):
oav_forwarder, _ = oav_forwarder_with_valid_response
oav_forwarder, _, _ = oav_forwarder_with_valid_response

await oav_forwarder.kickoff()
await oav_forwarder.complete()
Expand Down Expand Up @@ -138,3 +145,22 @@ async def test_when_get_frame_and_put_to_redis_called_then_data_put_in_redis_wit
redis_expire_call = oav_forwarder.redis_client.expire.call_args[0]
assert redis_expire_call[0] == str(SAMPLE_ID)
assert redis_expire_call[1] == timedelta(days=oav_forwarder.DATA_EXPIRY_DAYS)


@pytest.mark.parametrize(
"source, expected_url",
[
(Source.FULL_SCREEN, "test-full-screen-stream-url"),
(Source.ROI, "test-roi-stream-url"),
],
)
async def test_when_different_sources_selected_then_different_urls_used(
oav_forwarder_with_valid_response, source, expected_url
):
oav_forwarder, _, mock_get = oav_forwarder_with_valid_response
oav_forwarder.selected_source.set(source)

await oav_forwarder.kickoff()
await oav_forwarder.complete()

mock_get.assert_called_with(ANY, expected_url)
Loading