From 589da9ba4beaca19cbf592aed51d98ea499387cc Mon Sep 17 00:00:00 2001 From: Daniel McKnight <34697904+NeonDaniel@users.noreply.github.com> Date: Tue, 21 May 2024 13:05:18 -0700 Subject: [PATCH] Implement websocket client API (#20) # Description Implements a websocket API for a Node client Adds `ClientPermissions` object to define per-client permissions # Issues Closes #6 # Other Notes Example client implementation: https://github.com/NeonGeckoCom/neon-nodes/pull/14 --------- Co-authored-by: Daniel McKnight --- Dockerfile | 2 +- README.md | 3 +- neon_hana/app/__init__.py | 2 + neon_hana/app/routers/node_server.py | 75 ++++++++++ neon_hana/auth/client_manager.py | 52 ++++++- neon_hana/auth/permissions.py | 43 ++++++ neon_hana/mq_websocket_api.py | 203 +++++++++++++++++++++++++++ neon_hana/schema/node_v1.py | 124 ++++++++++++++++ requirements/test_requirements.txt | 4 +- requirements/websocket.txt | 2 + setup.py | 1 + tests/test_auth.py | 68 ++++++++- tests/test_mq_service_api.py | 32 +++++ tests/test_mq_websocket_api.py | 32 +++++ 14 files changed, 636 insertions(+), 7 deletions(-) create mode 100644 neon_hana/app/routers/node_server.py create mode 100644 neon_hana/auth/permissions.py create mode 100644 neon_hana/mq_websocket_api.py create mode 100644 neon_hana/schema/node_v1.py create mode 100644 requirements/websocket.txt create mode 100644 tests/test_mq_service_api.py create mode 100644 tests/test_mq_websocket_api.py diff --git a/Dockerfile b/Dockerfile index f8095c0..a4e5665 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,6 +11,6 @@ COPY docker_overlay/ / WORKDIR /app COPY . /app -RUN pip install /app +RUN pip install /app[websocket] CMD ["python3", "/app/neon_hana/app/__main__.py"] \ No newline at end of file diff --git a/README.md b/README.md index 54cf77c..a84a42b 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,8 @@ hana: stt_max_length_encoded: 500000 # Arbitrary limit that is larger than any expected voice command tts_max_words: 128 # Arbitrary limit that is longer than any default LLM token limit enable_email: True # Disabled by default; anyone with access to the API will be able to send emails from the configured address - + node_username: node_user # Username to authenticate Node API access; leave empty to disable Node API access + node_password: node_password # Password associated with node_username ``` It is recommended to generate unique values for configured tokens, these are 32 bytes in hexadecimal representation. diff --git a/neon_hana/app/__init__.py b/neon_hana/app/__init__.py index bcaeeec..9fd7f1d 100644 --- a/neon_hana/app/__init__.py +++ b/neon_hana/app/__init__.py @@ -33,6 +33,7 @@ from neon_hana.app.routers.mq_backend import mq_route from neon_hana.app.routers.auth import auth_route from neon_hana.app.routers.util import util_route +from neon_hana.app.routers.node_server import node_route from neon_hana.version import __version__ @@ -47,5 +48,6 @@ def create_app(config: dict): app.include_router(mq_route) app.include_router(llm_route) app.include_router(util_route) + app.include_router(node_route) return app diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py new file mode 100644 index 0000000..33fec4a --- /dev/null +++ b/neon_hana/app/routers/node_server.py @@ -0,0 +1,75 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2021 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from asyncio import Event +from signal import signal, SIGINT +from typing import Optional, Union + +from fastapi import APIRouter, WebSocket, HTTPException, Request +from starlette.websockets import WebSocketDisconnect + +from neon_hana.app.dependencies import config, client_manager +from neon_hana.mq_websocket_api import MQWebsocketAPI + +from neon_hana.schema.node_v1 import (NodeAudioInput, NodeGetStt, + NodeGetTts, NodeKlatResponse, + NodeAudioInputResponse, + NodeGetSttResponse, + NodeGetTtsResponse) +node_route = APIRouter(prefix="/node", tags=["node"]) + +socket_api = MQWebsocketAPI(config) +signal(SIGINT, socket_api.shutdown) + + +@node_route.websocket("/v1") +async def node_v1_endpoint(websocket: WebSocket, token: str): + client_id = client_manager.get_client_id(token) + if not client_manager.validate_auth(token, client_id): + raise HTTPException(status_code=403, + detail="Invalid or expired token.") + if not client_manager.get_permissions(client_id).node: + raise HTTPException(status_code=401, + detail=f"Client not authorized for node access " + f"({client_id})") + await websocket.accept() + disconnect_event = Event() + + socket_api.new_connection(websocket, client_id) + while not disconnect_event.is_set(): + try: + client_in: dict = await websocket.receive_json() + socket_api.handle_client_input(client_in, client_id) + except WebSocketDisconnect: + disconnect_event.set() + + +@node_route.get("/v1/doc") +async def node_v1_doc(_: Optional[Union[NodeAudioInput, NodeGetStt, + NodeGetTts]]) -> \ + Optional[Union[NodeKlatResponse, NodeAudioInputResponse, + NodeGetSttResponse, NodeGetTtsResponse]]: + pass diff --git a/neon_hana/auth/client_manager.py b/neon_hana/auth/client_manager.py index ab0caaa..d1bb6d0 100644 --- a/neon_hana/auth/client_manager.py +++ b/neon_hana/auth/client_manager.py @@ -31,9 +31,12 @@ from fastapi import Request, HTTPException from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from jwt import DecodeError +from ovos_utils import LOG from token_throttler import TokenThrottler, TokenBucket from token_throttler.storage import RuntimeStorage +from neon_hana.auth.permissions import ClientPermissions + class ClientManager: def __init__(self, config: dict): @@ -48,9 +51,15 @@ def __init__(self, config: dict): self._rpm = config.get("requests_per_minute", 60) self._auth_rpm = config.get("auth_requests_per_minute", 6) self._disable_auth = config.get("disable_auth") + self._node_username = config.get("node_username") + self._node_password = config.get("node_password") self._jwt_algo = "HS256" def _create_tokens(self, encode_data: dict) -> dict: + # Permissions were not included in old tokens, allow refreshing with + # default permissions + encode_data.setdefault("permissions", ClientPermissions().as_dict()) + token_expiration = encode_data['expire'] token = jwt.encode(encode_data, self._access_secret, self._jwt_algo) encode_data['expire'] = time() + self._refresh_token_lifetime @@ -59,13 +68,38 @@ def _create_tokens(self, encode_data: dict) -> dict: # TODO: Store refresh token on server to allow invalidating clients return {"username": encode_data['username'], "client_id": encode_data['client_id'], + "permissions": encode_data['permissions'], "access_token": token, "refresh_token": refresh, "expiration": token_expiration} + def get_permissions(self, client_id: str) -> ClientPermissions: + """ + Get ClientPermissions model for the given client_id + @param client_id: Client ID to get permissions for + @return: ClientPermissions object for the specified client + """ + if self._disable_auth: + LOG.debug("Auth disabled, allow full client permissions") + return ClientPermissions(assist=True, backend=True, node=True) + if client_id not in self.authorized_clients: + LOG.warning(f"{client_id} not known to this server") + return ClientPermissions(assist=False, backend=False, node=False) + client = self.authorized_clients[client_id] + return ClientPermissions(**client.get('permissions', dict())) + def check_auth_request(self, client_id: str, username: str, password: Optional[str] = None, - origin_ip: str = "127.0.0.1"): + origin_ip: str = "127.0.0.1") -> dict: + """ + Authenticate and Authorize a new client connection with the specified + username, password, and origin IP address. + @param client_id: Client ID of the connection to auth + @param username: Supplied username to authenticate + @param password: Supplied password to authenticate + @param origin_ip: Origin IP address of request + @return: response tokens, permissions, and other metadata + """ if client_id in self.authorized_clients: print(f"Using cached client: {self.authorized_clients[client_id]}") return self.authorized_clients[client_id] @@ -84,13 +118,19 @@ def check_auth_request(self, client_id: str, username: str, detail=f"Too many auth requests from: " f"{origin_ip}. Wait {wait_time}s.") + node_access = False if username != "guest": # TODO: Validate password here pass + if all((self._node_username, username == self._node_username, + password == self._node_password)): + node_access = True + permissions = ClientPermissions(node=node_access) expiration = time() + self._access_token_lifetime encode_data = {"client_id": client_id, "username": username, "password": password, + "permissions": permissions.as_dict(), "expire": expiration} auth = self._create_tokens(encode_data) self.authorized_clients[client_id] = auth @@ -125,6 +165,15 @@ def check_refresh_request(self, access_token: str, refresh_token: str, new_auth = self._create_tokens(encode_data) return new_auth + def get_client_id(self, token: str) -> str: + """ + Extract the client_id from a JWT token + @param token: JWT token to parse + @return: client_id associated with token + """ + auth = jwt.decode(token, self._access_secret, self._jwt_algo) + return auth['client_id'] + def validate_auth(self, token: str, origin_ip: str) -> bool: if not self.rate_limiter.get_all_buckets(origin_ip): self.rate_limiter.add_bucket(origin_ip, @@ -142,6 +191,7 @@ def validate_auth(self, token: str, origin_ip: str) -> bool: if auth['expire'] < time(): self.authorized_clients.pop(auth['client_id'], None) return False + self.authorized_clients[auth['client_id']] = auth return True except DecodeError: # Invalid token supplied diff --git a/neon_hana/auth/permissions.py b/neon_hana/auth/permissions.py new file mode 100644 index 0000000..287cc38 --- /dev/null +++ b/neon_hana/auth/permissions.py @@ -0,0 +1,43 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2021 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from dataclasses import dataclass, asdict + + +@dataclass +class ClientPermissions: + """ + Data class representing permissions of a particular client connection. + """ + assist: bool = True + backend: bool = True + node: bool = False + + def as_dict(self) -> dict: + """ + Get a dict representation of this instance. + """ + return asdict(self) diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py new file mode 100644 index 0000000..2c65af1 --- /dev/null +++ b/neon_hana/mq_websocket_api.py @@ -0,0 +1,203 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2021 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from asyncio import run, get_event_loop +from os import makedirs +from time import time +from fastapi import WebSocket +from neon_iris.client import NeonAIClient +from ovos_bus_client.message import Message +from threading import RLock +from ovos_utils import LOG + + +class MQWebsocketAPI(NeonAIClient): + def __init__(self, config: dict): + """ + Creates an MQWebsocketAPI to serve multiple client WS connections. + """ + mq_config = config.get("MQ") or dict() + config_dir = "/tmp/hana" + makedirs(config_dir, exist_ok=True) + NeonAIClient.__init__(self, mq_config, config_dir=config_dir) + self._sessions = dict() + self._session_lock = RLock() + self._client = "neon_node_websocket" + + def new_connection(self, ws: WebSocket, session_id: str): + """ + Record a new client connection to associate the WebSocket with the + session_id for response routing. + @param ws: Client WebSocket object + @param session_id: Session ID of the client + """ + self._sessions[session_id] = {"session": {"session_id": session_id}, + "socket": ws, + "user": self.user_config} + + def get_session(self, session_id: str) -> dict: + """ + Get the latest session context for the given session_id. + @param session_id: Session ID to get context for + @return: dict context for the given session_id (may be empty) + """ + with self._session_lock: + sess = dict(self._sessions.get(session_id, {}).get("session", {})) + return sess + + def get_user_config(self, session_id: str) -> dict: + """ + Get a dict user configuration for the given session_id + @param session_id: Session to get user configuration for + @return: dict user configuration + """ + with self._session_lock: + config = dict(self._sessions.get(session_id, {}).get("user") or + self.user_config) + return config + + def _get_message_context(self, message: Message, session_id: str) -> dict: + """ + Build message context for a Node input message. + @param message: Input message to include context from + @param session_id: Session ID associated with the message + @return: dict context for this input + """ + user_config = self.get_user_config(session_id) + default_context = {"client_name": self.client_name, + "client": self._client, + "ident": str(time()), + "username": user_config['user']['username'], + "user_profiles": [user_config], + "neon_should_respond": True, + "timing": dict(), + "mq": {"routing_key": self.uid, + "message_id": self.connection. + create_unique_id()}} + return {**message.context, **default_context} + + def _update_session_data(self, message: Message): + """ + Update the local session data and user profile from the latest response + message's context. + @param message: Response message containing updated context + """ + session_data = message.context.get('session') + if session_data: + user_config = message.context.get('user_profiles', [None])[0] + session_id = session_data.get('session_id') + with self._session_lock: + self._sessions[session_id]['session'] = session_data + if user_config: + self._sessions[session_id]['user'] = user_config + + def handle_client_input(self, data: dict, session_id: str): + """ + Handle some client input data. + @param data: Decoded input from client WebSocket + @param session_id: Session ID associated with the client connection + """ + # Handle `Message.serialize` data sent over WS in addition to proper + # dict representations + data['msg_type'] = data.pop("type", data.get("msg_type")) + message = Message(**data) + message.context = self._get_message_context(message, session_id) + message.context["session"] = self.get_session(session_id) + # Send raw message, skipping any validation by iris + self._send_message(message) + + def handle_klat_response(self, message: Message): + """ + Handle a Neon text+audio response to a user input. + @param message: `klat.response` message from Neon + """ + self._update_session_data(message) + run(self.send_to_client(message)) + LOG.debug(message.context.get("timing")) + + def handle_complete_intent_failure(self, message: Message): + """ + Handle a Neon error response to a user input. + @param message: `complete.intent.failure` message from Neon + """ + self._update_session_data(message) + run(self.send_to_client(message)) + + def handle_api_response(self, message: Message): + """ + Handle a Neon API response to an input. + @param message: `.response` message from Neon + """ + if message.msg_type == "neon.audio_input.response": + LOG.info(message.data.get("transcripts")) + LOG.debug(message.context.get("timing")) + run(self.send_to_client(message)) + + def handle_error_response(self, message: Message): + """ + Handle an MQ error response to a user input. + @param message: `klat.error` response message + """ + run(self.send_to_client(message)) + + def clear_caches(self, message: Message): + """ + Handle a Neon request to clear cached data. + @param message: `neon.clear_data` message from Neon + """ + run(self.send_to_client(message)) + + def clear_media(self, message: Message): + """ + Handle a Neon request to clear media data. + @param message: `neon.clear_data` message from Neon + """ + run(self.send_to_client(message)) + + def handle_alert(self, message: Message): + """ + Handle an expired alert from Neon. + @param message: `neon.alert_expired` message from Neon + """ + run(self.send_to_client(message)) + + async def send_to_client(self, message: Message): + """ + Asynchronously forward a message from Neon/MQ to a WebSocket client. + @param message: Message to forward to a WebSocket client + """ + # TODO: Drop context? + session_id = message.context["session"]["session_id"] + await self._sessions[session_id]["socket"].send_text(message.serialize()) + + def shutdown(self, *_, **__): + """ + Shutdown the event loop and prepare this object for destruction. + """ + loop = get_event_loop() + loop.call_soon_threadsafe(loop.stop) + LOG.info("Stopped Event Loop") + super().shutdown() diff --git a/neon_hana/schema/node_v1.py b/neon_hana/schema/node_v1.py new file mode 100644 index 0000000..11c2725 --- /dev/null +++ b/neon_hana/schema/node_v1.py @@ -0,0 +1,124 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2021 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from pydantic import BaseModel, Field +from typing import Optional, Dict, List, Literal +from neon_hana.schema.node_model import NodeData + + +class NodeInputContext(BaseModel): + node_data: Optional[NodeData] = Field(description="Node Data") + + +class AudioInputData(BaseModel): + audio_data: str = Field(description="base64-encoded audio") + lang: str = Field(description="BCP-47 language code") + + +class TextInputData(BaseModel): + text: str = Field(description="String text input") + lang: str = Field(description="BCP-47 language code") + + +class UtteranceInputData(BaseModel): + utterances: List[str] = Field(description="List of input utterance(s)") + lang: str = Field(description="BCP-47 language") + + +class KlatResponse(BaseModel): + sentence: str = Field(description="Text response") + audio: dict = {Field(description="Audio Gender", + type=Literal["male", "female"]): + Field(description="b64-encoded audio", type=str)} + + +class TtsResponse(KlatResponse): + translated: bool = Field(description="True if sentence was translated") + phonemes: List[str] = Field(description="List of phonemes") + genders: List[str] = Field(description="List of audio genders") + + +class KlatResponseData(BaseModel): + responses: dict = {Field(type=str, + description="BCP-47 language"): KlatResponse} + + +class NodeAudioInput(BaseModel): + msg_type: str = "neon.audio_input" + data: AudioInputData + context: NodeInputContext + + +class NodeTextInput(BaseModel): + msg_type: str = "recognizer_loop:utterance" + data: UtteranceInputData + context: NodeInputContext + + +class NodeGetStt(BaseModel): + msg_type: str = "neon.get_stt" + data: AudioInputData + context: NodeInputContext + + +class NodeGetTts(BaseModel): + msg_type: str = "neon.get_tts" + data: TextInputData + context: NodeInputContext + + +class NodeKlatResponse(BaseModel): + msg_type: str = "klat.response" + data: dict = {Field(type=str, description="BCP-47 language"): KlatResponse} + context: dict + + +class NodeAudioInputResponse(BaseModel): + msg_type: str = "neon.audio_input.response" + data: dict = {"parser_data": Field(description="Dict audio parser data", + type=dict), + "transcripts": Field(description="Transcribed text", + type=List[str]), + "skills_recv": Field(description="Skills service acknowledge", + type=bool)} + context: dict + + +class NodeGetSttResponse(BaseModel): + msg_type: str = "neon.get_stt.response" + data: dict = {"parser_data": Field(description="Dict audio parser data", + type=dict), + "transcripts": Field(description="Transcribed text", + type=List[str]), + "skills_recv": Field(description="Skills service acknowledge", + type=bool)} + context: dict + + +class NodeGetTtsResponse(BaseModel): + msg_type: str = "neon.get_tts.response" + data: KlatResponseData + context: dict diff --git a/requirements/test_requirements.txt b/requirements/test_requirements.txt index 68e751a..98ff01c 100644 --- a/requirements/test_requirements.txt +++ b/requirements/test_requirements.txt @@ -1,2 +1,4 @@ pytest -mock \ No newline at end of file +mock +neon-iris~=0.1 +websockets~=12.0 \ No newline at end of file diff --git a/requirements/websocket.txt b/requirements/websocket.txt new file mode 100644 index 0000000..99b6eb0 --- /dev/null +++ b/requirements/websocket.txt @@ -0,0 +1,2 @@ +neon-iris~=0.1,>=0.1.1a5 +websockets~=12.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 892e74c..b9928ee 100644 --- a/setup.py +++ b/setup.py @@ -74,6 +74,7 @@ def get_requirements(requirements_filename: str): license='BSD-3-Clause', packages=find_packages(), install_requires=get_requirements("requirements.txt"), + extras_require={"websocket": get_requirements("websocket.txt")}, zip_safe=True, classifiers=[ 'Intended Audience :: Developers', diff --git a/tests/test_auth.py b/tests/test_auth.py index 5bec3d9..d2fcbef 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -60,6 +60,8 @@ def test_check_auth_request(self): self.assertEqual(auth_resp_2['username'], 'guest') self.assertEqual(auth_resp_2['client_id'], client_2) + # TODO: Test permissions + # Check auth already authorized self.assertEqual(auth_resp_2, self.client_manager.check_auth_request(**request_2)) @@ -77,7 +79,8 @@ def test_validate_auth(self): expired_token = self.client_manager._create_tokens( {"client_id": invalid_client, "username": "test", - "password": "test", "expire": time()})['access_token'] + "password": "test", "expire": time(), + "permissions": {}})['access_token'] self.assertFalse(self.client_manager.validate_auth(expired_token, "127.0.0.1")) @@ -93,7 +96,8 @@ def test_check_refresh_request(self): tokens = self.client_manager._create_tokens({"client_id": valid_client, "username": "test", "password": "test", - "expire": time()}) + "expire": time(), + "permissions": {}}) self.assertEqual(tokens['client_id'], valid_client) # Test invalid refresh token @@ -133,10 +137,68 @@ def test_check_refresh_request(self): tokens = self.client_manager._create_tokens({"client_id": valid_client, "username": "test", "password": "test", - "expire": time()}) + "expire": time(), + "permissions": {}}) with self.assertRaises(HTTPException) as e: self.client_manager.check_refresh_request(tokens['access_token'], tokens['refresh_token'], tokens['client_id']) self.assertEqual(e.exception.status_code, 401) self.client_manager._refresh_token_lifetime = real_refresh + + def test_get_permissions(self): + from neon_hana.auth.permissions import ClientPermissions + + node_user = "node_test" + rest_user = "rest_user" + self.client_manager._node_username = node_user + self.client_manager._node_password = node_user + + rest_resp = self.client_manager.check_auth_request(rest_user, rest_user) + node_resp = self.client_manager.check_auth_request(node_user, node_user, + node_user) + node_fail = self.client_manager.check_auth_request("node_fail", + node_user, rest_user) + + rest_cid = rest_resp['client_id'] + node_cid = node_resp['client_id'] + fail_cid = node_fail['client_id'] + + permissive = ClientPermissions(True, True, True) + no_node = ClientPermissions(True, True, False) + no_perms = ClientPermissions(False, False, False) + + # Auth disabled, returns all True + self.client_manager._disable_auth = True + self.assertEqual(self.client_manager.get_permissions(rest_cid), + permissive) + self.assertEqual(self.client_manager.get_permissions(node_cid), + permissive) + self.assertEqual(self.client_manager.get_permissions(rest_cid), + permissive) + self.assertEqual(self.client_manager.get_permissions(fail_cid), + permissive) + self.assertEqual(self.client_manager.get_permissions("fake_user"), + permissive) + + # Auth enabled + self.client_manager._disable_auth = False + self.assertEqual(self.client_manager.get_permissions(rest_cid), no_node) + self.assertEqual(self.client_manager.get_permissions(node_cid), + permissive) + self.assertEqual(self.client_manager.get_permissions(fail_cid), no_node) + self.assertEqual(self.client_manager.get_permissions("fake_user"), + no_perms) + + def test_client_permissions(self): + from neon_hana.auth.permissions import ClientPermissions + default_perms = ClientPermissions() + restricted_perms = ClientPermissions(False, False, False) + permissive_perms = ClientPermissions(True, True, True) + self.assertIsInstance(default_perms.as_dict(), dict) + for v in default_perms.as_dict().values(): + self.assertIsInstance(v, bool) + self.assertIsInstance(restricted_perms.as_dict(), dict) + self.assertFalse(any([v for v in restricted_perms.as_dict().values()])) + self.assertIsInstance(permissive_perms.as_dict(), dict) + self.assertTrue(all([v for v in permissive_perms.as_dict().values()])) diff --git a/tests/test_mq_service_api.py b/tests/test_mq_service_api.py new file mode 100644 index 0000000..8d1fa2d --- /dev/null +++ b/tests/test_mq_service_api.py @@ -0,0 +1,32 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2021 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import unittest + + +class TestMqServiceApi(unittest.TestCase): + from neon_hana.mq_service_api import MQServiceManager + # TODO diff --git a/tests/test_mq_websocket_api.py b/tests/test_mq_websocket_api.py new file mode 100644 index 0000000..6377b5d --- /dev/null +++ b/tests/test_mq_websocket_api.py @@ -0,0 +1,32 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2021 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import unittest + + +class TestMqServiceApi(unittest.TestCase): + from neon_hana.mq_websocket_api import MQWebsocketAPI + # TODO