diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 068ca91..0e264d4 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -20,10 +20,14 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies + - name: Install system dependencies + run: | + sudo apt update + sudo apt install -y swig gcc libpulse-dev portaudio19-dev + - name: Install package run: | python -m pip install --upgrade pip - pip install . -r requirements/test_requirements.txt + pip install .[streaming] -r requirements/test_requirements.txt - name: Run Tests run: | pytest tests diff --git a/Dockerfile b/Dockerfile index a4e5665..d8394da 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,10 +7,13 @@ ENV OVOS_CONFIG_BASE_FOLDER neon ENV OVOS_CONFIG_FILENAME diana.yaml ENV XDG_CONFIG_HOME /config +RUN apt update && apt install -y swig gcc libpulse-dev portaudio19-dev + COPY docker_overlay/ / WORKDIR /app COPY . /app -RUN pip install /app[websocket] + +RUN pip install /app[websocket,streaming] CMD ["python3", "/app/neon_hana/app/__main__.py"] \ No newline at end of file diff --git a/README.md b/README.md index a84a42b..d741025 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ hana: enable_email: True # Disabled by default; anyone with access to the API will be able to send emails from the configured address node_username: node_user # Username to authenticate Node API access; leave empty to disable Node API access node_password: node_password # Password associated with node_username + max_streaming_clients: -1 # Maximum audio streaming clients allowed (including 0). Default unset value allows infinite clients ``` It is recommended to generate unique values for configured tokens, these are 32 bytes in hexadecimal representation. diff --git a/docker_overlay/etc/neon/diana.yaml b/docker_overlay/etc/neon/diana.yaml index c8fa873..71b0990 100644 --- a/docker_overlay/etc/neon/diana.yaml +++ b/docker_overlay/etc/neon/diana.yaml @@ -28,4 +28,6 @@ hana: fastapi_summary: "HANA (HTTP API for Neon Applications) is the HTTP component of the Device Independent API for Neon Applications (DIANA)" stt_max_length_encoded: 500000 tts_max_words: 128 - enable_email: False \ No newline at end of file + enable_email: False +vad: + module: ovos-vad-plugin-webrtcvad \ No newline at end of file diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index 33fec4a..c7acd3f 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -26,19 +26,23 @@ from asyncio import Event from signal import signal, SIGINT +from time import sleep from typing import Optional, Union -from fastapi import APIRouter, WebSocket, HTTPException, Request +from fastapi import APIRouter, WebSocket, HTTPException +from ovos_utils import LOG from starlette.websockets import WebSocketDisconnect from neon_hana.app.dependencies import config, client_manager -from neon_hana.mq_websocket_api import MQWebsocketAPI +from neon_hana.mq_websocket_api import MQWebsocketAPI, ClientNotKnown from neon_hana.schema.node_v1 import (NodeAudioInput, NodeGetStt, NodeGetTts, NodeKlatResponse, NodeAudioInputResponse, NodeGetSttResponse, - NodeGetTtsResponse) + NodeGetTtsResponse, CoreWWDetected, + CoreIntentFailure, CoreErrorResponse, + CoreClearData, CoreAlertExpired) node_route = APIRouter(prefix="/node", tags=["node"]) socket_api = MQWebsocketAPI(config) @@ -65,11 +69,67 @@ async def node_v1_endpoint(websocket: WebSocket, token: str): socket_api.handle_client_input(client_in, client_id) except WebSocketDisconnect: disconnect_event.set() + socket_api.end_session(session_id=client_id) + + +@node_route.websocket("/v1/stream") +async def node_v1_stream_endpoint(websocket: WebSocket, token: str): + client_id = client_manager.get_client_id(token) + + if not client_manager.check_connect_stream(): + raise HTTPException(status_code=503, + detail=f"Server is not accepting any more streams") + try: + await websocket.accept() + disconnect_event = Event() + socket_api.new_stream(websocket, client_id) + while not disconnect_event.is_set(): + try: + client_in: bytes = await websocket.receive_bytes() + socket_api.handle_audio_input_stream(client_in, client_id) + except WebSocketDisconnect: + disconnect_event.set() + except ClientNotKnown as e: + LOG.error(e) + raise HTTPException(status_code=401, + detail=f"Client not known ({client_id})") + except Exception as e: + LOG.exception(e) + finally: + client_manager.disconnect_stream() @node_route.get("/v1/doc") async def node_v1_doc(_: Optional[Union[NodeAudioInput, NodeGetStt, NodeGetTts]]) -> \ Optional[Union[NodeKlatResponse, NodeAudioInputResponse, - NodeGetSttResponse, NodeGetTtsResponse]]: + NodeGetSttResponse, NodeGetTtsResponse, + CoreWWDetected, CoreIntentFailure, CoreErrorResponse, + CoreClearData, CoreAlertExpired]]: + """ + The node endpoint (`/node/v1`) accepts and returns JSON objects representing + Messages. All inputs and responses will contain keys: + `msg_type`, `data`, `context`. Only the inputs and responses documented here + are explicitly supported. Other messages sent or received on this socket are + not guaranteed to be stable. + """ + pass + + +@node_route.get("/v1/stream/doc") +async def node_v1_stream_doc(): + """ + The stream endpoint accepts input audio as raw bytes. It expects inputs to + have: + - sample_rate=16000 + - sample_width=2 + - sample_channels=1 + - chunk_size=4096 + + Response audio is WAV audio as raw bytes, with each message containing one + full audio file. A client should queue all responses for playback. + + Any client accessing the stream endpoint (`/node/v1/stream`), must first + establish a connection to the node endpoint (`/node/v1`). + """ pass diff --git a/neon_hana/auth/client_manager.py b/neon_hana/auth/client_manager.py index d1bb6d0..578ea13 100644 --- a/neon_hana/auth/client_manager.py +++ b/neon_hana/auth/client_manager.py @@ -23,6 +23,7 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from threading import Lock import jwt @@ -53,7 +54,10 @@ def __init__(self, config: dict): self._disable_auth = config.get("disable_auth") self._node_username = config.get("node_username") self._node_password = config.get("node_password") + self._max_streaming_clients = config.get("max_streaming_clients") self._jwt_algo = "HS256" + self._connected_streams = 0 + self._stream_check_lock = Lock() def _create_tokens(self, encode_data: dict) -> dict: # Permissions were not included in old tokens, allow refreshing with @@ -88,6 +92,26 @@ def get_permissions(self, client_id: str) -> ClientPermissions: client = self.authorized_clients[client_id] return ClientPermissions(**client.get('permissions', dict())) + def check_connect_stream(self) -> bool: + """ + Check if a new stream is allowed + """ + with self._stream_check_lock: + if not isinstance(self._max_streaming_clients, int) or \ + self._max_streaming_clients is False or \ + self._max_streaming_clients < 0: + self._connected_streams += 1 + return True + if self._connected_streams >= self._max_streaming_clients: + LOG.warning(f"No more streams allowed ({self._connected_streams})") + return False + self._connected_streams += 1 + return True + + def disconnect_stream(self): + with self._stream_check_lock: + self._connected_streams -= 1 + def check_auth_request(self, client_id: str, username: str, password: Optional[str] = None, origin_ip: str = "127.0.0.1") -> dict: diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py index 2c65af1..6b9b96e 100644 --- a/neon_hana/mq_websocket_api.py +++ b/neon_hana/mq_websocket_api.py @@ -26,7 +26,9 @@ from asyncio import run, get_event_loop from os import makedirs -from time import time +from queue import Queue +from time import time, sleep +from typing import Optional from fastapi import WebSocket from neon_iris.client import NeonAIClient from ovos_bus_client.message import Message @@ -34,6 +36,12 @@ from ovos_utils import LOG +class ClientNotKnown(RuntimeError): + """ + Exception raised when a client tries to do something before authenticating + """ + + class MQWebsocketAPI(NeonAIClient): def __init__(self, config: dict): """ @@ -58,6 +66,49 @@ def new_connection(self, ws: WebSocket, session_id: str): "socket": ws, "user": self.user_config} + def new_stream(self, ws: WebSocket, session_id: str): + """ + Establish a new streaming connection, associated with an existing session. + @param ws: Client WebSocket that handles byte audio + @param session_id: Session ID the websocket is associated with + """ + timeout = time() + 5 + while session_id not in self._sessions and time() < timeout: + # Handle problem clients that don't explicitly wait for the Node WS + # to connect before starting a stream + sleep(1) + with self._session_lock: + if session_id not in self._sessions: + raise ClientNotKnown(f"Stream cannot be established for {session_id}") + from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone + if not self._sessions[session_id].get('stream'): + LOG.info(f"starting stream for session {session_id}") + audio_queue = Queue() + stream = RemoteStreamHandler(StreamMicrophone(audio_queue), session_id, + input_audio_callback=self.handle_client_input, + ww_callback=self.handle_ww_detected, + client_socket=ws) + self._sessions[session_id]['stream'] = stream + try: + stream.start() + except RuntimeError: + pass + + def end_session(self, session_id: str): + """ + End a client connection upon WS disconnection + """ + with self._session_lock: + session: Optional[dict] = self._sessions.pop(session_id, None) + if not session: + LOG.error(f"Ended session is not established {session_id}") + return + stream = session.get('stream') + if stream: + stream.shutdown() + stream.join() + LOG.info(f"Ended stream handler for: {session_id}") + def get_session(self, session_id: str) -> dict: """ Get the latest session context for the given session_id. @@ -114,6 +165,15 @@ def _update_session_data(self, message: Message): if user_config: self._sessions[session_id]['user'] = user_config + def handle_audio_input_stream(self, audio: bytes, session_id: str): + self._sessions[session_id]['stream'].mic.queue.put(audio) + + def handle_ww_detected(self, ww_context: dict, session_id: str): + session = self.get_session(session_id) + message = Message("neon.ww_detected", ww_context, + {"session": session}) + run(self.send_to_client(message)) + def handle_client_input(self, data: dict, session_id: str): """ Handle some client input data. @@ -134,9 +194,16 @@ def handle_klat_response(self, message: Message): Handle a Neon text+audio response to a user input. @param message: `klat.response` message from Neon """ - self._update_session_data(message) - run(self.send_to_client(message)) - LOG.debug(message.context.get("timing")) + try: + self._update_session_data(message) + run(self.send_to_client(message)) + session_id = message.context.get('session', {}).get('session_id') + if stream := self._sessions.get(session_id, {}).get('stream'): + LOG.info("Stream response audio") + stream.on_response_audio(message.data) + LOG.debug(message.context.get("timing")) + except Exception as e: + LOG.exception(e) def handle_complete_intent_failure(self, message: Message): """ diff --git a/neon_hana/schema/node_v1.py b/neon_hana/schema/node_v1.py index 11c2725..913c6b0 100644 --- a/neon_hana/schema/node_v1.py +++ b/neon_hana/schema/node_v1.py @@ -122,3 +122,33 @@ class NodeGetTtsResponse(BaseModel): msg_type: str = "neon.get_tts.response" data: KlatResponseData context: dict + + +class CoreWWDetected(BaseModel): + msg_type: str = "neon.ww_detected" + data: dict + context: dict + + +class CoreIntentFailure(BaseModel): + msg_type: str = "complete.intent.failure" + data: dict + context: dict + + +class CoreErrorResponse(BaseModel): + msg_type: str = "klat.error" + data: dict + context: dict + + +class CoreClearData(BaseModel): + msg_type: str = "neon.clear_data" + data: dict + context: dict + + +class CoreAlertExpired(BaseModel): + msg_type: str = "neon.alert_expired" + data: dict + context: dict diff --git a/neon_hana/streaming_client.py b/neon_hana/streaming_client.py new file mode 100644 index 0000000..c3263a7 --- /dev/null +++ b/neon_hana/streaming_client.py @@ -0,0 +1,107 @@ +import io +from asyncio import run +from base64 import b64encode, b64decode +from typing import Optional, Callable +from mock.mock import Mock +from threading import Thread +from queue import Queue + +from ovos_dinkum_listener.voice_loop import DinkumVoiceLoop +from ovos_dinkum_listener.voice_loop.hotwords import HotwordContainer +from ovos_dinkum_listener.voice_loop.voice_loop import ChunkInfo +from ovos_plugin_manager.templates.microphone import Microphone +from ovos_plugin_manager.vad import OVOSVADFactory +from ovos_utils.fakebus import FakeBus +from speech_recognition import AudioData +from ovos_utils import LOG +from starlette.websockets import WebSocket + + +class StreamMicrophone(Microphone): + def __init__(self, queue: Queue): + self.queue = queue + + def start(self): + pass + + def stop(self): + self.queue.put(None) + + def read_chunk(self) -> Optional[bytes]: + return self.queue.get() + + +class RemoteStreamHandler(Thread): + def __init__(self, mic: StreamMicrophone, session_id: str, + input_audio_callback: Callable, + client_socket: WebSocket, + ww_callback: Callable, lang: str = "en-us"): + Thread.__init__(self) + self.session_id = session_id + self.ww_callback = ww_callback + self.input_audio_callback = input_audio_callback + self.client_socket = client_socket + self.bus = FakeBus() + self.mic = mic + self.lang = lang + self.hotwords = HotwordContainer(self.bus) + self.hotwords.load_hotword_engines() + self.vad = OVOSVADFactory.create() + self.voice_loop = DinkumVoiceLoop(mic=self.mic, + vad=self.vad, + hotwords=self.hotwords, + listenword_audio_callback=self.on_hotword, + hotword_audio_callback=self.on_hotword, + stopword_audio_callback=self.on_hotword, + wakeupword_audio_callback=self.on_hotword, + stt_audio_callback=self.on_input_audio, + stt=Mock(transcribe=Mock(return_value=[])), + fallback_stt=Mock(transcribe=Mock(return_value=[])), + transformers=MockTransformers(), + chunk_callback=self.on_chunk, + speech_seconds=0.5, + num_hotword_keep_chunks=0, + num_stt_rewind_chunks=0) + + def run(self): + self.voice_loop.start() + self.voice_loop.run() + + def on_hotword(self, audio_bytes: bytes, context: dict): + self.lang = context.get("stt_lang") or self.lang + LOG.info(f"Hotword: {context}") + self.ww_callback(context, self.session_id) + + def on_input_audio(self, audio_bytes: bytes, context: dict): + LOG.info(f"Audio: {context}") + audio_data = AudioData(audio_bytes, self.mic.sample_rate, + self.mic.sample_width).get_wav_data() + audio_data = b64encode(audio_data).decode("utf-8") + callback_data = {"type": "neon.audio_input", + "data": {"audio_data": audio_data, "lang": self.lang}} + self.input_audio_callback(callback_data, self.session_id) + + def on_response_audio(self, data: dict): + async def _send_bytes(audio_bytes: bytes): + await self.client_socket.send_bytes(audio_bytes) + + i = 0 + for lang_response in data.get('responses', {}).values(): + for encoded_audio in lang_response.get('audio', {}).values(): + i += 1 + wav_audio_bytes = b64decode(encoded_audio) + LOG.info(f"Sending {len(wav_audio_bytes)} bytes of audio") + run(_send_bytes(wav_audio_bytes)) + LOG.info(f"Sent {i} binary audio response(s)") + + def on_chunk(self, chunk: ChunkInfo): + LOG.debug(f"Chunk: {chunk}") + + def shutdown(self): + self.mic.stop() + self.voice_loop.stop() + + +class MockTransformers(Mock): + def transform(self, chunk): + return chunk, dict() diff --git a/requirements/streaming.txt b/requirements/streaming.txt new file mode 100644 index 0000000..da3ca46 --- /dev/null +++ b/requirements/streaming.txt @@ -0,0 +1,9 @@ +mock~=5.0 +ovos-dinkum-listener~=0.1 +ovos-vad-plugin-webrtcvad~=0.0.1 +ovos-ww-plugin-pocketsphinx~=0.1 +ovos-ww-plugin-vosk~=0.1 +ovos-ww-plugin-precise-lite[tflite]~=0.1 +ovos-ww-plugin-precise~=0.1 +# TODO: numpy patching tflite plugin compat. issue https://github.com/OpenVoiceOS/ovos-ww-plugin-precise-lite/pull/8 +numpy~=1.0 diff --git a/setup.py b/setup.py index b9928ee..85ddbef 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,8 @@ def get_requirements(requirements_filename: str): license='BSD-3-Clause', packages=find_packages(), install_requires=get_requirements("requirements.txt"), - extras_require={"websocket": get_requirements("websocket.txt")}, + extras_require={"websocket": get_requirements("websocket.txt"), + "streaming": get_requirements("streaming.txt")}, zip_safe=True, classifiers=[ 'Intended Audience :: Developers', diff --git a/tests/test_auth.py b/tests/test_auth.py index d2fcbef..88422fb 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -202,3 +202,33 @@ def test_client_permissions(self): self.assertFalse(any([v for v in restricted_perms.as_dict().values()])) self.assertIsInstance(permissive_perms.as_dict(), dict) self.assertTrue(all([v for v in permissive_perms.as_dict().values()])) + + def test_stream_connections(self): + # Test configured maximum + self.client_manager._max_streaming_clients = 1 + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 1) + self.assertFalse(self.client_manager.check_connect_stream()) + self.assertFalse(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 1) + self.client_manager.disconnect_stream() + self.assertEqual(self.client_manager._connected_streams, 0) + + # Test explicitly disabled streaming + self.client_manager._max_streaming_clients = 0 + self.assertFalse(self.client_manager.check_connect_stream()) + + # Test unlimited clients + self.client_manager._max_streaming_clients = None + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 3) + + self.client_manager._max_streaming_clients = -1 + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 4) + + self.client_manager._max_streaming_clients = False + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 5)