diff --git a/.codeclimate.yml b/.codeclimate.yml deleted file mode 100644 index c513319..0000000 --- a/.codeclimate.yml +++ /dev/null @@ -1,29 +0,0 @@ ---- -engines: - duplication: - enabled: true - config: - languages: - - ruby - - javascript - - python - - php - fixme: - enabled: true - radon: - enabled: true -ratings: - paths: - - "**.inc" - - "**.js" - - "**.jsx" - - "**.module" - - "**.php" - - "**.py" - - "**.rb" -exclude_paths: - - "versioneer.py" - - "odin/_version.py" - - "odin/testing/**/*" - - "odin/static/js/bootstrap*/**/*" - - "odin/static/js/jquery*.js" diff --git a/.coveragerc-py27 b/.coveragerc-py27 new file mode 100644 index 0000000..09d6ae1 --- /dev/null +++ b/.coveragerc-py27 @@ -0,0 +1,14 @@ +[run] +omit = + *_version* + src/odin/adapters/async_adapter.py + +[paths] +source= + src/ + .tox/py*/lib/python*/site-packages/ + +[report] +omit = + *_version* + */async_*.py diff --git a/.disable-travis.yml b/.disable-travis.yml index fd2e545..9988641 100644 --- a/.disable-travis.yml +++ b/.disable-travis.yml @@ -3,7 +3,10 @@ language: python sudo: false python: - 2.7 +- 3.6 - 3.7 +- 3.8 +- 3.9 addons: apt: packages: diff --git a/.github/workflows/test_odin_control.yml b/.github/workflows/test_odin_control.yml index 9a4cc9f..b3f23c8 100644 --- a/.github/workflows/test_odin_control.yml +++ b/.github/workflows/test_odin_control.yml @@ -26,7 +26,12 @@ jobs: - name: Merge tox env specific coverage files run: | coverage combine - coverage xml + if [[ "${{ matrix.python-version }}" == 2.7* ]]; then + export COVERAGE_RC=.coveragerc-py27 + else + export COVERAGE_RC=.coveragerc + fi + coverage xml --rcfile=$COVERAGE_RC - name: Upload coverage to Codecov uses: codecov/codecov-action@v1 with: diff --git a/setup.cfg b/setup.cfg index 7ef480b..ead8eb8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,14 +1,3 @@ -[nosetests] -verbosity=2 -nocapture=1 -detailed-errors=1 -with-coverage=1 -cover-package=odin -cover-erase=1 -#debug=nose.loader -#pdb=1 -#pdb-failures=1 - [flake8] max-line-length = 100 @@ -18,3 +7,4 @@ style = pep440 versionfile_source = src/odin/_version.py versionfile_build = odin/_version.py tag_prefix= + diff --git a/src/odin/adapters/adapter.py b/src/odin/adapters/adapter.py index c20933d..da9d3e0 100644 --- a/src/odin/adapters/adapter.py +++ b/src/odin/adapters/adapter.py @@ -6,6 +6,7 @@ import logging +from odin.util import wrap_result class ApiAdapter(object): """ @@ -16,11 +17,14 @@ class ApiAdapter(object): implement them, returning an error message and 400 code. """ + is_async = False + def __init__(self, **kwargs): """Initialise the ApiAdapter object. :param kwargs: keyword argument list that is copied into options dictionary """ + super(ApiAdapter, self).__init__() self.name = type(self).__name__ # Load any keyword arguments into the adapter options dictionary @@ -51,6 +55,20 @@ def get(self, path, request): response = "GET method not implemented by {}".format(self.name) return ApiAdapterResponse(response, status_code=400) + def post(self, path, request): + """Handle an HTTP POST request. + + This method is an abstract implementation of the POST request handler for ApiAdapter. + + :param path: URI path of resource + :param request: HTTP request object passed from handler + :return: ApiAdapterResponse container of data, content-type and status_code + """ + logging.debug('POST on path %s from %s: method not implemented by %s', + path, request.remote_ip, self.name) + response = "POST method not implemented by {}".format(self.name) + return ApiAdapterResponse(response, status_code=400) + def put(self, path, request): """Handle an HTTP PUT request. @@ -200,9 +218,10 @@ def wrapper(_self, path, request): # Validate the Content-Type header in the request against allowed types if 'Content-Type' in request.headers: if request.headers['Content-Type'] not in oargs: - return ApiAdapterResponse( + response = ApiAdapterResponse( 'Request content type ({}) not supported'.format( request.headers['Content-Type']), status_code=415) + return wrap_result(response, _self.is_async) return func(_self, path, request) return wrapper return decorator @@ -254,10 +273,10 @@ def wrapper(_self, path, request): # If it was not possible to resolve a response type or there was not default # given, return an error code 406 if response_type is None: - return ApiAdapterResponse( + response = ApiAdapterResponse( "Requested content types not supported", status_code=406 ) - + return wrap_result(response, _self.is_async) else: response_type = okwargs['default'] if 'default' in okwargs else 'text/plain' request.headers['Accept'] = response_type diff --git a/src/odin/adapters/async_adapter.py b/src/odin/adapters/async_adapter.py new file mode 100644 index 0000000..25a53f4 --- /dev/null +++ b/src/odin/adapters/async_adapter.py @@ -0,0 +1,123 @@ +""" +odin.adapters.adapter.py - base asynchronous API adapter implmentation for the ODIN server. + +Tim Nicholls, STFC Detector Systems Software Group +""" + +import asyncio +import logging +import inspect + +from odin.adapters.adapter import ApiAdapter, ApiAdapterResponse + + +class AsyncApiAdapter(ApiAdapter): + """ + Asynchronous API adapter base class. + + This class defines the basis for all async API adapters and provides default + methods for the required HTTP verbs in case the derived classes fail to + implement them, returning an error message and 400 code. + """ + + is_async = True + + def __init__(self, **kwargs): + """Initialise the AsyncApiAdapter object. + + :param kwargs: keyword argument list that is copied into options dictionary + """ + super(AsyncApiAdapter, self).__init__(**kwargs) + + def __await__(self): + """Make AsyncApiAdapter objects awaitable. + + This magic method makes the instantiation of AsyncApiAdapter objects awaitable. This allows + any underlying async and awaitable attributes, e.g. an AsyncParameterTree, to be correctly + awaited when the adapter is loaded.""" + async def closure(): + """Await all async attributes of the adapter.""" + awaitable_attrs = [attr for attr in self.__dict__.values() if inspect.isawaitable(attr)] + await asyncio.gather(*awaitable_attrs) + return self + + return closure().__await__() + + async def initialize(self, adapters): + """Initialize the AsyncApiAdapter after it has been registered by the API Route. + + This is an abstract implementation of the initialize mechinism that allows + an adapter to receive a list of loaded adapters, for Inter-adapter communication. + :param adapters: a dictionary of the adapters loaded by the API route. + """ + + pass + + async def cleanup(self): + """Clean up adapter state. + + This is an abstract implementation of the cleanup mechanism provided to allow adapters + to clean up their state (e.g. disconnect cleanly from the device being controlled, set + some status message). + """ + pass + + async def get(self, path, request): + """Handle an HTTP GET request. + + This method is an abstract implementation of the GET request handler for AsyncApiAdapter. + + :param path: URI path of resource + :param request: HTTP request object passed from handler + :return: ApiAdapterResponse container of data, content-type and status_code + """ + logging.debug('GET on path %s from %s: method not implemented by %s', + path, request.remote_ip, self.name) + await asyncio.sleep(0) + response = "GET method not implemented by {}".format(self.name) + return ApiAdapterResponse(response, status_code=400) + + async def post(self, path, request): + """Handle an HTTP POST request. + + This method is an abstract implementation of the POST request handler for AsyncApiAdapter. + + :param path: URI path of resource + :param request: HTTP request object passed from handler + :return: ApiAdapterResponse container of data, content-type and status_code + """ + logging.debug('POST on path %s from %s: method not implemented by %s', + path, request.remote_ip, self.name) + await asyncio.sleep(0) + response = "POST method not implemented by {}".format(self.name) + return ApiAdapterResponse(response, status_code=400) + + async def put(self, path, request): + """Handle an HTTP PUT request. + + This method is an abstract implementation of the PUT request handler for AsyncApiAdapter. + + :param path: URI path of resource + :param request: HTTP request object passed from handler + :return: ApiAdapterResponse container of data, content-type and status_code + """ + logging.debug('PUT on path %s from %s: method not implemented by %s', + path, request.remote_ip, self.name) + await asyncio.sleep(0) + response = "PUT method not implemented by {}".format(self.name) + return ApiAdapterResponse(response, status_code=400) + + async def delete(self, path, request): + """Handle an HTTP DELETE request. + + This method is an abstract implementation of the DELETE request handler for ApiAdapter. + + :param path: URI path of resource + :param request: HTTP request object passed from handler + :return: ApiAdapterResponse container of data, content-type and status_code + """ + logging.debug('DELETE on path %s from %s: method not implemented by %s', + path, request.remote_ip, self.name) + await asyncio.sleep(0) + response = "DELETE method not implemented by {}".format(self.name) + return ApiAdapterResponse(response, status_code=400) diff --git a/src/odin/adapters/async_dummy.py b/src/odin/adapters/async_dummy.py new file mode 100644 index 0000000..753145e --- /dev/null +++ b/src/odin/adapters/async_dummy.py @@ -0,0 +1,202 @@ + +"""Dummy asynchronous adapter classes for the ODIN server. + +The AsyncDummyAdapter class implements a dummy asynchronous adapter for the ODIN server, +demonstrating the basic asymc adapter implementation and providing a loadable adapter for testing. + +Tim Nicholls, STFC Detector Systems Software Group. +""" +import asyncio +import logging +import time +import concurrent.futures + +from odin.adapters.adapter import ApiAdapterResponse, request_types, response_types +from odin.adapters.async_adapter import AsyncApiAdapter +from odin.adapters.async_parameter_tree import AsyncParameterTree +from odin.adapters.base_parameter_tree import ParameterTreeError +from odin.util import decode_request_body, run_in_executor + + +class AsyncDummyAdapter(AsyncApiAdapter): + """Dummy asynchronous adapter class for the ODIN server. + + This dummy adapter implements basic asynchronous operation of an adapter, including use of an + async parameter tree, and async GET and PUT methods. The parameter tree includes sync and async + accessors, which simulate long-running tasks by sleeping, either using native async sleep or + by sleeping in a thread pool executor. This shows that the calling server can remain responsive + during long-running async tasks. + """ + + def __init__(self, **kwargs): + """Intialize the AsyncDummy Adapter object. + + This constructor initializes the AsyncDummyAdapter object, including configuring an async + parameter tree with accessors triggering simulated long-running task (sleep), the duration + and implemntation of which can be selected by configuration parameters. + """ + super(AsyncDummyAdapter, self).__init__(**kwargs) + + # Parse the configuration options to determine the sleep duration and if we are wrapping + # a synchronous sleep in a thread pool executor. + self.async_sleep_duration = float(self.options.get('async_sleep_duration', 2.0)) + self.wrap_sync_sleep = bool(int(self.options.get('wrap_sync_sleep', 0))) + + sleep_mode_msg = 'sync thread pool executor' if self.wrap_sync_sleep else 'native async' + logging.debug("Configuring async sleep task using {} with duration {} secs".format( + sleep_mode_msg, self.async_sleep_duration + )) + + # Initialise counters for the async and sync tasks and a trivial async read/write parameter + self.sync_task_count = 0 + self.async_task_count = 0 + self.async_rw_param = 1234 + + self.param_tree = AsyncParameterTree({ + 'async_sleep_duration': (self.get_async_sleep_duration, None), + 'wrap_sync_sleep': (self.get_wrap_sync_sleep, None), + 'sync_task_count': (lambda: self.sync_task_count, None), + 'async_task_count': (lambda: self.async_task_count, None), + 'async_rw_param': (self.get_async_rw_param, self.set_async_rw_param), + }) + + # Create the thread pool executor + self.executor = concurrent.futures.ThreadPoolExecutor() + + async def initialize(self, adapters): + """Initalize the adapter. + + This dummy method demonstrates that async adapter initialisation can be performed + asynchronously. + + :param adapters: list of adapters loaded into the server + """ + logging.debug("AsyncDummyAdapter initialized with %d adapters", len(adapters)) + await asyncio.sleep(0) + + async def cleanup(self): + """Clean up the adapter. + + This dummy method demonstrates that async adapter cleanup can be performed asynchronously. + """ + logging.debug("AsyncDummyAdapter cleanup called") + await asyncio.sleep(0) + + @response_types('application/json', default='application/json') + async def get(self, path, request): + """Handle an HTTP GET request. + + This method handles an HTTP GET request, returning a JSON response. The parameter tree + data at the specified path is returned in the response. The underlying tree has a mix of + sync and async parameter accessors, and GET requests simulate the concurrent operation of + async adapters by sleeping for specified periods where appropriate. + + :param path: URI path of request + :param request: HTTP request object + :return: an ApiAdapterResponse object containing the appropriate response + """ + try: + response = await self.param_tree.get(path) + status_code = 200 + except ParameterTreeError as param_error: + response = {'error': str(param_error)} + status_code = 400 + + logging.debug("GET on path %s : %s", path, response) + content_type = 'application/json' + + return ApiAdapterResponse(response, content_type=content_type, status_code=status_code) + + @request_types('application/json', 'application/vnd.odin-native') + @response_types('application/json', default='application/json') + async def put(self, path, request): + """Handle an HTTP PUT request. + + This method handles an HTTP PUT request, decoding the request and attempting to set values + in the asynchronous parameter tree as appropriate. + + :param path: URI path of request + :param request: HTTP request object + :return: an ApiAdapterResponse object containing the appropriate response + """ + content_type = 'application/json' + + try: + data = decode_request_body(request) + await self.param_tree.set(path, data) + response = await self.param_tree.get(path) + status_code = 200 + except ParameterTreeError as param_error: + response = {'error': str(param_error)} + status_code = 400 + + return ApiAdapterResponse( + response, content_type=content_type, status_code=status_code + ) + + def sync_task(self): + """Simulate a synchronous long-running task. + + This method simulates a long-running task by sleeping for the configured duration. It + is made aysnchronous by wrapping it in a thread pool exector. + """ + logging.debug("Starting simulated sync task") + self.sync_task_count += 1 + time.sleep(self.async_sleep_duration) + logging.debug("Finished simulated sync task") + + async def async_task(self): + """Simulate a synchronous long-running task. + + This method simulates an async long-running task by performing an asyncio sleep for the + configured duration. + """ + logging.debug("Starting simulated async task") + self.async_task_count += 1 + await asyncio.sleep(self.async_sleep_duration) + logging.debug("Finished simulated async task") + + async def get_async_sleep_duration(self): + """Simulate an async parameter access. + + This method demonstrates an asynchronous parameter access, return the current value of the + async sleep duration parameter passed into the adapter as an option. + """ + logging.debug("Entering async sleep duration get function") + if self.wrap_sync_sleep: + await run_in_executor(self.executor, self.sync_task) + else: + await self.async_task() + + logging.debug("Returning async sleep duration parameter: %f", self.async_sleep_duration) + return self.async_sleep_duration + + def get_wrap_sync_sleep(self): + """Simulate a sync parameter access. + + This method demonstrates a synchronous parameter access, returning the the current value + of the wrap sync sleep parameter passed into the adapter as an option. + """ + logging.debug("Getting wrap sync sleep flag: %s", str(self.wrap_sync_sleep)) + return self.wrap_sync_sleep + + async def get_async_rw_param(self): + """Get the value of the async read/write parameter. + + This async method returns the current value of the async read/write parameter. + + :returns: current value of the async read/write parameter. + """ + await asyncio.sleep(0) + return self.async_rw_param + + async def set_async_rw_param(self, value): + """Set the value of the async read/write parameter. + + This async updates returns the current value of the async read/write parameter. + + :param: new value to set parameter to + """ + await asyncio.sleep(0) + self.async_rw_param = value + diff --git a/src/odin/adapters/async_parameter_tree.py b/src/odin/adapters/async_parameter_tree.py new file mode 100644 index 0000000..e2130da --- /dev/null +++ b/src/odin/adapters/async_parameter_tree.py @@ -0,0 +1,241 @@ +"""async_parameter_tree.py - classes representing an asychronous parameter tree and accessor. + +This module defines a parameter tree and accessor for use in asynchronous API adapters, where +concurrency over blocking operations (e.g. to read/write the value of a parameter from hardware) +is required. + +Tim Nicholls, STFC Detector Systems Software Group. +""" + +import asyncio + +from odin.adapters.base_parameter_tree import ( + BaseParameterAccessor, BaseParameterTree, ParameterTreeError +) + +__all__ = ['AsyncParameterAccessor', 'AsyncParameterTree', 'ParameterTreeError'] + +# if sys.version_info < (3,7): +# async_create_task = asyncio.ensure_future +# else: +# async_create_task = asyncio.create_task +try: + async_create_task = asyncio.create_task +except AttributeError: + async_create_task = asyncio.ensure_future + +class AsyncParameterAccessor(BaseParameterAccessor): + """Asynchronous container class representing accessor methods for a parameter. + + This class extends the base parameter accessor class to support asynchronous set and get + accessors for a parameter. Read-only and writeable parameters are supported, and the same + metadata fields are implemented. + + Note that the instantiation of objects of this class MUST be awaited to allow the getter + function to evaluate and record the parameter type in the metadata, e.g. + + accessor = await AsyncParameterAccessor(....) + + Accessors instantiated during the intialisation of an AsyncParameterTree will automatically be + collected and awaited by the tree itself. + """ + + def __init__(self, path, getter=None, setter=None, **kwargs): + """Initialise the AsyncParameterAccessor instance. + + This constructor initialises the AsyncParameterAccessor instance, storing the path of the + parameter, its set/get accessors and setting metadata fields based on the the specified + keyword arguments. + + :param path: path of the parameter within the tree + :param getter: get method for the parameter, or a value if read-only constant + :param setter: set method for the parameter + :param kwargs: keyword argument list for metadata fields to be set; these must be from + the allow list specified in BaseParameterAccessor.allowed_metadata + """ + # Initialise the superclass with the specified arguments + super(AsyncParameterAccessor, self).__init__(path, getter, setter, **kwargs) + + def __await__(self): + """Make AsyncParameterAccessor objects awaitable. + + This magic method makes the instantiation of AsyncParameterAccessor objects awaitable. This + is required since instantiation must call the specified get() method, which is itself async, + in order to resolve the type of the parameter and store that in the metadata. This cannot be + done directly in the constructor. + + :returns: an awaitable future + """ + async def closure(): + """Resolve the parameter type in an async closure.""" + self._type = type(await self.get()) + self.metadata["type"] = self._type.__name__ + return self + + return closure().__await__() + + @staticmethod + async def resolve_coroutine(value): + """Resolve a coroutine and return its value. + + This static convenience method allows an accessor to resolve the output of its getter/setter + functions to avalue if an async coroutine is returned. + + :param value: value or coroutine to resolve + :returns: resolved value + """ + if value and asyncio.iscoroutine(value): + value = await value + + return value + + async def get(self, with_metadata=False): + """Get the value of the parameter. + + This async method returns the value of the parameter, or the value returned by the accessor + getter, if one is defined (i.e. is callable). If the getter is itself async, the value is + resolved by awaiting the returned coroutine. If the with_metadata argument is true, the + value is returned in a dictionary including all metadata for the parameter. + + :param with_metadata: include metadata in the response when set to True + :returns value of the parameter + """ + # Call the superclass get method + value = super(AsyncParameterAccessor, self).get(with_metadata) + + # Resolve and await the returned value, either into the metadata-populated dict or directly + # as the returned value + if with_metadata: + value["value"] = await(self.resolve_coroutine(value["value"])) + else: + value = await self.resolve_coroutine(value) + + return value + + async def set(self, value): + """Set the value of the parameter. + + This async method sets the value of the parameter by calling the set accessor + if defined and callable. The result is awaited if a coroutine is returned. + + :param value: value to set + """ + await self.resolve_coroutine(super(AsyncParameterAccessor, self).set(value)) + + +class AsyncParameterTree(BaseParameterTree): + """Class implementing an asynchronous tree of parameters and their accessors. + + This async class implements an arbitrarily-structured, recursively-managed tree of parameters + and the appropriate accessor methods that are used to read and write those parameters. + + Note that the instantiation of an AsyncParameterTree MUST be awaited by calling code to allow + the type and intial value of each parameter to be resolved, e.g.: + + tree = await AsyncParameterTree(...) + """ + + def __init__(self, tree, mutable=False): + """Initialise the AsyncParameterTree object. + + This constructor recursively initialises the AsyncParameterTree object based on the + specified arguments. The tree initialisation syntax follows that of the BaseParameterTree + implementation. + + :param tree: dict representing the parameter tree + :param mutable: Flag, setting the tree + """ + # Set the accessor class used by this tree to AsyncParameterAccessor + self.accessor_cls = AsyncParameterAccessor + + # Initialise the superclass with the speccified parameters + super(AsyncParameterTree, self).__init__(tree, mutable) + + def __await__(self): + """Make AsyncParameterTree objects awaitable. + + This magic method makes the instantiation of AsyncParameterTree objects awaitable. This + is required since the underlying accessor objects must also be awaited at initialisation + to resolve their type and intial values. This is achieved by traversing the parameter tree + and gathering all awaitable accessor instances and awaiting them. + """ + def get_awaitable_params(node): + """Traverse the parameter tree and build a list of awaitable accessors.""" + awaitable_params = [] + if (isinstance(node, dict)): + for val in node.values(): + if isinstance(val, self.accessor_cls): + awaitable_params.append(val) + else: + awaitable_params.extend(get_awaitable_params(val)) + return awaitable_params + + async def closure(): + """Resolve the parameter tree accessor types in an async closure.""" + await asyncio.gather(*get_awaitable_params(self.tree)) + return self + + return closure().__await__() + + async def get(self, path, with_metadata=False): + """Get the values of parameters in a tree. + + This async method returns the values at and below a specified path in the parameter tree. + This is done by recursively populating the tree with the current values of parameters, + returning the result as a dictionary. + + :param path: path in tree to get parameter values for + :param with_metadata: include metadata in the response when set to True + :returns: dict of parameter tree at the specified path + """ + value = super(AsyncParameterTree, self).get(path, with_metadata) + + async def resolve_value(value): + """Recursively resolve the values of the parameters. + + This inner method recursively decends through the tree of parameters being returned by + the get() call, awaiting any async getter methods. These are done sequentially to allow + the values to be resolved in-place within the tree. + """ + if isinstance(value, dict): + for (k, v) in value.items(): + if asyncio.iscoroutine(v): + value[k] = await v + else: + await resolve_value(v) + + # Resolve values of parameters in the tree + await resolve_value(value) + return value + + async def set(self, path, data): + """Set the values of the parameters in a tree. + + This async method sets the values of parameters in a tree, based on the data passed to it + as a nested dictionary of parameter and value pairs. The updated parameters are merged + into the existing tree recursively. + + :param path: path to set parameters for in the tree + :param data: nested dictionary representing values to update at the path + """ + # Create an empty list of awaitable parameters + self.awaitable_params = [] + + # Call the superclass set method with the specified parameters + super(AsyncParameterTree, self).set(path, data) + + # Await any async set methods in the modified parameters + await asyncio.gather(*self.awaitable_params) + + def _set_node(self, node, data): + """Set the value of a node to the specified data. + + This method sets a specified node to the data supplied. If the setter function for the node + is async, it is added to the list of parameters to be awaited by the set() method. + + :param node: tree node to set value of + :param data: data to node value to + """ + response = node.set(data) + if asyncio.iscoroutine(response): + self.awaitable_params.append(async_create_task(response)) diff --git a/src/odin/adapters/async_proxy.py b/src/odin/adapters/async_proxy.py new file mode 100644 index 0000000..e98f6a8 --- /dev/null +++ b/src/odin/adapters/async_proxy.py @@ -0,0 +1,196 @@ +""" +Asynchronous proxy adapter for use in odin-control. + +This module implements a simple asynchronous proxy adapter, allowing requests to be proxied to +one or more remote HTTP resources, typically further odin-control instances. + +Tim Nicholls, Ashley Neaves STFC Detector Systems Software Group. +""" +import asyncio +import inspect + +from tornado.httpclient import AsyncHTTPClient +from odin.util import decode_request_body +from odin.adapters.adapter import ( + ApiAdapterResponse, request_types, response_types, wants_metadata +) +from odin.adapters.async_adapter import AsyncApiAdapter +from odin.adapters.base_proxy import BaseProxyTarget, BaseProxyAdapter + + +class AsyncProxyTarget(BaseProxyTarget): + """ + Asynchronous proxy adapter target class. + + This class implements an asynchronous proxy target, its parameter tree and associated + status information for use in the ProxyAdapter. + """ + + def __init__(self, name, url, request_timeout): + """ + Initialise the AsyncProxyTarget object. + + This constructor initialises the AsyncProxyTarget, creating an async HTTP client and + delegating the full initialisation to the base class. + + + :param name: name of the proxy target + :param url: URL of the remote target + :param request_timeout: request timeout in seconds + """ + # Create an async HTTP client for use in this target + self.http_client = AsyncHTTPClient() + + # Initialise the base class + super(AsyncProxyTarget, self).__init__(name, url, request_timeout) + + def __await__(self): + """ + Make AsyncProxyTarget objects awaitable. + + This magic method makes the instantation of AsyncProxyTarget objects awaitable. This allows + the async calls to remote_get, used to populate the data and metadata trees from the remote + target, to be awaited. + """ + async def closure(): + """Await the calls to the remote target to populate and data and metadata tress.""" + await self.remote_get() + await self.remote_get(get_metadata=True) + return self + + return closure().__await__() + + async def remote_get(self, path='', get_metadata=False): + """ + Get data from the remote target. + + This async method requests data from the remote target by issuing a GET request to the + target URL, and then updates the local proxy target data and status information according to + the response. The detailed handling of this is implemented by the base class. + + :param path: path to data on remote target + :param get_metadata: flag indicating if metadata is to be requested + """ + await super(AsyncProxyTarget, self).remote_get(path, get_metadata) + + async def remote_set(self, path, data): + """ + Set data on the remote target. + + This async method sends data to the remote target by issuing a PUT request to the target + URL, and then updates the local proxy target data and status information according to the + response. The detailed handling of this is implemented by the base class. + + :param path: path to data on remote target + :param data: data to set on remote target + """ + await super(AsyncProxyTarget, self).remote_set(path, data) + + async def _send_request(self, request, path, get_metadata=False): + """ + Send a request to the remote target and update data. + + This internal async method sends a request to the remote target using the HTTP client + and handles the response, updating target data accordingly. + + :param request: HTTP request to transmit to target + :param path: path of data being updated + :param get_metadata: flag indicating if metadata is to be requested + """ + # Send the request to the remote target, handling any exceptions that occur + try: + response = await self.http_client.fetch(request) + except Exception as fetch_exception: + # Set the response to the exception so it can be handled during response resolution + response = fetch_exception + + # Process the response from the target, updating data as appropriate + self._process_response(response, path, get_metadata) + + +class AsyncProxyAdapter(AsyncApiAdapter, BaseProxyAdapter): + """ + Asynchronous proxy adapter class for odin-control. + + This class implements a proxy adapter, allowing odin-control to forward requests to + other HTTP services. + """ + + def __init__(self, **kwargs): + """ + Initialise the AsyncProxyAdapter. + + This constructor initialises the adapter instance. The base class adapter is initialised + with the keyword arguments and then the proxy targets and paramter tree initialised by the + proxy adapter mixin. + + :param kwargs: keyword arguments specifying options + """ + # Initialise the base class + super(AsyncProxyAdapter, self).__init__(**kwargs) + + # Initialise the proxy targets and parameter trees + self.initialise_proxy(AsyncProxyTarget) + + def __await__(self): + """ + Make AsyncProxyAdapter objects awaitable. + + This magic method makes the instantation of AsyncProxyAdapter objects awaitable. This allows + the async proxy targets to be awaited at initialisation. + """ + + async def closure(): + """Construct a list of awaitable attributes and targets and await initialisation.""" + awaitables = [attr for attr in self.__dict__.values() if inspect.isawaitable(attr)] + awaitables += [target for target in self.targets if inspect.isawaitable(target)] + await asyncio.gather(*awaitables) + return self + + return closure().__await__() + + @response_types('application/json', default='application/json') + async def get(self, path, request): + """ + Handle an HTTP GET request. + + This async method handles an HTTP GET request, returning a JSON response. The request is + passed to the adapter proxy and resolved into responses from the requested proxy targets. + + :param path: URI path of request + :param request: HTTP request object + :return: an ApiAdapterResponse object containing the appropriate response + """ + get_metadata = wants_metadata(request) + + await asyncio.gather(*self.proxy_get(path, get_metadata)) + (response, status_code) = self._resolve_response(path, get_metadata) + + return ApiAdapterResponse(response, status_code=status_code) + + @request_types("application/json", "application/vnd.odin-native") + @response_types('application/json', default='application/json') + async def put(self, path, request): + """ + Handle an HTTP PUT request. + + This async method handles an HTTP PUT request, returning a JSON response. The request is + passed to the adapter proxy to set data on the remote targets and resolved into responses + from those targets. + + :param path: URI path of request + :param request: HTTP request object + :return: an ApiAdapterResponse object containing the appropriate response + """ + # Decode the request body from JSON, handling and returning any errors that occur. Otherwise + # send the PUT request to the remote target + try: + body = decode_request_body(request) + except (TypeError, ValueError) as type_val_err: + response = {'error': 'Failed to decode PUT request body: {}'.format(str(type_val_err))} + status_code = 415 + else: + await asyncio.gather(*self.proxy_set(path, body)) + (response, status_code) = self._resolve_response(path) + + return ApiAdapterResponse(response, status_code=status_code) diff --git a/src/odin/adapters/base_parameter_tree.py b/src/odin/adapters/base_parameter_tree.py new file mode 100644 index 0000000..1492354 --- /dev/null +++ b/src/odin/adapters/base_parameter_tree.py @@ -0,0 +1,523 @@ +"""base_parameter_tree.py - base classes representing a tree of parameters and accessors. + +This module implements an arbitrarily-structured, recursively-managed tree of parameters and +the appropriate accessor methods that are used to read and write those parameters. Its +particular use is in the definition of a tree of parameters for an API adapter and help +interfacing of those to the underlying device or object. These base classes are not intended to be +used directly, but form the basis for concrete synchronous and asynchronous implementations. + +James Hogge, Tim Nicholls, STFC Application Engineering Group. +""" + + +class ParameterTreeError(Exception): + """Simple error class for raising parameter tree parameter tree exceptions.""" + + pass + + +class BaseParameterAccessor(object): + """Base container class representing accessor methods for a parameter. + + This base class implements a parameter accessor, provding set and get methods + for parameters requiring calls to access them, or simply returning the + appropriate value if the parameter is a read-only constant. Parameter accessors also + contain metadata fields controlling access to and providing information about the parameter. + + Valid specifiable metadata fields are: + min : minimum allowed value for parameter + max : maxmium allowed value for parameter + allowed_values: list of allowed values for parameter + name : readable parameter name + description: longer description of parameter + units: parameter units + display_precision: number of decimal places to display for e.g. float types + + The class also maintains the following automatically-populated metadata fields: + type: parameter type + writeable: is the parameter writable + """ + + # Valid metadata arguments that can be passed to ParameterAccess __init__ method. + VALID_METADATA_ARGS = ( + "min", "max", "allowed_values", "name", "description", "units", "display_precision" + ) + # Automatically-populated metadata fields based on inferred type of the parameter and + # writeable status depending on specified accessors + AUTO_METADATA_FIELDS = ("type", "writeable") + + def __init__(self, path, getter=None, setter=None, **kwargs): + """Initialise the BaseParameterAccessor instance. + + This constructor initialises the BaseParameterAccessor instance, storing + the path of the parameter, its set/get accessors and setting metadata fields based + on the the specified keyword arguments + + :param path: path of the parameter within the tree + :param getter: get method for the parameter, or a value if read-only constant + :param setter: set method for the parameter + :param kwargs: keyword argument list for metadata fields to be set; these must be from + the allow list specified in ParameterAccessor.allowed_metadata + """ + # Initialise path, getter and setter + self.path = path[:-1] + self._get = getter + self._set = setter + + # Initialize metadata dict + self.metadata = {} + + # Check metadata keyword arguments are valid + for arg in kwargs: + if arg not in BaseParameterAccessor.VALID_METADATA_ARGS: + raise ParameterTreeError("Invalid metadata argument: {}".format(arg)) + + # Update metadata keywords from arguments + self.metadata.update(kwargs) + + # Set the writeable metadata field based on specified accessors + if not callable(self._set) and callable(self._get): + self.metadata["writeable"] = False + else: + self.metadata["writeable"] = True + + def get(self, with_metadata=False): + """Get the value of the parameter. + + This method returns the value of the parameter, or the value returned + by the get accessor if one is defined (i.e. is callable). If the with_metadata argument + is true, the value is returned in a dictionary including all metadata for the + parameter. + + :param with_metadata: include metadata in the response when set to True + :returns value of the parameter + """ + # Determine the value of the parameter by calling the getter or simply from the stored + # value + if callable(self._get): + value = self._get() + else: + value = self._get + + # If metadata is requested, replace the value with a dict containing the value itself + # plus metadata fields + if with_metadata: + value = {"value": value} + value.update(self.metadata) + + return value + + def set(self, value): + """Set the value of the parameter. + + This method sets the value of the parameter by calling the set accessor + if defined and callable, otherwise raising an exception. + + :param value: value to set + """ + # Raise an error if this parameter is not writeable + if not self.metadata["writeable"]: + raise ParameterTreeError("Parameter {} is read-only".format(self.path)) + + # Raise an error of the value to be set is not of the same type as the parameter. If + # the metadata type field is set to None, allow any type to be set, or if the value + # is integer and the parameter is float, also allow as JSON does not differentiate + # numerics in all cases + if self.metadata["type"] != "NoneType" and not isinstance(value, self._type): + if not (isinstance(value, int) and self.metadata["type"] == "float"): + raise ParameterTreeError( + "Type mismatch setting {}: got {} expected {}".format( + self.path, type(value).__name__, self.metadata["type"] + ) + ) + + # Raise an error if allowed_values has been set for this parameter and the value to + # set is not one of them + if "allowed_values" in self.metadata and value not in self.metadata["allowed_values"]: + raise ParameterTreeError( + "{} is not an allowed value for {}".format(value, self.path) + ) + + # Raise an error if the parameter has a mininum value specified in metadata and the + # value to set is below this + if "min" in self.metadata and value < self.metadata["min"]: + raise ParameterTreeError( + "{} is below the minimum value {} for {}".format( + value, self.metadata["min"], self.path + ) + ) + + # Raise an error if the parameter has a maximum value specified in metadata and the + # value to set is above this + if "max" in self.metadata and value > self.metadata["max"]: + raise ParameterTreeError( + "{} is above the maximum value {} for {}".format( + value, self.metadata["max"], self.path + ) + ) + + # Set the new parameter value, either by calling the setter or updating the local + # value as appropriate + response = None + if callable(self._set): + response = self._set(value) + elif not callable(self._get): + self._get = value + + return response + + +class BaseParameterTree(object): + """Base class implementing a tree of parameters and their accessors. + + This base class implements an arbitrarily-structured, recursively-managed tree of parameters and + the appropriate accessor methods that are used to read and write those parameters. Its + particular use is in the definition of a tree of parameters for an API adapter and help + interfacing of those to the underlying device or object. + """ + + METADATA_FIELDS = ["name", "description"] + + def __init__(self, tree, mutable=False): + """Initialise the BaseParameterTree object. + + This constructor recursively initialises the BaseParameterTree object, based on the + parameter tree dictionary passed as an argument. This is done recursively, so that a + parameter tree can have arbitrary depth and contain other BaseParameterTree instances + as necessary. + + Initialisation syntax for BaseParameterTree is made by passing a dict representing the tree + as an argument. Children of a node at any level of the tree are described with + dictionaries/lists e.g. + + {"parent" : {"childA" : {...}, "childB" : {...}}} + {"parent" : [{...}, {...}]} + + Leaf nodes can be one of the following formats: + + value - (value,) - (value, {metadata}) + getter - (getter,) - (getter, {metadata}) + (getter, setter) - (getter, setter, {metadata}) + + The following tags will also be treated as metadata: + + name - A printable name for that branch of the tree + description - A printable description for that branch of the tree + + :param tree: dict representing the parameter tree + :param mutable: Flag, setting the tree + """ + # Flag, if set to true, allows nodes to be replaced and new nodes created + self.mutable = mutable + + # list of paths to mutable parts. Not sure this is best solution + self.mutable_paths = [] + + # Recursively check and initialise the tree + self._tree = self._build_tree(tree) + + @property + def tree(self): + """Return tree object for this parameter tree node. + + Used internally for recursive descent of parameter trees. + """ + return self._tree + + def get(self, path, with_metadata=False): + """Get the values of parameters in a tree. + + This method returns the values at and below a specified path in the parameter tree. + This is done by recursively populating the tree with the current values of parameters, + returning the result as a dictionary. + + :param path: path in tree to get parameter values for + :param with_metadata: include metadata in the response when set to True + :returns: dict of parameter tree at the specified path + """ + # Split the path by levels, truncating the last level if path ends in trailing slash + levels = path.split('/') + if levels[-1] == '': + del levels[-1] + + # Initialise the subtree before descent + subtree = self._tree + + # If this is single level path, return the populated tree at the top level + if not levels: + return self._populate_tree(subtree, with_metadata) + + # Descend the specified levels in the path, checking for a valid subtree of the appropriate + # type + for level in levels: + if level in self.METADATA_FIELDS and not with_metadata: + raise ParameterTreeError("Invalid path: {}".format(path)) + try: + if isinstance(subtree, dict): + subtree = subtree[level] + elif isinstance(subtree, self.accessor_cls): + subtree = subtree.get(with_metadata)[level] + else: + subtree = subtree[int(level)] + except (KeyError, ValueError, IndexError): + raise ParameterTreeError("Invalid path: {}".format(path)) + + # Return the populated tree at the appropriate path + return self._populate_tree({levels[-1]: subtree}, with_metadata) + + def set(self, path, data): + """Set the values of the parameters in a tree. + + This method sets the values of parameters in a tree, based on the data passed to it + as a nested dictionary of parameter and value pairs. The updated parameters are merged + into the existing tree recursively. + + :param path: path to set parameters for in the tree + :param data: nested dictionary representing values to update at the path + """ + # Expand out any lists/tuples + data = self._build_tree(data) + + # Get subtree from the node the path points to + levels = path.split('/') + if levels[-1] == '': + del levels[-1] + + merge_parent = None + merge_child = self._tree + + # Descend the tree and validate each element of the path + for level in levels: + if level in self.METADATA_FIELDS: + raise ParameterTreeError("Invalid path: {}".format(path)) + try: + merge_parent = merge_child + if isinstance(merge_child, dict): + merge_child = merge_child[level] + else: + merge_child = merge_child[int(level)] + except (KeyError, ValueError, IndexError): + raise ParameterTreeError("Invalid path: {}".format(path)) + + # Add trailing / to paths where necessary + if path and path[-1] != '/': + path += '/' + + # Merge data with tree + merged = self._merge_tree(merge_child, data, path) + + # Add merged part to tree, either at the top of the tree or at the + # appropriate level speicfied by the path + if not levels: + self._tree = merged + return + if isinstance(merge_parent, dict): + merge_parent[levels[-1]] = merged + else: + merge_parent[int(levels[-1])] = merged + + def delete(self, path=''): + """ + Remove Parameters from a Mutable Tree. + + This method deletes selected parameters from a tree, if that tree has been flagged as + Mutable. Deletion of Branch Nodes means all child nodes of that Branch Node are also deleted + + :param path: Path to selected Parameter Node in the tree + """ + if not self.mutable and not any(path.startswith(part) for part in self.mutable_paths): + raise ParameterTreeError("Invalid Delete Attempt: Tree Not Mutable") + + # Split the path by levels, truncating the last level if path ends in trailing slash + levels = path.split('/') + if levels[-1] == '': + del levels[-1] + + subtree = self._tree + + if not levels: + subtree.clear() + return + try: + # Traverse down the path, based on hwo path navigation works in the Set Method above + for level in levels[:-1]: + + # If the subtree is a dict, the subtree is a normal branch, continue traversal. If + # it is not a dict the subtree is a list so the next path is indexed by the level + if isinstance(subtree, dict): + subtree = subtree[level] + else: + subtree = subtree[int(level)] + + # Once at the second to last part of the path, delete whatever comes next + if isinstance(subtree, list): + subtree.pop(int(levels[-1])) + else: + subtree.pop(levels[-1]) + except (KeyError, ValueError, IndexError): + raise ParameterTreeError("Invalid path: {}".format(path)) + + def _build_tree(self, node, path=''): + """Recursively build and expand out a tree or node. + + This internal method is used to recursively build and expand a tree or node, + replacing elements as found with appropriate types, e.g. ParameterAccessor for + a set/get pair, the internal tree of a nested ParameterTree. + + :param node: node to recursively build + :param path: path to node within overall tree + :returns: built node + """ + # If the node is a parameter tree instance, replace with its own built tree + if isinstance(node, type(self)): + if node.mutable: + self.mutable_paths.append(path) + return node.tree # this breaks the mutability of the sub-tree. hmm + + # Convert node tuple into the corresponding ParameterAccessor, depending on type of + # fields + if isinstance(node, tuple): + if len(node) == 1: + # Node is (value) + param = self.accessor_cls(path, node[0]) + + elif len(node) == 2: + if isinstance(node[1], dict): + # Node is (value, {metadata}) + param = self.accessor_cls(path, node[0], **node[1]) + else: + # Node is (getter, setter) + param = self.accessor_cls(path, node[0], node[1]) + + elif len(node) == 3 and isinstance(node[2], dict): + # Node is (getter, setter, {metadata}) + param = self.accessor_cls(path, node[0], node[1], **node[2]) + + else: + raise ParameterTreeError("{} is not a valid leaf node".format(repr(node))) + + return param + + # Convert list or non-callable tuple to enumerated dict + if isinstance(node, list): + return [self._build_tree(elem, path=path) for elem in node] + + # Recursively check child elements + if isinstance(node, dict): + return {k: self._build_tree( + v, path=path + str(k) + '/') for k, v in node.items()} + + return node + + def __remove_metadata(self, node): + """Remove metadata fields from a node. + + Used internally to return a parameter tree without metadata fields + + :param node: tree node to return without metadata fields + :returns: generator yeilding items in node minus metadata + """ + for key, val in node.items(): + if key not in self.METADATA_FIELDS: + yield key, val + + def _populate_tree(self, node, with_metadata=False): + """Recursively populate a tree with values. + + This internal method recursively populates the tree with parameter values, or + the results of the accessor getters for nodes. It is called by the get() method to + return the values of parameters in the tree. + + :param node: tree node to populate and return + :param with_metadata: include parameter metadata with the tree + :returns: populated node as a dict + """ + # If this is a branch node recurse down the tree + if isinstance(node, dict): + if with_metadata: + branch = { + k: self._populate_tree(v, with_metadata) for k, v + in node.items() + } + else: + branch = { + k: self._populate_tree(v, with_metadata) for k, v + in self.__remove_metadata(node) + } + return branch + + if isinstance(node, list): + return [self._populate_tree(item, with_metadata) for item in node] + + # If this is a leaf node, check if the leaf is a r/w tuple and substitute the + # read element of that tuple into the node + if isinstance(node, self.accessor_cls): + return node.get(with_metadata) + + return node + + def _merge_tree(self, node, new_data, cur_path): + """Recursively merge a tree with new values. + + This internal method recursively merges a tree with new values. Called by the set() + method, this allows parameters to be updated in place with the specified values, + calling the parameter setter in specified in an accessor. The type of any updated + parameters is checked against the existing parameter type. + + :param node: tree node to populate and return + :param new_data: dict of new data to be merged in at this path in the tree + :param cur_path: current path in the tree + :returns: the update node at this point in the tree + """ + # Recurse down tree if this is a branch node + if isinstance(node, dict) and isinstance(new_data, dict): + try: + update = {} + for k, v in self.__remove_metadata(new_data): + mutable = self.mutable or any( + cur_path.startswith(part) for part in self.mutable_paths + ) + if mutable and k not in node: + node[k] = {} + update[k] = self._merge_tree(node[k], v, cur_path + k + '/') + node.update(update) + return node + except KeyError as key_error: + raise ParameterTreeError( + 'Invalid path: {}{}'.format(cur_path, str(key_error)[1:-1]) + ) + if isinstance(node, list) and isinstance(new_data, dict): + try: + for i, val in enumerate(new_data): + node[i] = self._merge_tree(node[i], val, cur_path + str(i) + '/') + return node + except IndexError as index_error: + raise ParameterTreeError( + 'Invalid path: {}{} {}'.format(cur_path, str(i), str(index_error)) + ) + + # Update the value of the current parameter, calling the set accessor if specified and + # validating the type if necessary. + if isinstance(node, self.accessor_cls): + self._set_node(node, new_data) + else: + # Validate type of new node matches existing + if not self.mutable and type(node) is not type(new_data): + if not any(cur_path.startswith(part) for part in self.mutable_paths): + raise ParameterTreeError('Type mismatch updating {}: got {} expected {}'.format( + cur_path[:-1], type(new_data).__name__, type(node).__name__ + )) + node = new_data + + return node + + def _set_node(self, node, data): + """Set the value of a node to the specified data. + + This method trivially sets a specified node to the data supplied. It is exposed as a method + to allow derived classes to override it and add behaviour as necessary. + + :param node: tree node to set value of + :param data: data to node value to + """ + node.set(data) diff --git a/src/odin/adapters/base_proxy.py b/src/odin/adapters/base_proxy.py new file mode 100644 index 0000000..2bea874 --- /dev/null +++ b/src/odin/adapters/base_proxy.py @@ -0,0 +1,382 @@ +""" +Base class implementations for the synchronous and asynchronous proxy adapter implemntations. + +This module contains classes that provide the common behaviour for the implementations of the +proxy target and adaprers. + +Tim Nicholls, Ashley Neaves STFC Detector Systems Software Group. +""" +import logging +import time + +import tornado +import tornado.httpclient +from tornado.escape import json_encode, json_decode + +from odin.adapters.parameter_tree import ParameterTree, ParameterTreeError + + +class TargetDecodeError(Exception): + """Simple error class for raising target decode error exceptions.""" + + pass + + +class BaseProxyTarget(object): + """ + Proxy target base class. + + This base class provides the core fnctionality needed for the concrete synchronous and + asynchronous implementations. It is not intended to be instantiated directly. + """ + + def __init__(self, name, url, request_timeout): + """ + Initialise the BaseProxyTarget object. + + Sets up the default state of the base target object, builds the appropriate parameter tree + to be handled by the containing adapter and sets up the HTTP client for making requests + to the target. + + :param name: name of the proxy target + :param url: URL of the remote target + :param request_timeout: request timeout in seconds + """ + self.name = name + self.url = url + self.request_timeout = request_timeout + + # Initialise default state + self.status_code = 0 + self.error_string = 'OK' + self.last_update = 'unknown' + self.data = {} + self.metadata = {} + self.counter = 0 + + # Build a parameter tree representation of the proxy target status + self.status_param_tree = ParameterTree({ + 'url': (lambda: self.url, None), + 'status_code': (lambda: self.status_code, None), + 'error': (lambda: self.error_string, None), + 'last_update': (lambda: self.last_update, None), + }) + + # Build a parameter tree representation of the proxy target data + self.data_param_tree = ParameterTree((lambda: self.data, None)) + self.meta_param_tree = ParameterTree((lambda: self.metadata, None)) + + # Set up default request headers + self.request_headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + } + + def remote_get(self, path='', get_metadata=False): + """ + Get data from the remote target. + + This method requests data from the remote target by issuing a GET request to the target + URL, and then updates the local proxy target data and status information according to the + response. The request is sent to the target by the implementation-specific _send_request + method. + + :param path: path to data on remote target + :param get_metadata: flag indicating if metadata is to be requested + """ + # Create a GET request to send to the target + request = tornado.httpclient.HTTPRequest( + url=self.url + path, + method="GET", + headers=self.request_headers.copy(), + request_timeout=self.request_timeout + ) + # If metadata is requested, modify the Accept header accordingly + if get_metadata: + request.headers["Accept"] += ";metadata=True" + + # Send the request to the remote target + return self._send_request(request, path, get_metadata) + + def remote_set(self, path, data): + """ + Set data on the remote target. + + his method sends data to the remote target by issuing a PUT request to the target + URL, and then updates the local proxy target data and status information according to the + response. The request is sent to the target by the implementation-specific _send_request + method. + + :param path: path to data on remote target + :param data: data to set on remote target + """ + # Encode the request data as JSON if necessary + if isinstance(data, dict): + data = json_encode(data) + + # Create a PUT request to send to the target + request = tornado.httpclient.HTTPRequest( + url=self.url + path, + method="PUT", + body=data, + headers=self.request_headers, + request_timeout=self.request_timeout + ) + + # Send the request to the remote target + return self._send_request(request, path) + + def _process_response(self, response, path, get_metadata): + """ + Process a response from the remote target. + + This method processes the response of a remote target to a request. The response is used to + update the local proxy target data metadata and status as appropriate. If the request failed + the returned exception is decoded and the status updated accordingly. + + :param response: HTTP response from the target, or an exception if the response failed + :param path: path of data being updated + :param get_metadata: flag indicating if metadata was requested + """ + # Update the timestamp of the last request in standard format + self.last_update = tornado.httputil.format_timestamp(time.time()) + + # If an HTTP response was received, handle accordingly + if isinstance(response, tornado.httpclient.HTTPResponse): + + # Decode the reponse body, handling errors by re-processing the repsonse as an + # exception. Otherwise, update the target data and status based on the response. + try: + response_body = json_decode(response.body) + except ValueError as decode_error: + error_string = "Failed to decode response body: {}".format(str(decode_error)) + self._process_response(TargetDecodeError(error_string), path, get_metadata) + else: + + # Update status code, errror string and data accordingly + self.status_code = response.code + self.error_string = 'OK' + + # Set a reference to the data or metadata to update as necessary + if get_metadata: + data_ref = self.metadata + else: + data_ref = self.data + + # If a path was specified, parse it and descend to the appropriate location in the + # data struture + if path: + path_elems = path.split('/') + + # Remove empty string caused by trailing slashes + if path_elems[-1] == '': + del path_elems[-1] + + # Traverse down the data tree for each element + for elem in path_elems[:-1]: + data_ref = data_ref[elem] + + # Update the data or metadata with the body of the response + for key in response_body: + new_elem = response_body[key] + data_ref[key] = new_elem + + # Otherwise, handle the exception, updating status information and reporting the error + elif isinstance(response, Exception): + + if isinstance(response, tornado.httpclient.HTTPError): + error_type = "HTTP error" + self.status_code = response.code + self.error_string = response.message + + elif isinstance(response, tornado.ioloop.TimeoutError): + error_type = "Timeout" + self.status_code = 408 + self.error_string = str(response) + + elif isinstance(response, IOError): + error_type = "IO error" + self.status_code = 502 + self.error_string = str(response) + + elif isinstance(response, TargetDecodeError): + error_type = "Decode error" + self.status_code = 415 + self.error_string = str(response) + + else: + error_type = "Unknown error" + self.status_code = 500 + self.error_string = str(response) + + logging.error( + "%s: proxy target %s request failed (%d): %s ", + error_type, + self.name, + self.status_code, + self.error_string, + ) + + +class BaseProxyAdapter(object): + """ + Proxy adapter base mixin class. + + This mixin class implements the core functionality required by all concrete proxy adapter + implementations. + """ + TIMEOUT_CONFIG_NAME = 'request_timeout' + TARGET_CONFIG_NAME = 'targets' + + def initialise_proxy(self, proxy_target_cls): + """ + Initialise the proxy. + + This method initialises the proxy. The adapter options are parsed to determine the list + of proxy targets and request timeout, then a proxy target of the specified class is created + for each target. The data, metadata and status structures and parameter trees associated + with each target are created. + + :param proxy_target_cls: proxy target class appropriate for the specific implementation + """ + # Set the HTTP request timeout if present in the options + request_timeout = None + if self.TIMEOUT_CONFIG_NAME in self.options: + try: + request_timeout = float(self.options[self.TIMEOUT_CONFIG_NAME]) + logging.debug('Proxy adapter request timeout set to %f secs', request_timeout) + except ValueError: + logging.error( + "Illegal timeout specified for proxy adapter: %s", + self.options[self.TIMEOUT_CONFIG_NAME] + ) + + # Parse the list of target-URL pairs from the options, instantiating a proxy target of the + # specified type for each target specified. + self.targets = [] + if self.TARGET_CONFIG_NAME in self.options: + for target_str in self.options[self.TARGET_CONFIG_NAME].split(','): + try: + (target, url) = target_str.split('=') + self.targets.append( + proxy_target_cls(target.strip(), url.strip(), request_timeout) + ) + except ValueError: + logging.error("Illegal target specification for proxy adapter: %s", + target_str.strip()) + + # Issue an error message if no targets were loaded + if self.targets: + logging.debug("Proxy adapter with {:d} targets loaded".format(len(self.targets))) + else: + logging.error("Failed to resolve targets for proxy adapter") + + # Build the parameter trees implemented by this adapter for the specified proxy targets + status_dict = {} + tree = {} + meta_tree = {} + + for target in self.targets: + status_dict[target.name] = target.status_param_tree + tree[target.name] = target.data_param_tree + meta_tree[target.name] = target.meta_param_tree + + # Create a parameter tree from the status data for the targets and insert into the + # data and metadata structures + self.status_tree = ParameterTree(status_dict) + tree['status'] = self.status_tree + meta_tree['status'] = self.status_tree.get("", True) + + # Create the data and metadata parameter trees + self.param_tree = ParameterTree(tree) + self.meta_param_tree = ParameterTree(meta_tree) + + def proxy_get(self, path, get_metadata): + """ + Get data from the proxy targets. + + This method gets data from one or more specified targets and returns the responses. + + :param path: path to data on remote targets + :param get_metadata: flag indicating if metadata is to be requested + :return: list of target responses + """ + # Resolve the path element and target path + path_elem, target_path = self._resolve_path(path) + + # Iterate over the targets and get data if the path matches + target_responses = [] + for target in self.targets: + if path_elem == "" or path_elem == target.name: + target_responses.append(target.remote_get(target_path, get_metadata)) + + return target_responses + + def proxy_set(self, path, data): + """ + Set data on the proxy targets. + + This method sets data on one or more specified targets and returns the responses. + + :param path: path to data on remote targets + :param data to set on targets + :return: list of target responses + """ + # Resolve the path element and target path + path_elem, target_path = self._resolve_path(path) + + # Iterate over the targets and set data if the path matches + target_responses = [] + for target in self.targets: + if path_elem == '' or path_elem == target.name: + target_responses.append(target.remote_set(target_path, data)) + + return target_responses + + def _resolve_response(self, path, get_metadata=False): + """ + Resolve the response to a proxy target get or set request. + + This method resolves the appropriate response to a proxy target get or set request. Data + or metadata from the specified path is returned, along with an appropriate HTTP status code. + + :param path: path to data on remote targets + :param get_metadata: flag indicating if metadata is to be requested + + """ + # Build the response from the adapter parameter trees, matching to the path for one or more + # targets + try: + # If metadata is requested, update the status tree with metadata before returning + # metadata + if get_metadata: + path_elem, _ = self._resolve_path(path) + if path_elem in ("", "status"): + # update status tree with metadata + self.meta_param_tree.set('status', self.status_tree.get("", True)) + response = self.meta_param_tree.get(path) + else: + response = self.param_tree.get(path) + status_code = 200 + except ParameterTreeError as param_tree_err: + response = {'error': str(param_tree_err)} + status_code = 400 + + return (response, status_code) + + @staticmethod + def _resolve_path(path): + """ + Resolve the specified path into a path element and target. + + This method resolves the specified path into a path element and target path. + + :param path: path to data on remote targets + :return: tuple of path element and target path + """ + if "/" in path: + path_elem, target_path = path.split('/', 1) + else: + path_elem = path + target_path = "" + return (path_elem, target_path) diff --git a/src/odin/adapters/dummy.py b/src/odin/adapters/dummy.py index a3c05c8..9da4dbd 100644 --- a/src/odin/adapters/dummy.py +++ b/src/odin/adapters/dummy.py @@ -10,8 +10,6 @@ Tim Nicholls, STFC Application Engineering """ import logging -from concurrent import futures -import time from tornado.ioloop import PeriodicCallback from odin.adapters.adapter import (ApiAdapter, ApiAdapterRequest, @@ -54,6 +52,9 @@ def __init__(self, **kwargs): logging.debug('DummyAdapter loaded') + def initialize(self, adapters): + logging.debug("DummyAdapter initialized with %d adapters", len(adapters)) + def background_task_callback(self): """Run the adapter background task. @@ -62,8 +63,8 @@ def background_task_callback(self): :param task_interval: time to sleep until task is run again """ - logging.debug("%s: background task running, count = %d", - self.name, self.background_task_counter) + logging.debug( + "%s: background task running, count = %d", self.name, self.background_task_counter) self.background_task_counter += 1 @response_types('application/json', default='application/json') diff --git a/src/odin/adapters/parameter_tree.py b/src/odin/adapters/parameter_tree.py index 79569f8..ecd2061 100644 --- a/src/odin/adapters/parameter_tree.py +++ b/src/odin/adapters/parameter_tree.py @@ -1,508 +1,68 @@ -"""ParameterTree - classes representing a tree of parameters and their accessor methods. +"""parameter_tree.py - classes representing a sychronous parameter tree and accessor. -This module implements an arbitrarily-structured, recursively-managed tree of parameters and -the appropriate accessor methods that are used to read and write those parameters. Its -particular use is in the definition of a tree of parameters for an API adapter and help -interfacing of those to the underlying device or object. +This module defines a parameter tree and accessor for use in synchronous API adapters, where +concurrency over blocking operations (e.g. to read/write the value of a parameter from hardware) +is not required. -James Hogge, Tim Nicholls, STFC Application Engineering Group. +Tim Nicholls, STFC Detector Systems Software Group. """ -import warnings +from odin.adapters.base_parameter_tree import ( + BaseParameterAccessor, BaseParameterTree, ParameterTreeError +) -class ParameterTreeError(Exception): - """Simple error class for raising parameter tree parameter tree exceptions.""" +__all__ = ['ParameterAccessor', 'ParameterTree', 'ParameterTreeError'] - pass +class ParameterAccessor(BaseParameterAccessor): + """Synchronous container class representing accessor methods for a parameter. -class ParameterAccessor(object): - """Container class representing accessor methods for a parameter. - - This class implements a parameter accessor, provding set and get methods - for parameters requiring calls to access them, or simply returning the - appropriate value if the parameter is a read-only constant. Parameter accessors also - contain metadata fields controlling access to and providing information about the parameter. - - Valid specifiable metadata fields are: - min : minimum allowed value for parameter - max : maxmium allowed value for parameter - allowed_values: list of allowed values for parameter - name : readable parameter name - description: longer description of parameter - units: parameter units - display_precision: number of decimal places to display for e.g. float types - - The class also maintains the following automatically-populated metadata fields: - type: parameter type - writeable: is the parameter writable + This class extends the base parameter accessor class to support synchronous set and get + accessors for a parameter. Read-only and writeable parameters are supported, and the same + metadata fields are implemented. """ - # Valid metadata arguments that can be passed to ParameterAccess __init__ method. - VALID_METADATA_ARGS = ( - "min", "max", "allowed_values", "name", "description", "units", "display_precision" - ) - # Automatically-populated metadata fields based on inferred type of the parameter and - # writeable status depending on specified accessors - AUTO_METADATA_FIELDS = ("type", "writeable") - def __init__(self, path, getter=None, setter=None, **kwargs): """Initialise the ParameterAccessor instance. - This constructor initialises the ParameterAccessor instance, storing - the path of the parameter, its set/get accessors and setting metadata fields based - on the the specified keyword arguments + This constructor initialises the ParameterAccessor instance, storing the path of the + parameter, its set/get accessors and setting metadata fields based on the the specified + keyword arguments. :param path: path of the parameter within the tree :param getter: get method for the parameter, or a value if read-only constant :param setter: set method for the parameter :param kwargs: keyword argument list for metadata fields to be set; these must be from - the allow list specified in ParameterAccessor.allowed_metadata + the allow list specified in BaseParameterAccessor.allowed_metadata """ - # Initialise path, getter and setter - self.path = path[:-1] - self._get = getter - self._set = setter - - # Initialize metadata dict - self.metadata = {} - - # Check metadata keyword arguments are valid - for arg in kwargs: - if arg not in ParameterAccessor.VALID_METADATA_ARGS: - raise ParameterTreeError("Invalid metadata argument: {}".format(arg)) - - # Update metadata keywords from arguments - self.metadata.update(kwargs) + # Initialise the superclass with the specified arguments + super(ParameterAccessor, self).__init__(path, getter, setter, **kwargs) # Save the type of the parameter for type checking self._type = type(self.get()) - # Set type and writeable metadata fields based on specified accessors - self.metadata["type"] = type(self.get()).__name__ - if not callable(self._set) and callable(self._get): - self.metadata["writeable"] = False - else: - self.metadata["writeable"] = True - - def get(self, with_metadata=False): - """Get the value of the parameter. - - This method returns the value of the parameter, or the value returned - by the get accessor if one is defined (i.e. is callable). If the with_metadata argument - is true, the value is returned in a dictionary including all metadata for the - parameter. - - :param with_metadata: include metadata in the response when set to True - :returns value of the parameter - """ - # Determine the value of the parameter by calling the getter or simply from the stored - # value - if callable(self._get): - value = self._get() - else: - value = self._get - - # If metadata is requested, replace the value with a dict containing the value itself - # plus metadata fields - if with_metadata: - value = {"value": value} - value.update(self.metadata) - - return value - - def set(self, value): - """Set the value of the parameter. - - This method sets the value of the parameter by calling the set accessor - if defined and callable, otherwise raising an exception. - - :param value: value to set - """ - - # Raise an error if this parameter is not writeable - if not self.metadata["writeable"]: - raise ParameterTreeError("Parameter {} is read-only".format(self.path)) - - # Raise an error of the value to be set is not of the same type as the parameter. If - # the metadata type field is set to None, allow any type to be set, or if the value - # is integer and the parameter is float, also allow as JSON does not differentiate - # numerics in all cases - if self.metadata["type"] != "NoneType" and not isinstance(value, self._type): - if not (isinstance(value, int) and self.metadata["type"] == "float"): - raise ParameterTreeError( - "Type mismatch setting {}: got {} expected {}".format( - self.path, type(value).__name__, self.metadata["type"] - ) - ) - - # Raise an error if allowed_values has been set for this parameter and the value to - # set is not one of them - if "allowed_values" in self.metadata and value not in self.metadata["allowed_values"]: - raise ParameterTreeError( - "{} is not an allowed value for {}".format(value, self.path) - ) - - # Raise an error if the parameter has a mininum value specified in metadata and the - # value to set is below this - if "min" in self.metadata and value < self.metadata["min"]: - raise ParameterTreeError( - "{} is below the minimum value {} for {}".format( - value, self.metadata["min"], self.path - ) - ) - - # Raise an error if the parameter has a maximum value specified in metadata and the - # value to set is above this - if "max" in self.metadata and value > self.metadata["max"]: - raise ParameterTreeError( - "{} is above the maximum value {} for {}".format( - value, self.metadata["max"], self.path - ) - ) - - # Set the new parameter value, either by calling the setter or updating the local - # value as appropriate - if callable(self._set): - self._set(value) - elif not callable(self._get): - self._get = value + # Set the type metadata fields based on the resolved tyoe + self.metadata["type"] = self._type.__name__ -class ParameterTree(object): - """Class implementing a tree of parameters and their accessors. +class ParameterTree(BaseParameterTree): + """Class implementing a synchronous tree of parameters and their accessors. - This class implements an arbitrarily-structured, recursively-managed tree of parameters and - the appropriate accessor methods that are used to read and write those parameters. Its - particular use is in the definition of a tree of parameters for an API adapter and help - interfacing of those to the underlying device or object. + This lass implements an arbitrarily-structured, recursively-managed tree of parameters + and the appropriate accessor methods that are used to read and write those parameters. """ - METADATA_FIELDS = ["name", "description"] - def __init__(self, tree, mutable=False): """Initialise the ParameterTree object. - This constructor recursively initialises the ParameterTree object, based on the parameter - tree dictionary passed as an argument. This is done recursively, so that a parameter tree - can have arbitrary depth and contain other ParameterTree instances as necessary. - - Initialisation syntax for ParameterTree is made by passing a dict representing the tree - as an argument. Children of a node at any level of the tree are described with - dictionaries/lists e.g. - - {"parent" : {"childA" : {...}, "childB" : {...}}} - {"parent" : [{...}, {...}]} - - Leaf nodes can be one of the following formats: - - value - (value,) - (value, {metadata}) - getter - (getter,) - (getter, {metadata}) - (getter, setter) - (getter, setter, {metadata}) - - The following tags will also be treated as metadata: - - name - A printable name for that branch of the tree - description - A printable description for that branch of the tree + This constructor recursively initialises the ParameterTree object based on the specified + arguments. The tree initialisation syntax follows that of the BaseParameterTree + implementation. :param tree: dict representing the parameter tree - :param mutable: Flag, setting the tree - """ - # Flag, if set to true, allows nodes to be replaced and new nodes created - self.mutable = mutable - # list of paths to mutable parts. Not sure this is best solution - self.mutable_paths = [] - # Recursively check and initialise the tree - self._tree = self.__recursive_build_tree(tree) - - - @property - def tree(self): - """Return tree object for this parameter tree node. - - Used internally for recursive descent of parameter trees. + :param mutable: Flag, setting the tree """ - return self._tree - - def get(self, path, with_metadata=False): - """Get the values of parameters in a tree. - - This method returns the values at and below a specified path in the parameter tree. - This is done by recursively populating the tree with the current values of parameters, - returning the result as a dictionary. - - :param path: path in tree to get parameter values for - :param with_metadata: include metadata in the response when set to True - :returns: dict of parameter tree at the specified path - """ - # Split the path by levels, truncating the last level if path ends in trailing slash - levels = path.split('/') - if levels[-1] == '': - del levels[-1] - - # Initialise the subtree before descent - subtree = self._tree - - # If this is single level path, return the populated tree at the top level - if not levels: - return self.__recursive_populate_tree(subtree, with_metadata) - - # Descend the specified levels in the path, checking for a valid subtree of the appropriate - # type - for level in levels: - if level in self.METADATA_FIELDS and not with_metadata: - raise ParameterTreeError("Invalid path: {}".format(path)) - try: - if isinstance(subtree, dict): - subtree = subtree[level] - elif isinstance(subtree, ParameterAccessor): - subtree = subtree.get(with_metadata)[level] - else: - subtree = subtree[int(level)] - except (KeyError, ValueError, IndexError): - raise ParameterTreeError("Invalid path: {}".format(path)) - - # Return the populated tree at the appropriate path - return self.__recursive_populate_tree({levels[-1]: subtree}, with_metadata) - - def set(self, path, data): - """Set the values of the parameters in a tree. - - This method sets the values of parameters in a tree, based on the data passed to it - as a nested dictionary of parameter and value pairs. The updated parameters are merged - into the existing tree recursively. - - :param path: path to set parameters for in the tree - :param data: nested dictionary representing values to update at the path - """ - # Expand out any lists/tuples - data = self.__recursive_build_tree(data) - - # Get subtree from the node the path points to - levels = path.split('/') - if levels[-1] == '': - del levels[-1] - - merge_parent = None - merge_child = self._tree - - # Descend the tree and validate each element of the path - for level in levels: - if level in self.METADATA_FIELDS: - raise ParameterTreeError("Invalid path: {}".format(path)) - try: - merge_parent = merge_child - if isinstance(merge_child, dict): - merge_child = merge_child[level] - else: - merge_child = merge_child[int(level)] - except (KeyError, ValueError, IndexError): - raise ParameterTreeError("Invalid path: {}".format(path)) - - # Add trailing / to paths where necessary - if path and path[-1] != '/': - path += '/' - - # Merge data with tree - merged = self.__recursive_merge_tree(merge_child, data, path) - - # Add merged part to tree, either at the top of the tree or at the - # appropriate level speicfied by the path - if not levels: - self._tree = merged - return - if isinstance(merge_parent, dict): - merge_parent[levels[-1]] = merged - else: - merge_parent[int(levels[-1])] = merged - - def delete(self, path=''): - """ - Remove Parameters from a Mutable Tree. - - This method deletes selected parameters from a tree, if that tree has been flagged as - Mutable. Deletion of Branch Nodes means all child nodes of that Branch Node are also deleted - - :param path: Path to selected Parameter Node in the tree - """ - if not self.mutable and not any(path.startswith(part) for part in self.mutable_paths): - raise ParameterTreeError("Invalid Delete Attempt: Tree Not Mutable") - - # Split the path by levels, truncating the last level if path ends in trailing slash - levels = path.split('/') - if levels[-1] == '': - del levels[-1] - - subtree = self._tree - - if not levels: - subtree.clear() - return - try: - # navigate down the path, based on hwo path navigation works in the Set Method above - for level in levels[:-1]: - # if dict, subtree is normal branch, continue navigation - if isinstance(subtree, dict): - subtree = subtree[level] - else: # if not a dict, but still navigating, it should be a list, so next path is int - subtree = subtree[int(level)] - # once we are at the second to last part of the path, we want to delete whatever comes next - if isinstance(subtree, list): - subtree.pop(int(levels[-1])) - else: - subtree.pop(levels[-1]) - except (KeyError, ValueError, IndexError): - raise ParameterTreeError("Invalid path: {}".format(path)) - - def __recursive_build_tree(self, node, path=''): - """Recursively build and expand out a tree or node. - - This internal method is used to recursively build and expand a tree or node, - replacing elements as found with appropriate types, e.g. ParameterAccessor for - a set/get pair, the internal tree of a nested ParameterTree. - - :param node: node to recursively build - :param path: path to node within overall tree - :returns: built node - """ - - # If the node is a ParameterTree instance, replace with its own built tree - if isinstance(node, ParameterTree): - if node.mutable: - self.mutable_paths.append(path) - return node.tree # this breaks the mutability of the sub-tree. hmm - - # Convert node tuple into the corresponding ParameterAccessor, depending on type of - # fields - if isinstance(node, tuple): - if len(node) == 1: - # Node is (value) - param = ParameterAccessor(path, node[0]) - - elif len(node) == 2: - if isinstance(node[1], dict): - # Node is (value, {metadata}) - param = ParameterAccessor(path, node[0], **node[1]) - else: - # Node is (getter, setter) - param = ParameterAccessor(path, node[0], node[1]) - - elif len(node) == 3 and isinstance(node[2], dict): - # Node is (getter, setter, {metadata}) - param = ParameterAccessor(path, node[0], node[1], **node[2]) - - else: - raise ParameterTreeError("{} is not a valid leaf node".format(repr(node))) - - return param - - # Convert list or non-callable tuple to enumerated dict - if isinstance(node, list): - return [self.__recursive_build_tree(elem, path=path) for elem in node] - - # Recursively check child elements - if isinstance(node, dict): - return {k: self.__recursive_build_tree( - v, path=path + str(k) + '/') for k, v in node.items()} - - return node - - def __remove_metadata(self, node): - """Remove metadata fields from a node. - - Used internally to return a parameter tree without metadata fields - - :param node: tree node to return without metadata fields - :returns: generator yeilding items in node minus metadata - """ - for key, val in node.items(): - if key not in self.METADATA_FIELDS: - yield key, val - - def __recursive_populate_tree(self, node, with_metadata=False): - """Recursively populate a tree with values. - - This internal method recursively populates the tree with parameter values, or - the results of the accessor getters for nodes. It is called by the get() method to - return the values of parameters in the tree. - - :param node: tree node to populate and return - :param with_metadata: include parameter metadata with the tree - :returns: populated node as a dict - """ - # If this is a branch node recurse down the tree - if isinstance(node, dict): - if with_metadata: - branch = { - k: self.__recursive_populate_tree(v, with_metadata) for k, v - in node.items() - } - else: - branch = { - k: self.__recursive_populate_tree(v, with_metadata) for k, v - in self.__remove_metadata(node) - } - return branch - - if isinstance(node, list): - return [self.__recursive_populate_tree(item, with_metadata) for item in node] - - # If this is a leaf node, check if the leaf is a r/w tuple and substitute the - # read element of that tuple into the node - if isinstance(node, ParameterAccessor): - return node.get(with_metadata) - - return node - - # Replaces values in data_tree with values from new_data - def __recursive_merge_tree(self, node, new_data, cur_path): - """Recursively merge a tree with new values. - - This internal method recursively merges a tree with new values. Called by the set() - method, this allows parameters to be updated in place with the specified values, - calling the parameter setter in specified in an accessor. The type of any updated - parameters is checked against the existing parameter type. - - :param node: tree node to populate and return - :param new_data: dict of new data to be merged in at this path in the tree - :param cur_path: current path in the tree - :returns: the update node at this point in the tree - """ - # Recurse down tree if this is a branch node - if isinstance(node, dict) and isinstance(new_data, dict): - try: - update = {} - for k, v in self.__remove_metadata(new_data): - mutable = self.mutable or any(cur_path.startswith(part) for part in self.mutable_paths) - if mutable and k not in node: - node[k] = {} - update[k] = self.__recursive_merge_tree(node[k], v, cur_path + k + '/') - node.update(update) - return node - except KeyError as key_error: - raise ParameterTreeError( - 'Invalid path: {}{}'.format(cur_path, str(key_error)[1:-1]) - ) - if isinstance(node, list) and isinstance(new_data, dict): - try: - for i, val in enumerate(new_data): - node[i] = self.__recursive_merge_tree(node[i], val, cur_path + str(i) + '/') - return node - except IndexError as index_error: - raise ParameterTreeError( - 'Invalid path: {}{} {}'.format(cur_path, str(i), str(index_error)) - ) - - # Update the value of the current parameter, calling the set accessor if specified and - # validating the type if necessary. - if isinstance(node, ParameterAccessor): - node.set(new_data) - else: - # Validate type of new node matches existing - if not self.mutable and type(node) is not type(new_data): - if not any(cur_path.startswith(part) for part in self.mutable_paths): - raise ParameterTreeError('Type mismatch updating {}: got {} expected {}'.format( - cur_path[:-1], type(new_data).__name__, type(node).__name__ - )) - node = new_data + # Set the accessor class used by this tree to ParameterAccessor + self.accessor_cls = ParameterAccessor - return node + # Initialise the superclass with the speccified parameters + super(ParameterTree, self).__init__(tree, mutable) diff --git a/src/odin/adapters/proxy.py b/src/odin/adapters/proxy.py index 6ae8313..745b082 100644 --- a/src/odin/adapters/proxy.py +++ b/src/odin/adapters/proxy.py @@ -1,28 +1,21 @@ """ -Proxy adapter class for the ODIN server. +Proxy adapter for use in odin-control. -This class implements a simple asynchronous proxy adapter, allowing requests to be proxied to -one or more remote HTTP resources, typically further ODIN servers. +This module implements a simple proxy adapter, allowing requests to be proxied to +one or more remote HTTP resources, typically further odin-control instances. -Tim Nicholls, Adam Neaves STFC Application Engineering Group. +Tim Nicholls, Ashley Neaves STFC Detector Systems Software Group. """ -import logging -import time -import tornado -import tornado.httpclient -from tornado.escape import json_encode +from tornado.httpclient import HTTPClient from odin.util import decode_request_body - from odin.adapters.adapter import ( ApiAdapter, ApiAdapterResponse, - request_types, response_types, wants_metadata) -from odin.adapters.parameter_tree import ParameterTree, ParameterTreeError - -TIMEOUT_CONFIG_NAME = 'request_timeout' -TARGET_CONFIG_NAME = 'targets' + request_types, response_types, wants_metadata +) +from odin.adapters.base_proxy import BaseProxyTarget, BaseProxyAdapter -class ProxyTarget(object): +class ProxyTarget(BaseProxyTarget): """ Proxy adapter target class. @@ -32,163 +25,78 @@ class ProxyTarget(object): def __init__(self, name, url, request_timeout): """ - Initalise the ProxyTarget object. + Initialise the ProxyTarget object. - Sets up the default state of the target object, builds the - appropriate parameter tree to be handled by the containing adapter - and sets up the HTTP client for making requests to the target. - """ - self.name = name - self.url = url - self.request_timeout = request_timeout - - # Initialise default state - self.status_code = 0 - self.error_string = 'OK' - self.last_update = 'unknown' - self.data = {} - self.metadata = {} - self.counter = 0 - - # Build a parameter tree representation of the proxy target status - self.status_param_tree = ParameterTree({ - 'url': (lambda: self.url, None), - 'status_code': (lambda: self.status_code, None), - 'error': (lambda: self.error_string, None), - 'last_update': (lambda: self.last_update, None), - }) - - # Build a parameter tree representation of the proxy target data - self.data_param_tree = ParameterTree((lambda: self.data, None)) - self.meta_param_tree = ParameterTree((lambda: self.metadata, None)) - - # Create an HTTP client instance and set up default request headers - self.http_client = tornado.httpclient.HTTPClient() - self.request_headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - } - self.remote_get() # init the data tree - self.remote_get(get_metadata=True) # init the metadata - - def update(self, request, path, get_metadata=False): - """ - Update the Proxy Target `ParameterTree` with data from the proxied adapter, - after issuing a GET or a PUT request to it. It also updates the status code - and error string if the HTTP request fails. + This constructor initialises the ProxyTarget, creating a HTTP client and delegating + the full initialisation to the base class. + + :param name: name of the proxy target + :param url: URL of the remote target + :param request_timeout: request timeout in seconds """ + # Create an async HTTP client for use in this target + self.http_client = HTTPClient() - try: - # Request data to/from the target - response = self.http_client.fetch(request) + # Initialise the base class + super(ProxyTarget, self).__init__(name, url, request_timeout) - # Update status code and data accordingly - self.status_code = response.code - self.error_string = 'OK' - response_body = tornado.escape.json_decode(response.body) - - except tornado.httpclient.HTTPError as http_err: - # Handle HTTP errors, updating status information and reporting error - self.status_code = http_err.code - self.error_string = http_err.message - logging.error( - "HTTP Error: Proxy target %s fetch failed: %d %s Request: %s", - self.name, - self.status_code, - self.error_string, - request.body - ) - self.last_update = tornado.httputil.format_timestamp(time.time()) - return - except tornado.ioloop.TimeoutError as time_err: - self.status_code = 408 - self.error_string = str(time_err) - logging.error( - "Timeout Error: Proxy Target %s fetch failed: %d %s", - self.name, - self.status_code, - self.error_string - ) - self.last_update = tornado.httputil.format_timestamp(time.time()) - return - except IOError as other_err: - self.status_code = 502 - self.error_string = str(other_err) - logging.error( - "IO Error: Proxy Target %s fetch failed: %d %s", - self.name, - self.status_code, - self.error_string - ) - self.last_update = tornado.httputil.format_timestamp(time.time()) - return - - if get_metadata: - data_ref = self.metadata - else: - data_ref = self.data # reference for modification - if path: - # if the path exists, we need to split it so we can navigate the data - path_elems = path.split('/') - if path_elems[-1] == '': # remove empty string caused by trailing slashes - del path_elems[-1] - for elem in path_elems[:-1]: - # for each element, traverse down the data tree - data_ref = data_ref[elem] - - for key in response_body: - new_elem = response_body[key] - data_ref[key] = new_elem - - # Update the timestamp of the last request in standard format - self.last_update = tornado.httputil.format_timestamp(time.time()) + # Initialise the data and metadata trees from the remote target + self.remote_get() + self.remote_get(get_metadata=True) def remote_get(self, path='', get_metadata=False): """ Get data from the remote target. - This method updates the local proxy target with new data by - issuing a GET request to the target URL, and then updates the proxy - target data and status information according to the response. - """ + This method updates the local proxy target with new data by issuing a GET request to the + target URL, and then updates the local proxy target data and status information according to + the response. The detailed handling of this is implemented by the base class. - # create request to PUT data, send to the target - request = tornado.httpclient.HTTPRequest( - url=self.url + path, - method="GET", - headers=self.request_headers.copy(), - request_timeout=self.request_timeout - ) - if get_metadata: - request.headers["Accept"] += ";metadata=True" - self.update(request, path, get_metadata) + :param path: path to data on remote target + :param get_metadata: flag indicating if metadata is to be requested + """ + super(ProxyTarget, self).remote_get(path, get_metadata) def remote_set(self, path, data): """ Set data on the remote target. - This method updates the local proxy target with new datat by - issuing a PUT request to the target URL, and then updates the proxy - target data and status information according to the response. + This method sends data to the remote target by issuing a PUT request to the target + URL, and then updates the local proxy target data and status information according to the + response. The detailed handling of this is implemented by the base class. + + :param path: path to data on remote target + :param data: data to set on remote target + """ + super(ProxyTarget, self).remote_set(path, data) + + def _send_request(self, request, path, get_metadata=False): + """ + Send a request to the remote target and update data. + + This internal method sends a request to the remote target using the HTTP client + and handles the response, updating target data accordingly. + + :param request: HTTP request to transmit to target + :param path: path of data being updated + :param get_metadata: flag indicating if metadata is to be requested """ - # create request to PUT data, send to the target - if isinstance(data, dict): - data = json_encode(data) - request = tornado.httpclient.HTTPRequest( - url=self.url + path, - method="PUT", - body=data, - headers=self.request_headers, - request_timeout=self.request_timeout - ) - self.update(request, path) - - -class ProxyAdapter(ApiAdapter): + # Send the request to the remote target, handling any exceptions that occur + try: + response = self.http_client.fetch(request) + except Exception as fetch_exception: + # Set the response to the exception so it can be handled during response resolution + response = fetch_exception + + # Process the response from the target, updating data as appropriate + self._process_response(response, path, get_metadata) + + +class ProxyAdapter(ApiAdapter, BaseProxyAdapter): """ - Proxy adapter class for ODIN server. + Proxy adapter class for odin-control. - This class implements a proxy adapter, allowing ODIN server to forward requests to + This class implements a proxy adapter, allowing odin-control to forward requests to other HTTP services. """ @@ -196,100 +104,34 @@ def __init__(self, **kwargs): """ Initialise the ProxyAdapter. - This constructor initialises the adapter instance, parsing configuration - options out of the keyword arguments it is passed. A ProxyTarget object is - instantiated for each target specified in the options. + This constructor initialises the adapter instance. The base class adapter is initialised + with the keyword arguments and then the proxy targets and paramter tree initialised by the + proxy adapter mixin. - :param kwargs: keyword arguments specifying options + :param kwargs: keyword arguments specifying options """ - - # Initialise base class + # Initialise the base class super(ProxyAdapter, self).__init__(**kwargs) - # Set the HTTP request timeout if present in the options - request_timeout = None - if TIMEOUT_CONFIG_NAME in self.options: - try: - request_timeout = float(self.options[TIMEOUT_CONFIG_NAME]) - logging.debug('ProxyAdapter request timeout set to %f secs', request_timeout) - except ValueError: - logging.error( - "Illegal timeout specified for ProxyAdapter: %s", - self.options[TIMEOUT_CONFIG_NAME] - ) - - # Parse the list of target-URL pairs from the options, instantiating a ProxyTarget - # object for each target specified. - self.targets = [] - if TARGET_CONFIG_NAME in self.options: - for target_str in self.options[TARGET_CONFIG_NAME].split(','): - try: - (target, url) = target_str.split('=') - self.targets.append(ProxyTarget(target.strip(), url.strip(), request_timeout)) - except ValueError: - logging.error("Illegal target specification for ProxyAdapter: %s", - target_str.strip()) - - # Issue an error message if no targets were loaded - if self.targets: - logging.debug("ProxyAdapter with {:d} targets loaded".format(len(self.targets))) - else: - logging.error("Failed to resolve targets for ProxyAdapter") - - status_dict = {} - # Construct the parameter tree returned by this adapter - tree = {} - meta_tree = {} - for target in self.targets: - status_dict[target.name] = target.status_param_tree - - tree[target.name] = target.data_param_tree - meta_tree[target.name] = target.meta_param_tree - - self.status_tree = ParameterTree(status_dict) - tree['status'] = self.status_tree - meta_tree['status'] = self.status_tree.get("", True) - - self.param_tree = ParameterTree(tree) - self.meta_param_tree = ParameterTree(meta_tree) + # Initialise the proxy targets and parameter trees + self.initialise_proxy(ProxyTarget) @response_types('application/json', default='application/json') def get(self, path, request): """ Handle an HTTP GET request. - This method handles an HTTP GET request, returning a JSON response. + This method handles an HTTP GET request, returning a JSON response. The request is + passed to the adapter proxy and resolved into responses from the requested proxy targets. :param path: URI path of request :param request: HTTP request object :return: an ApiAdapterResponse object containing the appropriate response """ - get_metadata = wants_metadata(request) - # Update the target specified in the path, or all targets if none specified - if "/" in path: - path_elem, target_path = path.split('/', 1) - else: - path_elem = path - target_path = "" - for target in self.targets: - if path_elem == "" or path_elem == target.name: - target.remote_get(target_path, get_metadata) - # Build the response from the adapter parameter tree - try: - if get_metadata: - if path_elem == "" or path_elem == "status": - # update status tree with metadata - self.meta_param_tree.set('status', self.status_tree.get("", True)) - response = self.meta_param_tree.get(path) - - else: - response = self.param_tree.get(path) - status_code = 200 - except ParameterTreeError as param_tree_err: - response = {'error': str(param_tree_err)} - status_code = 400 + self.proxy_get(path, get_metadata) + (response, status_code) = self._resolve_response(path, get_metadata) return ApiAdapterResponse(response, status_code=status_code) @@ -299,32 +141,23 @@ def put(self, path, request): """ Handle an HTTP PUT request. - This method handles an HTTP PUT request, returning a JSON response. + This method handles an HTTP PUT request, returning a JSON response. The request is + passed to the adapter proxy to set data on the remote targets and resolved into responses + from those targets. :param path: URI path of request :param request: HTTP request object :return: an ApiAdapterResponse object containing the appropriate response """ - # Update the target specified in the path, or all targets if none specified - + # Decode the request body from JSON, handling and returning any errors that occur. Otherwise + # send the PUT request to the remote target try: - body = decode_request_body(request) # ensure request body is JSON. Will throw a TypeError if not - if "/" in path: - path_elem, target_path = path.split('/', 1) - else: - path_elem = path - target_path = "" - for target in self.targets: - if path_elem == '' or path_elem == target.name: - target.remote_set(target_path, body) - - response = self.param_tree.get(path) - status_code = 200 - except ParameterTreeError as param_tree_err: - response = {'error': str(param_tree_err)} - status_code = 400 + body = decode_request_body(request) except (TypeError, ValueError) as type_val_err: response = {'error': 'Failed to decode PUT request body: {}'.format(str(type_val_err))} status_code = 415 + else: + self.proxy_set(path, body) + (response, status_code) = self._resolve_response(path) return ApiAdapterResponse(response, status_code=status_code) diff --git a/src/odin/adapters/system_status.py b/src/odin/adapters/system_status.py index 83e422f..e731f11 100644 --- a/src/odin/adapters/system_status.py +++ b/src/odin/adapters/system_status.py @@ -294,7 +294,6 @@ def monitor_processes(self): num_processes_old = len(self._processes[process_name]) self._processes[process_name] = self.find_processes(process_name) - if len(self._processes[process_name]) != num_processes_old: self._log.debug( "Number of processes named %s is now %d", diff --git a/src/odin/async_util.py b/src/odin/async_util.py new file mode 100644 index 0000000..bd50b03 --- /dev/null +++ b/src/odin/async_util.py @@ -0,0 +1,56 @@ +"""Odin server asyncio utility functions. + +This module implements asyncio-based utility functions needed in odin-control when using +asynchronous adapters. + +Tim Nicholls, STFC Detector System Software Group. +""" +import asyncio + + +def wrap_async(object): + """Wrap an object in an async future. + + This function wraps an object in an async future and is called from wrap_result when + async objects are wrapped in python 3. A future is created, its result set to the + object passed in, and returned to the caller. + + :param object: object to wrap in a future + :return: a Future with object as its result + """ + future = asyncio.Future() + future.set_result(object) + return future + + +def get_async_event_loop(): + """Get the asyncio event loop. + + This function obtains and returns the current asyncio event loop. If no loop is present, a new + one is created and set as the event loop. + + :return: an asyncio event loop + """ + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop + + +def run_async(func, *args, **kwargs): + """Run an async function synchronously in an event loop. + + This function can be used to run an async function synchronously, i.e. without the need for an + await() call. The function is run on an asyncio event loop and the result is returned. + + :param func: async function to run + :param args: positional arguments to function + :param kwargs:: keyword arguments to function + :return: result of function + """ + loop = get_async_event_loop() + result = loop.run_until_complete(func(*args, **kwargs)) + return result diff --git a/src/odin/http/handlers/__init__.py b/src/odin/http/handlers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/odin/http/handlers/api.py b/src/odin/http/handlers/api.py new file mode 100644 index 0000000..cb17c3a --- /dev/null +++ b/src/odin/http/handlers/api.py @@ -0,0 +1,57 @@ +"""Synchronous API handler for the ODIN server. + +This module implements the synchronous API handler used by the ODIN server to pass +API calls to synchronous adapters. + +Tim Nicholls, STFC Detector Systems Software Group. +""" + +from odin.http.handlers.base import BaseApiHandler, validate_api_request, API_VERSION + + +class ApiHandler(BaseApiHandler): + """Class for handling synchrounous API requests. + + This class handles synchronous API requests, that is when the ODIN server is being + used with Tornado and python versions incompatible with native async behaviour. + """ + + @validate_api_request(API_VERSION) + def get(self, subsystem, path=''): + """Handle an API GET request. + + :param subsystem: subsystem element of URI, defining adapter to be called + :param path: remaining URI path to be passed to adapter method + """ + response = self.route.adapter(subsystem).get(path, self.request) + self.respond(response) + + @validate_api_request(API_VERSION) + def post(self, subsystem, path=''): + """Handle an API POST request. + + :param subsystem: subsystem element of URI, defining adapter to be called + :param path: remaining URI path to be passed to adapter method + """ + response = self.route.adapter(subsystem).post(path, self.request) + self.respond(response) + + @validate_api_request(API_VERSION) + def put(self, subsystem, path=''): + """Handle an API PUT request. + + :param subsystem: subsystem element of URI, defining adapter to be called + :param path: remaining URI path to be passed to adapter method + """ + response = self.route.adapter(subsystem).put(path, self.request) + self.respond(response) + + @validate_api_request(API_VERSION) + def delete(self, subsystem, path=''): + """Handle an API DELETE request. + + :param subsystem: subsystem element of URI, defining adapter to be called + :param path: remaining URI path to be passed to adapter method + """ + response = self.route.adapter(subsystem).delete(path, self.request) + self.respond(response) diff --git a/src/odin/http/handlers/async_api.py b/src/odin/http/handlers/async_api.py new file mode 100644 index 0000000..852aca9 --- /dev/null +++ b/src/odin/http/handlers/async_api.py @@ -0,0 +1,75 @@ +"""Asynchronous API handler for the ODIN server. + +This module implements the asynchronous API handler used by the ODIN server to pass +API calls to asynchronous adapters. + +Tim Nicholls, STFC Detector Systems Software Group. +""" + +from odin.http.handlers.base import BaseApiHandler, validate_api_request, API_VERSION + + +class AsyncApiHandler(BaseApiHandler): + """Class for handling asynchrounous API requests. + + This class handles asynchronous API requests, that is when the ODIN server is being + used with Tornado and python versions the implement native async behaviour. + """ + + @validate_api_request(API_VERSION) + async def get(self, subsystem, path=''): + """Handle an API GET request. + + :param subsystem: subsystem element of URI, defining adapter to be called + :param path: remaining URI path to be passed to adapter method + """ + adapter = self.route.adapter(subsystem) + if adapter.is_async: + response = await adapter.get(path, self.request) + else: + response = adapter.get(path, self.request) + + self.respond(response) + + @validate_api_request(API_VERSION) + async def post(self, subsystem, path=''): + """Handle an API POST request. + + :param subsystem: subsystem element of URI, defining adapter to be called + :param path: remaining URI path to be passed to adapter method + """ + adapter = self.route.adapter(subsystem) + if adapter.is_async: + response = await adapter.post(path, self.request) + else: + response = adapter.post(path, self.request) + + self.respond(response) + + @validate_api_request(API_VERSION) + async def put(self, subsystem, path=''): + """Handle an API PUT request. + + :param subsystem: subsystem element of URI, defining adapter to be called + :param path: remaining URI path to be passed to adapter method + """ + adapter = self.route.adapter(subsystem) + if adapter.is_async: + response = await adapter.put(path, self.request) + else: + response = adapter.put(path, self.request) + self.respond(response) + + @validate_api_request(API_VERSION) + async def delete(self, subsystem, path=''): + """Handle an API DELETE request. + + :param subsystem: subsystem element of URI, defining adapter to be called + :param path: remaining URI path to be passed to adapter method + """ + adapter = self.route.adapter(subsystem) + if adapter.is_async: + response = await adapter.delete(path, self.request) + else: + response = adapter.delete(path, self.request) + self.respond(response) diff --git a/src/odin/http/handlers/base.py b/src/odin/http/handlers/base.py new file mode 100644 index 0000000..af03ac9 --- /dev/null +++ b/src/odin/http/handlers/base.py @@ -0,0 +1,130 @@ +"""Base API handler for the ODIN server. + +This module implements the base API handler functionality from which both the concrete +synchronous and asynchronous API handler implementations inherit. + +Tim Nicholls, STFC Detector Systems Software Group. +""" + +import tornado.web + +from odin.adapters.adapter import ApiAdapterResponse +from odin.util import wrap_result +API_VERSION = 0.1 + + +class ApiError(Exception): + """Simple exception class for API-related errors.""" + + +def validate_api_request(required_version): + """Validate an API request to the ApiHandler. + + This decorator checks that API version in the URI of a requst is correct and that the subsystem + is registered with the application dispatcher; responds with a 400 error if not + """ + def decorator(func): + def wrapper(_self, *args, **kwargs): + # Extract version as first argument + version = args[0] + subsystem = args[1] + rem_args = args[2:] + if version != str(required_version): + _self.respond(ApiAdapterResponse( + "API version {} is not supported".format(version), + status_code=400)) + return wrap_result(None) + if not _self.route.has_adapter(subsystem): + _self.respond(ApiAdapterResponse( + "No API adapter registered for subsystem {}".format(subsystem), + status_code=400)) + return wrap_result(None) + return func(_self, subsystem, *rem_args, **kwargs) + return wrapper + return decorator + + +class BaseApiHandler(tornado.web.RequestHandler): + """API handler to transform requests into appropriate adapter calls. + + This handler maps incoming API requests into the appropriate calls to methods + in registered adapters. HTTP GET, PUT and DELETE verbs are supported. The class + also enforces a uniform response with the appropriate Content-Type header. + """ + + def __init__(self, *args, **kwargs): + """Construct the BaseApiHandler object. + + This method just calls the base class constructor and sets the route object to None. + """ + self.route = None + super(BaseApiHandler, self).__init__(*args, **kwargs) + + def initialize(self, route): + """Initialize the API handler. + + :param route: ApiRoute object calling the handler (allows adapters to be resolved) + """ + self.route = route + + def respond(self, response): + """Respond to an API request. + + This method transforms an ApiAdapterResponse object into the appropriate request handler + response, setting the HTTP status code and content type for a response to an API request + and validating the content of the response against the appropriate type. + + :param response: ApiAdapterResponse object containing response + """ + self.set_status(response.status_code) + self.set_header('Content-Type', response.content_type) + + data = response.data + + if response.content_type == 'application/json': + if not isinstance(response.data, (str, dict)): + raise ApiError( + 'A response with content type application/json must have str or dict data' + ) + + self.write(data) + + def get(self, subsystem, path=''): + """Handle an API GET request. + + This is an abstract method which must be implemented by derived classes. + + :param subsystem: subsystem element of URI, defining adapter to be called + :param path: remaining URI path to be passed to adapter method + """ + raise NotImplementedError() + + def post(self, subsystem, path=''): + """Handle an API POST request. + + This is an abstract method which must be implemented by derived classes. + + :param subsystem: subsystem element of URI, defining adapter to be called + :param path: remaining URI path to be passed to adapter method + """ + raise NotImplementedError() + + def put(self, subsystem, path=''): + """Handle an API PUT request. + + This is an abstract method which must be implemented by derived classes. + + :param subsystem: subsystem element of URI, defining adapter to be called + :param path: remaining URI path to be passed to adapter method + """ + raise NotImplementedError() + + def delete(self, subsystem, path=''): + """Handle an API DELETE request. + + This is an abstract method which must be implemented by derived classes. + + :param subsystem: subsystem element of URI, defining adapter to be called + :param path: remaining URI path to be passed to adapter method + """ + raise NotImplementedError() diff --git a/src/odin/http/routes/api.py b/src/odin/http/routes/api.py index 0349fd1..cb9d7cd 100644 --- a/src/odin/http/routes/api.py +++ b/src/odin/http/routes/api.py @@ -13,43 +13,13 @@ import tornado.web from odin.http.routes.route import Route -from odin.adapters.adapter import ApiAdapterResponse - -_api_version = 0.1 - - -def validate_api_request(required_version): - """Validate an API request to the ApiHandler. - - This decorator checks that API version in the URI of a requst is correct and that the subsystem - is registered with the application dispatcher; responds with a 400 error if not - """ - def decorator(func): - def wrapper(_self, *args, **kwargs): - # Extract version as first argument - version = args[0] - subsystem = args[1] - rem_args = args[2:] - if version != str(required_version): - _self.respond(ApiAdapterResponse( - "API version {} is not supported".format(version), - status_code=400) - ) - elif not _self.route.has_adapter(subsystem): - _self.respond(ApiAdapterResponse( - "No API adapter registered for subsystem {}".format(subsystem), - status_code=400) - ) - else: - return func(_self, subsystem, *rem_args, **kwargs) - return wrapper - return decorator - - -class ApiError(Exception): - """Simple exception class for API-related errors.""" - - pass +from odin.util import PY3 +from odin.http.handlers.base import ApiError, API_VERSION +if PY3: + from odin.http.handlers.async_api import AsyncApiHandler as ApiHandler + from odin.async_util import run_async +else: + from odin.http.handlers.api import ApiHandler class ApiVersionHandler(tornado.web.RequestHandler): @@ -67,7 +37,7 @@ def get(self): self.write('Requested content types not supported') return - self.write(json.dumps({'api': _api_version})) + self.write(json.dumps({'api': API_VERSION})) class ApiAdapterListHandler(tornado.web.RequestHandler): @@ -91,9 +61,8 @@ def get(self, version): :param version: API version """ - # Validate the API version explicity - can't use the validate_api_request decorator here - if version != str(_api_version): + if version != str(API_VERSION): self.set_status(400) self.write("API version {} is not supported".format(version)) return @@ -108,74 +77,6 @@ def get(self, version): self.write({'adapters': [adapter for adapter in self.route.adapters]}) -class ApiHandler(tornado.web.RequestHandler): - """API handler to transform requests into appropriate adapter calls. - - This handler maps incoming API requests into the appropriate calls to methods - in registered adapters. HTTP GET, PUT and DELETE verbs are supported. The class - also enforces a uniform response with the appropriate Content-Type header. - """ - - def initialize(self, route): - """Initialize the API handler. - - :param route: ApiRoute object calling the handler (allows adapters to be resolved) - """ - self.route = route - - @validate_api_request(_api_version) - def get(self, subsystem, path=''): - """Handle an API GET request. - - :param subsystem: subsystem element of URI, defining adapter to be called - :param path: remaining URI path to be passed to adapter method - """ - response = self.route.adapter(subsystem).get(path, self.request) - self.respond(response) - - @validate_api_request(_api_version) - def put(self, subsystem, path=''): - """Handle an API PUT request. - - :param subsystem: subsystem element of URI, defining adapter to be called - :param path: remaining URI path to be passed to adapter method - """ - response = self.route.adapter(subsystem).put(path, self.request) - self.respond(response) - - @validate_api_request(_api_version) - def delete(self, subsystem, path=''): - """Handle an API DELETE request. - - :param subsystem: subsystem element of URI, defining adapter to be called - :param path: remaining URI path to be passed to adapter method - """ - response = self.route.adapter(subsystem).delete(path, self.request) - self.respond(response) - - def respond(self, response): - """Respond to an API request. - - This method transforms an ApiAdapterResponse object into the appropriate request handler - response, setting the HTTP status code and content type for a response to an API request - and validating the content of the response against the appropriate type. - - :param response: ApiAdapterResponse object containing response - """ - self.set_status(response.status_code) - self.set_header('Content-Type', response.content_type) - - data = response.data - - if response.content_type == 'application/json': - if not isinstance(response.data, (str, dict)): - raise ApiError( - 'A response with content type application/json must have str or dict data' - ) - - self.write(data) - - class ApiRoute(Route): """ApiRoute - API route object used to map handlers onto adapter for API calls.""" @@ -218,7 +119,11 @@ def register_adapter(self, adapter_config, fail_ok=True): try: adapter_module = importlib.import_module(module_name) adapter_class = getattr(adapter_module, class_name) - self.adapters[adapter_config.name] = adapter_class(**adapter_config.options()) + if PY3 and adapter_class.is_async: + adapter = run_async(adapter_class, **adapter_config.options()) + else: + adapter = adapter_class(**adapter_config.options()) + self.adapters[adapter_config.name] = adapter except (ImportError, AttributeError) as e: logging.error( @@ -256,7 +161,11 @@ def cleanup_adapters(self): """ for adapter_name, adapter in self.adapters.items(): try: - getattr(adapter, 'cleanup')() + cleanup_method = getattr(adapter, 'cleanup') + if PY3 and adapter.is_async: + run_async(cleanup_method) + else: + cleanup_method() except AttributeError: logging.debug("Adapter %s has no cleanup method", adapter_name) @@ -268,6 +177,10 @@ def initialize_adapters(self): """ for adapter_name, adapter in self.adapters.items(): try: - getattr(adapter, 'initialize')(self.adapters) + initialize_method = getattr(adapter, 'initialize') + if PY3 and adapter.is_async: + run_async(initialize_method, self.adapters) + else: + initialize_method(self.adapters) except AttributeError: - logging.debug("Adapter %s has no Initialize method", adapter_name) + logging.debug("Adapter %s has no initialize method", adapter_name) diff --git a/src/odin/util.py b/src/odin/util.py index 839b62d..2ab32da 100644 --- a/src/odin/util.py +++ b/src/odin/util.py @@ -3,10 +3,17 @@ This module implements utility methods for Odin Server. """ import sys + +from tornado import version_info from tornado.escape import json_decode +from tornado.ioloop import IOLoop PY3 = sys.version_info >= (3,) +if PY3: + from odin.async_util import get_async_event_loop, wrap_async + unicode = str + def decode_request_body(request): """Extract the body from a request. @@ -54,3 +61,50 @@ def convert_unicode_to_string(obj): return obj.encode("utf-8") # Obj is none of the above, just return it return obj + + +def wrap_result(result, is_async=True): + """ + Conditionally wrap a result in an aysncio Future if being used in async code on python 3. + + This is to allow common functions for e.g. request validation, to be used in both + async and sync code across python variants. + + :param is_async: optional flag for if desired outcome is a result wrapped in a future + + :return: either the result or a Future wrapping the result + """ + if is_async and PY3: + return wrap_async(result) + else: + return result + + +def run_in_executor(executor, func, *args): + """ + Run a function asynchronously in an executor. + + This method extends the behaviour of Tornado IOLoop equivalent to allow nested task execution + without having to modify the underlying asyncio loop creation policy on python 3. If the + current execution context does not have a valid IO loop, a new one will be created and used. + The method returns a tornado Future instance, allowing it to be awaited in an async method where + applicable. + + :param executor: a concurrent.futures.Executor instance to run the task in + :param func: the function to execute + :param arg: list of arguments to pass to the function + + :return: a Future wrapping the task + """ + # In python 3, try to get the current asyncio event loop, otherwise create a new one + if PY3: + get_async_event_loop() + + # Run the function in the specified executor, handling tornado version 4 where there was no + # run_in_executor implementation + if version_info[0] <= 4: + future = executor.submit(func, *args) + else: + future = IOLoop.current().run_in_executor(executor, func, *args) + + return future diff --git a/tests/adapters/test_adapter.py b/tests/adapters/test_adapter.py index 7ab12a7..a553b23 100644 --- a/tests/adapters/test_adapter.py +++ b/tests/adapters/test_adapter.py @@ -41,6 +41,15 @@ def test_adapter_get(self, test_api_adapter): assert response.data == 'GET method not implemented by ApiAdapter' assert response.status_code == 400 + def test_adapter_post(self, test_api_adapter): + """ + Test the the adapter responds to a GET request correctly by returning a 400 code and + appropriate message. This is due to the base adapter not implementing the methods. + """ + response = test_api_adapter.adapter.post(test_api_adapter.path, test_api_adapter.request) + assert response.data == 'POST method not implemented by ApiAdapter' + assert response.status_code == 400 + def test_adapter_put(self, test_api_adapter): """ Test the the adapter responds to a PUT request correctly by returning a 400 code and @@ -207,6 +216,8 @@ def __init__(self): self.response_type_json = 'application/json' self.response_data_json = {'response': 'JSON response'} + self.is_async = False + @request_types('application/json', 'text/plain') @response_types('application/json', 'text/plain', default='application/json') def decorated_method(self, path, request): diff --git a/tests/adapters/test_async_adapter_py3.py b/tests/adapters/test_async_adapter_py3.py new file mode 100644 index 0000000..aae0c9c --- /dev/null +++ b/tests/adapters/test_async_adapter_py3.py @@ -0,0 +1,94 @@ +import sys + +import pytest + +if sys.version_info[0] < 3: + pytest.skip("Skipping async tests", allow_module_level=True) +else: + from odin.adapters.async_adapter import AsyncApiAdapter + from unittest.mock import Mock + +class AsyncApiAdapterTestFixture(object): + + def __init__(self): + self.adapter_options = { + 'test_option_float' : 1.234, + 'test_option_str' : 'value', + 'test_option_int' : 4567. + } + self.adapter = AsyncApiAdapter(**self.adapter_options) + self.path = '/api/async_path' + self.request = Mock() + self.request.headers = {'Accept': '*/*', 'Content-Type': 'text/plain'} + +@pytest.fixture(scope="class") +def test_async_api_adapter(): + test_async_api_adapter = AsyncApiAdapterTestFixture() + yield test_async_api_adapter + +class TestAsyncApiAdapter(): + """Class to test the AsyncApiAdapter object.""" + + @pytest.mark.asyncio + async def test_async_adapter_get(self, test_async_api_adapter): + """ + Test the the adapter responds to a GET request correctly by returning a 400 code and + appropriate message. This is due to the base adapter not implementing the methods. + """ + response = await test_async_api_adapter.adapter.get( + test_async_api_adapter.path, test_async_api_adapter.request) + assert response.data == 'GET method not implemented by AsyncApiAdapter' + assert response.status_code == 400 + + @pytest.mark.asyncio + async def test_async_adapter_post(self, test_async_api_adapter): + """ + Test the the adapter responds to a POST request correctly by returning a 400 code and + appropriate message. This is due to the base adapter not implementing the methods. + """ + response = await test_async_api_adapter.adapter.post( + test_async_api_adapter.path, test_async_api_adapter.request) + assert response.data == 'POST method not implemented by AsyncApiAdapter' + assert response.status_code == 400 + + @pytest.mark.asyncio + async def test_async_adapter_put(self, test_async_api_adapter): + """ + Test the the adapter responds to a PUT request correctly by returning a 400 code and + appropriate message. This is due to the base adapter not implementing the methods. + """ + response = await test_async_api_adapter.adapter.put( + test_async_api_adapter.path, test_async_api_adapter.request) + assert response.data == 'PUT method not implemented by AsyncApiAdapter' + assert response.status_code == 400 + + @pytest.mark.asyncio + async def test_adapter_delete(self, test_async_api_adapter): + """ + Test the the adapter responds to a DELETE request correctly by returning a 400 code and + appropriate message. This is due to the base adapter not implementing the methods. + """ + response = await test_async_api_adapter.adapter.delete( + test_async_api_adapter.path, test_async_api_adapter.request) + assert response.data == 'DELETE method not implemented by AsyncApiAdapter' + assert response.status_code == 400 + + @pytest.mark.asyncio + async def test_adapter_initialize(self, test_async_api_adapter): + """Test the the adapter initialize function runs without error.""" + raised = False + try: + await test_async_api_adapter.adapter.initialize(None) + except: + raised = True + assert not raised + + @pytest.mark.asyncio + async def test_adapter_cleanup(self, test_async_api_adapter): + """Test the the adapter cleanup function runs without error.""" + raised = False + try: + await test_async_api_adapter.adapter.cleanup() + except: + raised = True + assert not raised diff --git a/tests/adapters/test_async_dummy_py3.py b/tests/adapters/test_async_dummy_py3.py new file mode 100644 index 0000000..96f6d1e --- /dev/null +++ b/tests/adapters/test_async_dummy_py3.py @@ -0,0 +1,100 @@ +import sys + +import pytest + +if sys.version_info[0] < 3: + pytest.skip("Skipping async tests", allow_module_level=True) +else: + import asyncio + from odin.adapters.async_dummy import AsyncDummyAdapter + from unittest.mock import Mock + from tests.async_utils import AwaitableTestFixture, asyncio_fixture_decorator + + +class AsyncDummyAdapterTestFixture(AwaitableTestFixture): + """Container class used in fixtures for testing the AsyncDummyAdapter.""" + def __init__(self, wrap_sync_sleep=False): + """ + Initialise the adapter and associated test objects. + + The wrap_sync_sleep argument steers the adapter options, controlling how + the simulated task is executed, either wrapping a synchronous function + or using native asyncio sleep. + """ + + super(AsyncDummyAdapterTestFixture, self).__init__(AsyncDummyAdapter) + + self.adapter_options = { + 'wrap_sync_sleep': wrap_sync_sleep, + 'async_sleep_duration': 0.1 + } + self.adapter = AsyncDummyAdapter(**self.adapter_options) + self.path = '' + self.bad_path = 'missing/path' + self.rw_path = 'async_rw_param' + self.request = Mock() + self.request.body = '{}' + self.request.headers = {'Accept': 'application/json', 'Content-Type': 'application/json'} + +@pytest.fixture(scope="class") +def event_loop(): + """Redefine the pytest.asyncio event loop fixture to have class scope.""" + loop = asyncio.get_event_loop() + yield loop + loop.close() + +@asyncio_fixture_decorator(scope='class', params=[True, False], ids=['wrapped', 'native']) +async def test_dummy_adapter(request): + """ + Parameterised test fixture for use with AsyncDummyAdapter tests. The fixture + parameters generate tests using this fixture for both wrapped and native async task + simulation. + """ + test_dummy_adapter = await AsyncDummyAdapterTestFixture(request.param) + adapters = [test_dummy_adapter.adapter] + await test_dummy_adapter.adapter.initialize(adapters) + yield test_dummy_adapter + await test_dummy_adapter.adapter.cleanup() + + +@pytest.mark.asyncio +class TestAsyncDummyAdapter(): + + async def test_adapter_get(self, test_dummy_adapter): + + response = await test_dummy_adapter.adapter.get( + test_dummy_adapter.path, test_dummy_adapter.request) + assert isinstance(response.data, dict) + assert response.status_code == 200 + + async def test_adapter_get_bad_path(self, test_dummy_adapter): + + expected_response = {'error': 'Invalid path: {}'.format(test_dummy_adapter.bad_path)} + response = await test_dummy_adapter.adapter.get( + test_dummy_adapter.bad_path, test_dummy_adapter.request) + assert response.data == expected_response + assert response.status_code == 400 + + async def test_adapter_put(self, test_dummy_adapter): + + rw_request = Mock() + rw_request.headers = test_dummy_adapter.request.headers + rw_request.body = 4567 + + await test_dummy_adapter.adapter.put(test_dummy_adapter.rw_path, rw_request) + + response = await test_dummy_adapter.adapter.get( + test_dummy_adapter.rw_path, test_dummy_adapter.request) + + assert isinstance(response.data, dict) + assert response.data[test_dummy_adapter.rw_path] == rw_request.body + assert response.status_code == 200 + + async def test_adapter_put_bad_path(self, test_dummy_adapter): + + expected_response = {'error': 'Invalid path: {}'.format(test_dummy_adapter.bad_path)} + response = await test_dummy_adapter.adapter.put( + test_dummy_adapter.bad_path, test_dummy_adapter.request + ) + assert response.data == expected_response + assert response.status_code == 400 \ No newline at end of file diff --git a/tests/adapters/test_async_parameter_tree_py3.py b/tests/adapters/test_async_parameter_tree_py3.py new file mode 100644 index 0000000..998e94b --- /dev/null +++ b/tests/adapters/test_async_parameter_tree_py3.py @@ -0,0 +1,1138 @@ +"""Test the AsyncParameterTree classes. + +This module implements unit test cases for the AsyncParameterAccessor and AsyncParameterTree +classes. + +Tim Nicholls, STFC Detector Systems Software Group. +""" + +import asyncio +import math +import sys + +from copy import deepcopy + +import pytest + +if sys.version_info[0] < 3: + pytest.skip("Skipping async tests", allow_module_level=True) +else: + + from tests.async_utils import AwaitableTestFixture, asyncio_fixture_decorator + from odin.adapters.async_parameter_tree import ( + AsyncParameterAccessor, AsyncParameterTree, ParameterTreeError + ) + + +class AsyncParameterAccessorTestFixture(AwaitableTestFixture): + """Test fixture of AsyncParameterAccessor test cases.""" + def __init__(self): + + super(AsyncParameterAccessorTestFixture, self).__init__(AsyncParameterAccessor) + + self.static_rw_path = 'static_rw' + self.static_rw_value = 2.76923 + self.static_rw_accessor = AsyncParameterAccessor( + self.static_rw_path + '/', self.static_rw_value + ) + + self.sync_ro_value = 1234 + self.sync_ro_path = 'sync_ro' + self.sync_ro_accessor = AsyncParameterAccessor( + self.sync_ro_path + '/', self.sync_ro_get + ) + + self.sync_rw_value = 'foo' + self.sync_rw_path = 'sync_rw' + self.sync_rw_accessor = AsyncParameterAccessor( + self.sync_rw_path + '/', self.sync_rw_get, self.sync_rw_set + ) + + self.async_ro_value = 5593 + self.async_ro_path = 'async_ro' + self.async_ro_accessor = AsyncParameterAccessor( + self.async_ro_path + '/', self.async_ro_get + ) + + self.async_rw_value = math.pi + self.async_rw_path = 'async_rw' + self.async_rw_accessor = AsyncParameterAccessor( + self.async_rw_path + '/', self.async_rw_get, self.async_rw_set + ) + + self.md_param_path ='mdparam' + self.md_param_value = 456 + self.md_param_metadata = { + 'min' : 100, + 'max' : 1000, + "allowed_values": [100, 123, 456, 789, 1000], + "name": "Test Parameter", + "description": "This is a test parameter", + "units": "furlongs/fortnight", + "display_precision": 0, + } + self.md_accessor = AsyncParameterAccessor( + self.md_param_path + '/', self.async_md_get, self.async_md_set, **self.md_param_metadata + ) + + self.md_minmax_path = 'minmaxparam' + self.md_minmax_value = 500 + self.md_minmax_metadata = { + 'min': 100, + 'max': 1000 + } + self.md_minmax_accessor = AsyncParameterAccessor( + self.md_minmax_path + '/', self.async_md_minmax_get, self.async_md_minmax_set, + **self.md_minmax_metadata + ) + + def sync_ro_get(self): + return self.sync_ro_value + + def sync_rw_get(self): + return self.sync_rw_value + + def sync_rw_set(self, value): + self.sync_rw_value = value + + async def async_ro_get(self): + await asyncio.sleep(0) + return self.async_ro_value + + async def async_rw_get(self): + await asyncio.sleep(0) + return self.async_rw_value + + async def async_rw_set(self, value): + await asyncio.sleep(0) + self.async_rw_value = value + + async def async_md_get(self): + await asyncio.sleep(0) + return self.md_param_value + + async def async_md_set(self, value): + await asyncio.sleep(0) + self.async_md_param_value = value + + async def async_md_minmax_get(self): + await asyncio.sleep(0) + return self.md_minmax_value + + async def async_md_minmax_set(self, value): + await asyncio.sleep(0) + self.md_minmax_value = value + +@pytest.fixture(scope="class") +def event_loop(): + """Redefine the pytest.asyncio event loop fixture to have class scope.""" + loop = asyncio.get_event_loop() + yield loop + loop.close() + +@asyncio_fixture_decorator(scope="class") +async def test_param_accessor(): + """Test fixture used in testing ParameterAccessor behaviour.""" + test_param_accessor = await AsyncParameterAccessorTestFixture() + yield test_param_accessor + +@pytest.mark.asyncio +class TestAsyncParameterAccessor(): + """Class to test AsyncParameterAccessor behaviour""" + + async def test_static_rw_accessor_get(self, test_param_accessor): + """Test that a static RW accessor get call returns the correct value.""" + value = await test_param_accessor.static_rw_accessor.get() + assert value == test_param_accessor.static_rw_value + + async def test_static_rw_accessor_set(self, test_param_accessor): + """Test that a static RW accessor set call sets the correct value.""" + old_val = test_param_accessor.static_rw_value + new_val = 1.234 + await test_param_accessor.static_rw_accessor.set(new_val) + value = await test_param_accessor.static_rw_accessor.get() + assert value == new_val + + await test_param_accessor.static_rw_accessor.set(old_val) + + async def test_sync_ro_accessor_get(self, test_param_accessor): + """Test that a synchronous callable RO accessor get call returns the correct value.""" + value = await test_param_accessor.sync_ro_accessor.get() + assert value == test_param_accessor.sync_ro_value + + async def test_sync_ro_accessor_set(self, test_param_accessor): + """Test that a synchronous callable RO accessor set call raises an error.""" + new_val = 91265 + with pytest.raises(ParameterTreeError) as excinfo: + await test_param_accessor.sync_ro_accessor.set(new_val) + + assert "Parameter {} is read-only".format(test_param_accessor.sync_ro_path) \ + in str(excinfo.value) + + async def test_sync_rw_accessor_get(self, test_param_accessor): + """Test that a synchronous callable RW accessor returns the correct value.""" + value = await test_param_accessor.sync_rw_accessor.get() + assert value == test_param_accessor.sync_rw_value + + async def test_sync_rw_accessor_set(self, test_param_accessor): + """Test that a synchronous callable RW accessor set call sets the correct value.""" + old_val = test_param_accessor.sync_rw_value + new_val = 'bar' + await test_param_accessor.sync_rw_accessor.set(new_val) + value = await test_param_accessor.sync_rw_accessor.get() + assert value == new_val + + await test_param_accessor.sync_rw_accessor.set(old_val) + + async def test_async_ro_accessor_get(self, test_param_accessor): + """Test that an asynchronous callable RO accessor get call returns the correct value.""" + value = await test_param_accessor.async_ro_accessor.get() + assert value == test_param_accessor.async_ro_value + + async def test_async_ro_accessor_set(self, test_param_accessor): + """Test that an asynchronous callable RO accessor set call raises an error.""" + new_val = 91265 + with pytest.raises(ParameterTreeError) as excinfo: + await test_param_accessor.async_ro_accessor.set(new_val) + + assert "Parameter {} is read-only".format(test_param_accessor.async_ro_path) \ + in str(excinfo.value) + + async def test_async_rw_accessor_get(self, test_param_accessor): + """Test that an asynchronous callable RW accessor get returns the correct value.""" + value = await test_param_accessor.async_rw_accessor.get() + assert value == test_param_accessor.async_rw_value + + async def test_async_rw_accessor_set(self, test_param_accessor): + """Test that an asynchronous callable RW accessor sets the correct value.""" + old_val = test_param_accessor.async_rw_value + new_val = old_val * 2 + await test_param_accessor.async_rw_accessor.set(new_val) + value = await test_param_accessor.async_rw_accessor.get() + assert value == new_val + + async def test_static_rw_accessor_default_metadata(self, test_param_accessor): + """Test that a static RW accessor has the appropriate default metadata.""" + param = await test_param_accessor.static_rw_accessor.get(with_metadata=True) + assert(isinstance(param, dict)) + assert param['value'] == test_param_accessor.static_rw_value + assert param['type'] == type(test_param_accessor.static_rw_value).__name__ + assert param['writeable'] == True + + async def test_sync_ro_accessor_default_metadata(self, test_param_accessor): + """Test that a synchronous callable RO accesor has the appropriate default metadata.""" + param = await test_param_accessor.sync_ro_accessor.get(with_metadata=True) + assert param['value'] == test_param_accessor.sync_ro_value + assert param['type'] == type(test_param_accessor.sync_ro_value).__name__ + assert param['writeable'] == False + + async def test_sync_rw_accessor_default_metadata(self, test_param_accessor): + """Test that a synchronous callable RW accesor has the appropriate default metadata.""" + param = await test_param_accessor.sync_rw_accessor.get(with_metadata=True) + assert param['value'] == test_param_accessor.sync_rw_value + assert param['type'] == type(test_param_accessor.sync_rw_value).__name__ + assert param['writeable'] == True + + async def test_sync_ro_accessor_default_metadata(self, test_param_accessor): + """Test that a synchronous callable RO accesor has the appropriate default metadata.""" + param = await test_param_accessor.sync_ro_accessor.get(with_metadata=True) + assert param['value'] == test_param_accessor.sync_ro_value + assert param['type'] == type(test_param_accessor.sync_ro_value).__name__ + assert param['writeable'] == False + + async def test_async_rw_accessor_default_metadata(self, test_param_accessor): + """Test that an asynchronous callable RW accesor has the appropriate default metadata.""" + param = await test_param_accessor.async_rw_accessor.get(with_metadata=True) + assert param['value'] == test_param_accessor.async_rw_value + assert param['type'] == type(test_param_accessor.async_rw_value).__name__ + assert param['writeable'] == True + + async def test_async_ro_accessor_default_metadata(self, test_param_accessor): + """Test that an asynchronous callable RO accesor has the appropriate default metadata.""" + param = await test_param_accessor.async_ro_accessor.get(with_metadata=True) + assert param['value'] == test_param_accessor.async_ro_value + assert param['type'] == type(test_param_accessor.async_ro_value).__name__ + assert param['writeable'] == False + + async def test_metadata_param_accessor_metadata(self, test_param_accessor): + """Test that a parameter accessor has the correct metadata fields.""" + param = await test_param_accessor.md_accessor.get(with_metadata=True) + for md_field in test_param_accessor.md_param_metadata: + assert md_field in param + assert param[md_field] == test_param_accessor.md_param_metadata[md_field] + assert param['value'] == test_param_accessor.md_param_value + assert param['type'] == type(test_param_accessor.md_param_value).__name__ + assert param['writeable'] == True + + async def test_param_accessor_bad_metadata_arg(self, test_param_accessor): + """Test that a parameter accessor with a bad metadata argument raises an error.""" + bad_metadata_argument = 'foo' + bad_metadata = {bad_metadata_argument: 'bar'} + with pytest.raises(ParameterTreeError) as excinfo: + _ = await AsyncParameterAccessor( + test_param_accessor.static_rw_path + '/', + test_param_accessor.static_rw_value, **bad_metadata + ) + + assert "Invalid metadata argument: {}".format(bad_metadata_argument) \ + in str(excinfo.value) + + async def test_param_accessor_set_type_mismatch(self, test_param_accessor): + """ + Test that setting the value of a parameter accessor with the incorrected type raises + an error. + """ + bad_value = 'bar' + bad_value_type = type(bad_value).__name__ + + with pytest.raises(ParameterTreeError) as excinfo: + await test_param_accessor.async_rw_accessor.set(bad_value) + + assert "Type mismatch setting {}: got {} expected {}".format( + test_param_accessor.async_rw_path, bad_value_type, + type(test_param_accessor.async_rw_value).__name__ + ) in str(excinfo.value) + + async def test_param_accessor_bad_allowed_value(self, test_param_accessor): + """ + Test the setting the value of a parameter accessor to a disallowed value raises an error. + """ + bad_value = 222 + with pytest.raises(ParameterTreeError) as excinfo: + await test_param_accessor.md_accessor.set(bad_value) + + assert "{} is not an allowed value for {}".format( + bad_value, test_param_accessor.md_param_path + ) in str(excinfo.value) + + async def test_param_accessor_value_below_min(self, test_param_accessor): + """ + Test that setting the value of a parameter accessor below the minimum allowed raises an + error. + """ + bad_value = 1 + with pytest.raises(ParameterTreeError) as excinfo: + await test_param_accessor.md_minmax_accessor.set(bad_value) + + assert "{} is below the minimum value {} for {}".format( + bad_value, test_param_accessor.md_minmax_metadata['min'], + test_param_accessor.md_minmax_path + ) in str(excinfo.value) + + async def test_param_accessor_value_above_max(self, test_param_accessor): + """ + Test that setting the value of a parameter accessor above the maximum allowed raises an + error. + """ + bad_value = 100000 + with pytest.raises(ParameterTreeError) as excinfo: + await test_param_accessor.md_minmax_accessor.set(bad_value) + + assert "{} is above the maximum value {} for {}".format( + bad_value, test_param_accessor.md_minmax_metadata['max'], + test_param_accessor.md_minmax_path + ) in str(excinfo.value) + + +class AsyncParameterTreeTestFixture(AwaitableTestFixture): + """Container class for use in fixtures testing AsyncParameterTree.""" + + def __init__(self): + + super(AsyncParameterTreeTestFixture, self).__init__(AsyncParameterTree) + + self.int_value = 1234 + self.float_value = 3.1415 + self.bool_value = True + self.str_value = 'theString' + self.list_values = list(range(4)) + + self.simple_dict = { + 'intParam': self.int_value, + 'floatParam': self.float_value, + 'boolParam': self.bool_value, + 'strParam': self.str_value, + } + + self.accessor_params = { + 'one': 1, + 'two': 2, + 'pi': 3.14 + } + self.simple_tree = AsyncParameterTree(self.simple_dict) + + # Set up nested dict of parameters for a more complex tree + self.nested_dict = self.simple_dict.copy() + self.nested_dict['branch'] = { + 'branchIntParam': 4567, + 'branchStrParam': 'theBranch', + } + self.nested_tree = AsyncParameterTree(self.nested_dict) + + self.complex_tree = AsyncParameterTree({ + 'intParam': self.int_value, + 'callableRoParam': (lambda: self.int_value, None), + 'callableAccessorParam': (self.get_accessor_param, None), + 'listParam': self.list_values, + 'branch': AsyncParameterTree(deepcopy(self.nested_dict)), + }) + + self.list_tree = AsyncParameterTree({ + 'main' : [ + self.simple_dict.copy(), + list(self.list_values) + ] + }) + + self.simple_list_tree = AsyncParameterTree({ + 'list_param': [10, 11, 12, 13] + }) + + async def async_ro_get(self): + await asyncio.sleep(0) + return self.async_ro_value + + async def nested_async_ro_get(self): + await asyncio.sleep(0) + return self.nested_async_ro_value + + async def get_accessor_param(self): + await asyncio.sleep(0) + return self.accessor_params + + +@asyncio_fixture_decorator(scope="class") +async def test_param_tree(): + """Test fixture used in testing AsyncParameterTree behaviour.""" + test_param_accessor = await AsyncParameterTreeTestFixture() + yield test_param_accessor + + +@pytest.mark.asyncio +class TestAsyncParameterTree(): + + async def test_simple_tree_returns_dict(self, test_param_tree): + """Test the get on a simple tree returns a dict.""" + dt_vals = await test_param_tree.simple_tree.get('') + assert dt_vals, test_param_tree.simple_dict + assert True + + async def test_simple_tree_single_values(self, test_param_tree): + """Test that getting single values from a simple tree returns the correct values.""" + dt_int_val = await test_param_tree.simple_tree.get('intParam') + assert dt_int_val['intParam'] == test_param_tree.int_value + + dt_float_val = await test_param_tree.simple_tree.get('floatParam') + assert dt_float_val['floatParam'] == test_param_tree.float_value + + dt_bool_val = await test_param_tree.simple_tree.get('boolParam') + assert dt_bool_val['boolParam'] == test_param_tree.bool_value + + dt_str_val = await test_param_tree.simple_tree.get('strParam') + assert dt_str_val['strParam'] == test_param_tree.str_value + + async def test_simple_tree_missing_value(self, test_param_tree): + """Test that getting a missing value from a simple tree raises an error.""" + with pytest.raises(ParameterTreeError) as excinfo: + await test_param_tree.simple_tree.get('missing') + + assert 'Invalid path: missing' in str(excinfo.value) + + async def test_nested_tree_returns_nested_dict(self, test_param_tree): + """Test that getting a nested tree return a dict.""" + nested_dt_vals = await test_param_tree.nested_tree.get('') + assert nested_dt_vals == test_param_tree.nested_dict + + async def test_nested_tree_branch_returns_dict(self, test_param_tree): + """Test that getting a tree from within a nested tree returns a dict.""" + branch_vals = await test_param_tree.nested_tree.get('branch') + assert branch_vals['branch'] == test_param_tree.nested_dict['branch'] + + async def test_nested_tree_trailing_slash(self, test_param_tree): + """Test that getting a tree with trailing slash returns the correct dict.""" + branch_vals = await test_param_tree.nested_tree.get('branch/') + assert branch_vals['branch'] == test_param_tree.nested_dict['branch'] + + async def test_set_with_extra_branch_paths(self, test_param_tree): + """ + Test that modifiying a branch in a tree with extra parameters raises an error. + """ + branch_data = deepcopy(test_param_tree.nested_dict['branch']) + branch_data['extraParam'] = 'oops' + + with pytest.raises(ParameterTreeError) as excinfo: + await test_param_tree.complex_tree.set('branch', branch_data) + + assert 'Invalid path' in str(excinfo.value) + + async def test_complex_tree_calls_leaf_nodes(self, test_param_tree): + """ + Test that accessing valyus in a complex tree returns the correct values for + static and callable parameters. + """ + complex_vals = await test_param_tree.complex_tree.get('') + assert complex_vals['intParam'] == test_param_tree.int_value + assert complex_vals['callableRoParam'] == test_param_tree.int_value + + async def test_complex_tree_access_list_param(self, test_param_tree): + """Test that getting a list parameter from a complex tree returns the appropriate values.""" + list_param_vals = await test_param_tree.complex_tree.get('listParam') + assert list_param_vals['listParam'] == test_param_tree.list_values + + async def test_complex_tree_callable_readonly(self, test_param_tree): + """ + Test that attempting to set the value of a RO callable parameter in a tree raises an + error. + """ + with pytest.raises(ParameterTreeError) as excinfo: + await test_param_tree.complex_tree.set('callableRoParam', 1234) + + assert 'Parameter callableRoParam is read-only' in str(excinfo.value) + + async def test_complex_tree_set_invalid_path(self, test_param_tree): + """ + Test that attempting to set the value of an element in a complex tree on a path + that does not exist raises an error. + """ + invalid_path = 'invalidPath/toNothing' + + with pytest.raises(ParameterTreeError) as excinfo: + await test_param_tree.complex_tree.set(invalid_path, 0) + + assert 'Invalid path: {}'.format(invalid_path) in str(excinfo.value) + + async def test_complex_tree_set_top_level(self, test_param_tree): + """Test that setting the top level of a complex tree correctly sets all values.""" + complex_vals = await test_param_tree.complex_tree.get('') + complex_vals_copy = deepcopy(complex_vals) + del complex_vals_copy['callableRoParam'] + del complex_vals_copy['callableAccessorParam'] + + await test_param_tree.complex_tree.set('', complex_vals_copy) + complex_vals2 = await test_param_tree.complex_tree.get('') + assert complex_vals == complex_vals2 + + async def test_complex_tree_inject_spurious_dict(self, test_param_tree): + """ + Test that attempting to attempt a dictionary into the position of a non-dict parameter + raises in error. + """ + param_data = {'intParam': 9876} + + with pytest.raises(ParameterTreeError) as excinfo: + await test_param_tree.complex_tree.set('intParam', param_data) + + assert 'Type mismatch updating intParam' in str(excinfo.value) + + async def test_list_tree_get_indexed(self, test_param_tree): + """ + Test that it is possible to get a value by index from a list parameter. + """ + ret = await test_param_tree.list_tree.get("main/1") + assert ret == {'1':test_param_tree.list_values} + + async def test_list_tree_set_indexed(self, test_param_tree): + """ + Test that it is possible to set a value by index on a list parameter. + """ + await test_param_tree.list_tree.set("main/1/2", 7) + assert await test_param_tree.list_tree.get("main/1/2") == {'2': 7} + + async def test_list_tree_set_from_root(self, test_param_tree): + """Test that it is possible to set a list tree from its root.""" + tree_data = { + 'main' : [ + { + 'intParam': 0, + 'floatParam': 0.00, + 'boolParam': False, + 'strParam': "test", + }, + [1,2,3,4] + ] + } + + await test_param_tree.list_tree.set("",tree_data) + assert await test_param_tree.list_tree.get("main") == tree_data + + async def test_list_tree_from_dict(self, test_param_tree): + """TEet that a list tree can be set with a dict of index/values.""" + new_list_param = {0: 0, 1: 1, 2: 2, 3: 3} + await test_param_tree.simple_list_tree.set('list_param', new_list_param) + result = await test_param_tree.simple_list_tree.get('list_param') + assert result['list_param']== list(new_list_param.values()) + + + async def test_list_tree_from_dict_bad_index(self, test_param_tree): + """ + Test that setting a list tree from a dict with an index outside the current range + raises an error. + """ + new_list_param = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5} + with pytest.raises(ParameterTreeError) as excinfo: + await test_param_tree.simple_list_tree.set('list_param', new_list_param) + + assert "Invalid path: list_param/4 list index out of range" in str(excinfo.value) + + async def test_bad_tuple_node_raises_error(self, test_param_tree): + """Test that constructing a parameter tree with an immutable tuple raises an error.""" + bad_node = 'bad' + bad_data = tuple(range(4)) + bad_tree = { + bad_node: bad_data + } + with pytest.raises(ParameterTreeError) as excinfo: + tree = AsyncParameterTree(bad_tree) + + assert "not a valid leaf node" in str(excinfo.value) + + +class AsyncRwParameterTreeTestFixture(AwaitableTestFixture): + """Container class for use in async read-write parameter tree test fixtures.""" + + def __init__(self): + + super(AsyncRwParameterTreeTestFixture, self).__init__(AsyncParameterTree) + + self.int_rw_param = 4576 + self.int_ro_param = 255374 + self.int_rw_value = 9876 + self.int_wo_param = 0 + + self.rw_value_set_called = False + + self.nested_rw_param = 53.752 + self.nested_ro_value = 9.8765 + + nested_tree = AsyncParameterTree({ + 'nestedRwParam': (self.nestedRwParamGet, self.nestedRwParamSet), + 'nestedRoParam': self.nested_ro_value + }) + + self.rw_callable_tree = AsyncParameterTree({ + 'intCallableRwParam': (self.intCallableRwParamGet, self.intCallableRwParamSet), + 'intCallableRoParam': (self.intCallableRoParamGet, None), + 'intCallableWoParam': (None, self.intCallableWoParamSet), + 'intCallableRwValue': (self.int_rw_value, self.intCallableRwValueSet), + 'branch': nested_tree + }) + + async def intCallableRwParamSet(self, value): + await asyncio.sleep(0) + self.int_rw_param = value + + async def intCallableRwParamGet(self): + await asyncio.sleep(0) + return self.int_rw_param + + async def intCallableRoParamGet(self): + await asyncio.sleep(0) + return self.int_ro_param + + async def intCallableWoParamSet(self, value): + await asyncio.sleep(0) + self.int_wo_param = value + + async def intCallableRwValueSet(self, value): + await asyncio.sleep(0) + self.rw_value_set_called = True + + async def nestedRwParamSet(self, value): + await asyncio.sleep(0) + self.nested_rw_param = value + + async def nestedRwParamGet(self): + await asyncio.sleep(0) + return self.nested_rw_param + + +@asyncio_fixture_decorator(scope="class") +async def test_rw_tree(): + """Test fixture for use in testing read-write parameter trees.""" + test_rw_tree = await AsyncRwParameterTreeTestFixture() + yield test_rw_tree + + +@pytest.mark.asyncio +class TestAsyncRwParameterTree(): + """Class to test behaviour of async read-write parameter trees.""" + + async def test_rw_tree_simple_get_values(self, test_rw_tree): + """Test getting simple values from a RW tree returns the correct values.""" + dt_rw_int_param = await test_rw_tree.rw_callable_tree.get('intCallableRwParam') + assert dt_rw_int_param['intCallableRwParam'] == test_rw_tree.int_rw_param + + dt_ro_int_param = await test_rw_tree.rw_callable_tree.get('intCallableRoParam') + assert dt_ro_int_param['intCallableRoParam'] == test_rw_tree.int_ro_param + + dt_rw_int_value = await test_rw_tree.rw_callable_tree.get('intCallableRwValue') + assert dt_rw_int_value['intCallableRwValue'] == test_rw_tree.int_rw_value + + async def test_rw_tree_simple_set_value(self, test_rw_tree): + """Test that setting a value in a RW tree updates and returns the correct value.""" + new_int_value = 91210 + await test_rw_tree.rw_callable_tree.set('intCallableRwParam', new_int_value) + + dt_rw_int_param = await test_rw_tree.rw_callable_tree.get('intCallableRwParam') + assert dt_rw_int_param['intCallableRwParam'] == new_int_value + + async def test_rw_tree_set_ro_param(self, test_rw_tree): + """Test that attempting to set a RO parameter raises an error.""" + with pytest.raises(ParameterTreeError) as excinfo: + await test_rw_tree.rw_callable_tree.set('intCallableRoParam', 0) + + assert 'Parameter intCallableRoParam is read-only' in str(excinfo.value) + + async def test_rw_callable_tree_set_wo_param(self, test_rw_tree): + """Test that setting a write-only parameter (!!) sets the correct value.""" + new_value = 1234 + await test_rw_tree.rw_callable_tree.set('intCallableWoParam', new_value) + assert test_rw_tree.int_wo_param == new_value + + async def test_rw_callable_tree_set_rw_value(self, test_rw_tree): + """Test that setting a callable RW value calls the appropriate set method.""" + new_value = 1234 + await test_rw_tree.rw_callable_tree.set('intCallableRwValue', new_value) + assert test_rw_tree.rw_value_set_called + + async def test_rw_callable_nested_param_get(self, test_rw_tree): + """Test the getting a nested callable RW parameter returns the correct value.""" + dt_nested_param = await test_rw_tree.rw_callable_tree.get('branch/nestedRwParam') + assert dt_nested_param['nestedRwParam'] == test_rw_tree.nested_rw_param + + async def test_rw_callable_nested_param_set(self, test_rw_tree): + """Test that setting a nested callable RW parameter sets the correct value.""" + new_float_value = test_rw_tree.nested_rw_param + 2.3456 + await test_rw_tree.rw_callable_tree.set('branch/nestedRwParam', new_float_value) + assert test_rw_tree.nested_rw_param == new_float_value + + async def test_rw_callable_nested_tree_set(self, test_rw_tree): + """Test the setting a value within a callable nested tree updated the value correctly.""" + result = await test_rw_tree.rw_callable_tree.get('branch') + nested_branch = result['branch'] + new_rw_param_val = 45.876 + nested_branch['nestedRwParam'] = new_rw_param_val + await test_rw_tree.rw_callable_tree.set('branch', nested_branch) + result = await test_rw_tree.rw_callable_tree.get('branch') + assert result['branch']['nestedRwParam'], new_rw_param_val + + async def test_rw_callable_nested_tree_set_trailing_slash(self, test_rw_tree): + """ + Test that setting a callable nested tree with a trailing slash in the path + sets the value correctly. + """ + result = await test_rw_tree.rw_callable_tree.get('branch/') + nested_branch = result['branch'] + new_rw_param_val = 24.601 + nested_branch['nestedRwParam'] = new_rw_param_val + await test_rw_tree.rw_callable_tree.set('branch/', nested_branch) + result = await test_rw_tree.rw_callable_tree.get('branch/') + assert result['branch']['nestedRwParam'] == new_rw_param_val + + +class AsyncParameterTreeMetadataTestFixture(AwaitableTestFixture): + """Container class for use in test fixtures testing parameter tree metadata.""" + + def __init__(self): + + super(AsyncParameterTreeMetadataTestFixture, self).__init__(AsyncParameterTree) + + self.int_rw_param = 100 + self.float_ro_param = 4.6593 + self.int_ro_param = 1000 + self.int_enum_param = 0 + self.int_enum_param_allowed_values = [0, 1, 2, 3, 5, 8, 13] + + self.int_rw_param_metadata = { + "min": 0, + "max": 1000, + "units": "arbitrary", + "name": "intCallableRwParam", + "description": "A callable integer RW parameter" + } + + self.metadata_tree_dict = { + 'name': 'Metadata Tree', + 'description': 'A paramter tree to test metadata', + 'floatRoParam': (self.floatRoParamGet,), + 'intRoParam': (self.intRoParamGet, {"units": "seconds"}), + 'intCallableRwParam': ( + self.intCallableRwParamGet, self.intCallableRwParamSet, self.int_rw_param_metadata + ), + 'intEnumParam': (0, {"allowed_values": self.int_enum_param_allowed_values}), + 'valueParam': (24601,), + 'minNoMaxParam': (1, {'min': 0}) + } + self.metadata_tree = AsyncParameterTree(self.metadata_tree_dict) + + def intCallableRwParamSet(self, value): + self.int_rw_param = value + + def intCallableRwParamGet(self): + return self.int_rw_param + + def floatRoParamGet(self): + return self.float_ro_param + + def intRoParamGet(self): + return self.int_ro_param + + +@asyncio_fixture_decorator(scope="class") +async def test_tree_metadata(): + """Test fixture for use in testing parameter tree metadata.""" + test_tree_metadata = await AsyncParameterTreeMetadataTestFixture() + yield test_tree_metadata + +@pytest.mark.asyncio +class TestAsyncParameterTreeMetadata(): + + async def test_callable_rw_param_metadata(self, test_tree_metadata): + """Test that a getting RW parameter with metadata returns the appropriate metadata.""" + int_param_with_metadata = await test_tree_metadata.metadata_tree.get( + "intCallableRwParam",with_metadata=True) + result = await test_tree_metadata.metadata_tree.get("intCallableRwParam") + int_param = result["intCallableRwParam"] + + expected_metadata = test_tree_metadata.int_rw_param_metadata + expected_metadata["value"] = int_param + expected_metadata["type"] = 'int' + expected_metadata["writeable"] = True + expected_param = {"intCallableRwParam" : expected_metadata} + + assert int_param_with_metadata == expected_param + + async def test_get_filters_tree_metadata(self, test_tree_metadata): + """ + Test that attempting to get a metadata field for a parameter as if it was path itself + raises an error. + """ + metadata_path = "name" + with pytest.raises(ParameterTreeError) as excinfo: + await test_tree_metadata.metadata_tree.get(metadata_path) + + assert "Invalid path: {}".format(metadata_path) in str(excinfo.value) + + async def test_set_tree_rejects_metadata(self, test_tree_metadata): + """ + Test that attampeting to set a metadata field as if it was a parameter raises an error. + """ + metadata_path = "name" + with pytest.raises(ParameterTreeError) as excinfo: + await test_tree_metadata.metadata_tree.set(metadata_path, "invalid") + + assert "Invalid path: {}".format(metadata_path) in str(excinfo.value) + + async def test_enum_param_allowed_values(self, test_tree_metadata): + """Test that setting an enumerated parameter with an allowed value succeeds.""" + for value in test_tree_metadata.int_enum_param_allowed_values: + await test_tree_metadata.metadata_tree.set("intEnumParam", value) + result = await test_tree_metadata.metadata_tree.get("intEnumParam") + set_value = result["intEnumParam"] + assert value == set_value + + async def test_enum_param_bad_value(self, test_tree_metadata): + """ + Test that attempting to set a disallowed value for an enumerated parameter raises an error. + """ + bad_value = test_tree_metadata.int_enum_param_allowed_values[-1] + 1 + with pytest.raises(ParameterTreeError) as excinfo: + await test_tree_metadata.metadata_tree.set("intEnumParam", bad_value) + + assert "{} is not an allowed value".format(bad_value) in str(excinfo.value) + + async def test_ro_param_has_writeable_metadata_field(self, test_tree_metadata): + """Test that a RO parameter has the writeable metadata field set to false.""" + ro_param = await test_tree_metadata.metadata_tree.get("floatRoParam", with_metadata=True) + assert ro_param["floatRoParam"]["writeable"] == False + + async def test_ro_param_not_writeable(self, test_tree_metadata): + """Test that attempting to write to a RO parameter with metadata raises an error.""" + with pytest.raises(ParameterTreeError) as excinfo: + await test_tree_metadata.metadata_tree.set("floatRoParam", 3.141275) + assert "Parameter {} is read-only".format("floatRoParam") in str(excinfo.value) + + async def test_value_param_writeable(self, test_tree_metadata): + """Test that a value parameter is writeable and has the correct metadata flag.""" + new_value = 90210 + await test_tree_metadata.metadata_tree.set("valueParam", new_value) + result = await test_tree_metadata.metadata_tree.get("valueParam", with_metadata=True) + set_param = result["valueParam"] + assert set_param["value"] == new_value + assert set_param["writeable"] == True + + async def test_rw_param_min_no_max(self, test_tree_metadata): + """Test that a parameter with a minimum but no maximum works as expected.""" + new_value = 2 + await test_tree_metadata.metadata_tree.set("minNoMaxParam", new_value) + result = await test_tree_metadata.metadata_tree.get("minNoMaxParam", with_metadata=True) + set_param = result["minNoMaxParam"] + assert set_param["value"] == new_value + assert set_param["writeable"] == True + + async def test_rw_param_below_min_value(self, test_tree_metadata): + """ + Test that attempting to set a value for a RW parameter below the specified minimum + raises an error. + """ + low_value = -1 + with pytest.raises(ParameterTreeError) as excinfo: + await test_tree_metadata.metadata_tree.set("intCallableRwParam", low_value) + + assert "{} is below the minimum value {} for {}".format( + low_value, test_tree_metadata.int_rw_param_metadata["min"], + "intCallableRwParam") in str(excinfo.value) + + async def test_rw_param_above_max_value(self, test_tree_metadata): + """ + Test that attempting to set a value for a RW parameter above the specified maximum + raises an error. + """ + high_value = 100000 + with pytest.raises(ParameterTreeError) as excinfo: + await test_tree_metadata.metadata_tree.set("intCallableRwParam", high_value) + + assert "{} is above the maximum value {} for {}".format( + high_value, test_tree_metadata.int_rw_param_metadata["max"], + "intCallableRwParam") in str(excinfo.value) + + +class AsyncParameterTreeMutableTestFixture(AwaitableTestFixture): + + def __init__(self): + + super(AsyncParameterTreeMutableTestFixture, self).__init__(AsyncParameterTree) + + self.read_value = 64 + self.write_value = "test" + + self.param_tree_dict = { + 'extra': 'wibble', + 'bonus': 'win', + 'nest': { + 'double_nest': { + 'nested_val': 125, + 'dont_touch': "let me stay!", + 'write': (self.get_write, self.set_write) + }, + 'list': [0, 1, {'list_test': "test"}, 3] + }, + 'read': (self.get_read,), + 'empty': {} + } + + self.param_tree = AsyncParameterTree(self.param_tree_dict) + self.param_tree.mutable = True + + async def get_read(self): + return self.read_value + + async def get_write(self): + return self.write_value + + async def set_write(self, data): + self.write_value = data + +@asyncio_fixture_decorator() +async def test_tree_mutable(): + """Test fixture for use in testing parameter tree metadata.""" + test_tree_mutable = await AsyncParameterTreeMutableTestFixture() + yield test_tree_mutable + +@pytest.mark.asyncio +class TestAsyncParamTreeMutable(): + """Class to test the behaviour of mutable async parameter trees""" + + async def test_mutable_put_differnt_data_type(self, test_tree_mutable): + + new_data = 75 + await test_tree_mutable.param_tree.set('bonus', new_data) + val = await test_tree_mutable.param_tree.get('bonus') + assert val['bonus'] == new_data + + async def test_mutable_put_new_branch_node(self, test_tree_mutable): + + new_node = {"new": 65} + await test_tree_mutable.param_tree.set('extra', new_node) + + val = await test_tree_mutable.param_tree.get('extra') + assert val['extra'] == new_node + + async def test_mutable_put_new_sibling_node(self, test_tree_mutable): + + new_node = {'new': 65} + path = 'nest' + + await test_tree_mutable.param_tree.set(path, new_node) + val = await test_tree_mutable.param_tree.get(path) + assert 'new' in val[path] + + async def test_mutable_put_overwrite_param_accessor_read_only(self, test_tree_mutable): + + new_node = {"Node": "Broke Accessor"} + with pytest.raises(ParameterTreeError) as excinfo: + await test_tree_mutable.param_tree.set('read', new_node) + + assert "is read-only" in str(excinfo.value) + + async def test_mutable_put_overwrite_param_accessor_read_write(self, test_tree_mutable): + + new_node = {"Node": "Broke Accessor"} + path = 'nest/double_nest/write' + + with pytest.raises(ParameterTreeError) as excinfo: + await test_tree_mutable.param_tree.set(path, new_node) + + assert "Type mismatch setting" in str(excinfo.value) + + async def test_mutable_put_replace_nested_path(self, test_tree_mutable): + + new_node = {"double_nest": 294} + path = 'nest' + + await test_tree_mutable.param_tree.set(path, new_node) + val = await test_tree_mutable.param_tree.get(path) + assert val[path]['double_nest'] == new_node['double_nest'] + + async def test_mutable_put_merge_nested_path(self, test_tree_mutable): + + new_node = { + "double_nest": { + 'nested_val': { + "additional_val": "New value Here!", + "add_int": 648 + } + } + } + path = 'nest' + + await test_tree_mutable.param_tree.set(path, new_node) + val = await test_tree_mutable.param_tree.get(path) + assert val[path]['double_nest']['nested_val'] == new_node['double_nest']['nested_val'] + assert 'dont_touch' in val[path]['double_nest'] + + async def test_mutable_delete_method(self, test_tree_mutable): + + path = 'nest/double_nest' + + test_tree_mutable.param_tree.delete(path) + tree = await test_tree_mutable.param_tree.get('') + assert 'double_nest' not in tree['nest'] + with pytest.raises(ParameterTreeError) as excinfo: + await test_tree_mutable.param_tree.get(path) + + assert "Invalid path" in str(excinfo.value) + + async def test_mutable_delete_immutable_tree(self, test_tree_mutable): + + test_tree_mutable.param_tree.mutable = False + + with pytest.raises(ParameterTreeError) as excinfo: + path = 'nest/double_nest' + await test_tree_mutable.param_tree.delete(path) + + assert "Invalid Delete Attempt" in str(excinfo.value) + + async def test_mutable_delete_entire_tree(self, test_tree_mutable): + + path = '' + + test_tree_mutable.param_tree.delete(path) + val = await test_tree_mutable.param_tree.get(path) + assert not val + + async def test_mutable_delete_invalid_path(self, test_tree_mutable): + + path = 'nest/not_real' + + with pytest.raises(ParameterTreeError) as excinfo: + await test_tree_mutable.param_tree.delete(path) + + assert "Invalid path" in str(excinfo.value) + + async def test_mutable_delete_from_list(self, test_tree_mutable): + + path = 'nest/list/3' + + test_tree_mutable.param_tree.delete(path) + val = await test_tree_mutable.param_tree.get('nest/list') + assert '3' not in val['list'] + + async def test_mutable_delete_from_dict_in_list(self, test_tree_mutable): + path = 'nest/list/2/list_test' + + test_tree_mutable.param_tree.delete(path) + val = await test_tree_mutable.param_tree.get('nest/list') + assert {'list_test': "test"} not in val['list'] + + async def test_mutable_nested_tree_in_immutable_tree(self, test_tree_mutable): + + new_tree = await AsyncParameterTree({ + 'immutable_param': "Hello", + "nest": { + "tree": test_tree_mutable.param_tree + } + }) + + new_node = {"new": 65} + path = 'nest/tree/extra' + await new_tree.set(path, new_node) + val = await new_tree.get(path) + assert val['extra'] == new_node + + async def test_mutable_nested_tree_external_change(self, test_tree_mutable): + + new_tree = await AsyncParameterTree({ + 'immutable_param': "Hello", + "tree": test_tree_mutable.param_tree + }) + + new_node = {"new": 65} + path = 'tree/extra' + await test_tree_mutable.param_tree.set('extra', new_node) + val = await new_tree.get(path) + assert val['extra'] == new_node + + async def test_mutable_nested_tree_delete(self, test_tree_mutable): + + new_tree = await AsyncParameterTree({ + 'immutable_param': "Hello", + "tree": test_tree_mutable.param_tree + }) + + path = 'tree/bonus' + new_tree.delete(path) + + tree = await new_tree.get('') + + assert 'bonus' not in tree['tree'] + + with pytest.raises(ParameterTreeError) as excinfo: + await test_tree_mutable.param_tree.get(path) + + assert "Invalid path" in str(excinfo.value) + + async def test_mutable_nested_tree_root_tree_not_affected(self, test_tree_mutable): + + new_tree = await AsyncParameterTree({ + 'immutable_param': "Hello", + "nest": { + "tree": test_tree_mutable.param_tree + } + }) + + new_node = {"new": 65} + path = 'immutable_param' + + with pytest.raises(ParameterTreeError) as excinfo: + await new_tree.set(path, new_node) + + assert "Type mismatch" in str(excinfo.value) + + async def test_mutable_add_to_empty_dict(self, test_tree_mutable): + + new_node = {"new": 65} + path = 'empty' + await test_tree_mutable.param_tree.set(path, new_node) + val = await test_tree_mutable.param_tree.get(path) + assert val[path] == new_node \ No newline at end of file diff --git a/tests/adapters/test_async_proxy_py3.py b/tests/adapters/test_async_proxy_py3.py new file mode 100644 index 0000000..b37921a --- /dev/null +++ b/tests/adapters/test_async_proxy_py3.py @@ -0,0 +1,395 @@ +""" +Unit tests for the odin-control AsyncProxyAdapter class. + +Tim Nicholls, STFC Detector Systems Software Group. +""" + +import logging +import sys +from io import StringIO + +import pytest +from zmq import proxy + +if sys.version_info[0] < 3: + pytest.skip("Skipping async tests", allow_module_level=True) +else: + import asyncio + from tornado.ioloop import TimeoutError + from tornado.httpclient import HTTPResponse + from odin.adapters.async_proxy import AsyncProxyTarget, AsyncProxyAdapter + from unittest.mock import Mock + from tests.adapters.test_proxy import ProxyTestHandler, ProxyTargetTestFixture, ProxyTestServer + from odin.util import convert_unicode_to_string + from tests.utils import log_message_seen + from tests.async_utils import AwaitableTestFixture, asyncio_fixture_decorator + try: + from unittest.mock import AsyncMock + except ImportError: + from tests.async_utils import AsyncMock + + +@pytest.fixture +def test_proxy_target(): + test_proxy_target = ProxyTargetTestFixture(AsyncProxyTarget) + yield test_proxy_target + test_proxy_target.stop() + + +class TestAsyncProxyTarget(): + """Test cases for the AsyncProxyTarget class.""" + + @pytest.mark.asyncio + async def test_async_proxy_target_init(self, test_proxy_target): + """Test the proxy target is correctly initialised.""" + assert test_proxy_target.proxy_target.name == test_proxy_target.name + assert test_proxy_target.proxy_target.url == test_proxy_target.url + assert test_proxy_target.proxy_target.request_timeout == test_proxy_target.request_timeout + + @pytest.mark.asyncio + async def test_async_proxy_target_remote_get(self, test_proxy_target): + """Test the that remote GET to a proxy target succeeds.""" + test_proxy_target.proxy_target.last_update = '' + + await test_proxy_target.proxy_target.remote_get() + assert test_proxy_target.proxy_target.data == ProxyTestHandler.param_tree.get("") + assert test_proxy_target.proxy_target.status_code == 200 + assert test_proxy_target.proxy_target.last_update != '' + + def test_async_proxy_target_param_tree_get(self, test_proxy_target): + """Test that a proxy target get returns a parameter tree.""" + param_tree = test_proxy_target.proxy_target.status_param_tree.get('') + for tree_element in ['url', 'status_code', 'error', 'last_update']: + assert tree_element in param_tree + + @pytest.mark.asyncio + async def test_async_proxy_target_http_get_error_404(self, test_proxy_target): + """Test that a proxy target GET to a bad URL returns a 404 not found error.""" + bad_url = test_proxy_target.url + 'notfound/' + proxy_target = await AsyncProxyTarget( + test_proxy_target.name, bad_url, test_proxy_target.request_timeout + ) + await proxy_target.remote_get('notfound') + + assert proxy_target.status_code == 404 + assert 'Not Found' in proxy_target.error_string + + @pytest.mark.asyncio + async def test_async_proxy_target_timeout_error(self, test_proxy_target): + """Test that a proxy target GET request that times out is handled correctly""" + mock_fetch = Mock() + mock_fetch.side_effect = TimeoutError('timeout') + proxy_target = await AsyncProxyTarget( + test_proxy_target.name, test_proxy_target.url, test_proxy_target.request_timeout + ) + proxy_target.http_client.fetch = mock_fetch + + await proxy_target.remote_get() + + assert proxy_target.status_code == 408 + assert 'timeout' in proxy_target.error_string + + @pytest.mark.asyncio + async def test_async_proxy_target_io_error(self, test_proxy_target): + """Test that a proxy target GET request to a non-existing server returns a 502 error.""" + bad_url = 'http://127.0.0.1:{}'.format(test_proxy_target.port + 1) + proxy_target = await AsyncProxyTarget( + test_proxy_target.name, bad_url, test_proxy_target.request_timeout + ) + await proxy_target.remote_get() + + assert proxy_target.status_code == 502 + assert 'Connection refused' in proxy_target.error_string + + @pytest.mark.asyncio + async def test_async_proxy_target_unknown_error(self, test_proxy_target): + """Test that a proxy target GET request handles an unknown exception returning a 500 error.""" + mock_fetch = Mock() + mock_fetch.side_effect = ValueError('value error') + proxy_target = await AsyncProxyTarget( + test_proxy_target.name, test_proxy_target.url, test_proxy_target.request_timeout + ) + proxy_target.http_client.fetch = mock_fetch + + await proxy_target.remote_get() + + assert proxy_target.status_code == 500 + assert 'value error' in proxy_target.error_string + + @pytest.mark.asyncio + async def test_async_proxy_target_traps_decode_error(self, test_proxy_target): + """Test that a proxy target correctly traps errors decoding a non-JSON response body.""" + mock_fetch = AsyncMock() + mock_fetch.return_value = HTTPResponse(Mock(), 200, buffer=StringIO(u"wibble")) + + proxy_target = await AsyncProxyTarget( + test_proxy_target.name, test_proxy_target.url, test_proxy_target.request_timeout + ) + proxy_target.http_client.fetch = mock_fetch + + await proxy_target.remote_get() + + assert proxy_target.status_code == 415 + assert "Failed to decode response body" in proxy_target.error_string + + +class AsyncProxyAdapterTestFixture(AwaitableTestFixture): + """Container class used in fixtures for testing async proxy adapters.""" + + def __init__(self): + + super(AsyncProxyAdapterTestFixture, self).__init__(AsyncProxyAdapter) + + """Initialise the fixture, setting up the AsyncProxyAdapter with the correct configuration.""" + self.num_targets = 2 + + self.test_servers = [] + self.ports = [] + self.target_config = "" + + # Launch the appropriate number of target test servers."""" + for _ in range(self.num_targets): + + test_server = ProxyTestServer() + self.test_servers.append(test_server) + self.ports.append(test_server.port) + + self.target_config = ','.join([ + "node_{}=http://127.0.0.1:{}/".format(tgt, port) for (tgt, port) in enumerate(self.ports) + ]) + + self.adapter_kwargs = { + 'targets': self.target_config, + 'request_timeout': 1.0, + } + self.adapter = AsyncProxyAdapter(**self.adapter_kwargs) + + self.path = '' + self.request = Mock() + self.request.headers = {'Accept': 'application/json', 'Content-Type': 'application/json'} + self.request.body = '{"pi":2.56}' + + def __del__(self): + """Ensure the test servers are stopped on deletion.""" + self.stop() + + def stop(self): + """Stop the proxied test servers, ensuring any client connections to them are closed.""" + for target in self.adapter.targets: + target.http_client.close() + for test_server in self.test_servers: + test_server.stop() + + def clear_access_counts(self): + """Clear the access counters in all test servers.""" + for test_server in self.test_servers: + test_server.clear_access_count() + +@pytest.fixture(scope="class") +def event_loop(): + """Redefine the pytest.asyncio event loop fixture to have class scope.""" + loop = asyncio.get_event_loop() + yield loop + loop.close() + +@asyncio_fixture_decorator(scope='class') +async def async_proxy_adapter_test(): + async_proxy_adapter_test = await AsyncProxyAdapterTestFixture() + adapters = [async_proxy_adapter_test] + await async_proxy_adapter_test.adapter.initialize([adapters]) + yield async_proxy_adapter_test + await async_proxy_adapter_test.adapter.cleanup() + +class TestAsyncProxyAdapter(): + + def test_adapter_loaded(self, async_proxy_adapter_test): + assert len(async_proxy_adapter_test.adapter.targets) == async_proxy_adapter_test.num_targets + + @pytest.mark.asyncio + async def test_adapter_get(self, async_proxy_adapter_test): + """ + Test that a GET request to the proxy adapter returns the appropriate data for all + defined proxied targets. + """ + response = await async_proxy_adapter_test.adapter.get( + async_proxy_adapter_test.path, async_proxy_adapter_test.request) + + assert 'status' in response.data + + assert len(response.data) == async_proxy_adapter_test.num_targets + 1 + + for tgt in range(async_proxy_adapter_test.num_targets): + node_str = 'node_{}'.format(tgt) + assert node_str in response.data + assert response.data[node_str], ProxyTestHandler.data + + @pytest.mark.asyncio + async def test_adapter_get_metadata(self, async_proxy_adapter_test): + request = async_proxy_adapter_test.request + request.headers['Accept'] = "{};{}".format(request.headers['Accept'], "metadata=True") + response = await async_proxy_adapter_test.adapter.get(async_proxy_adapter_test.path, request) + + assert "status" in response.data + for target in range(async_proxy_adapter_test.num_targets): + node_str = 'node_{}'.format(target) + assert node_str in response.data + assert "one" in response.data[node_str] + assert "type" in response.data[node_str]['one'] + + @pytest.mark.asyncio + async def test_adapter_get_status_metadata(self, async_proxy_adapter_test): + request = async_proxy_adapter_test.request + request.headers['Accept'] = "{};{}".format(request.headers['Accept'], "metadata=True") + response = await async_proxy_adapter_test.adapter.get(async_proxy_adapter_test.path, request) + + assert 'status' in response.data + assert 'node_0' in response.data['status'] + assert 'type' in response.data['status']['node_0']['error'] + + @pytest.mark.asyncio + async def test_adapter_put(self, async_proxy_adapter_test): + """ + Test that a PUT request to the proxy adapter returns the appropriate data for all + defined proxied targets. + """ + response = await async_proxy_adapter_test.adapter.put( + async_proxy_adapter_test.path, async_proxy_adapter_test.request) + + assert 'status' in response.data + + assert len(response.data) == async_proxy_adapter_test.num_targets + 1 + + for tgt in range(async_proxy_adapter_test.num_targets): + node_str = 'node_{}'.format(tgt) + assert node_str in response.data + assert convert_unicode_to_string(response.data[node_str]) == ProxyTestHandler.param_tree.get("") + + @pytest.mark.asyncio + async def test_adapter_get_proxy_path(self, async_proxy_adapter_test): + """Test that a GET to a sub-path within a targer succeeds and return the correct data.""" + node = async_proxy_adapter_test.adapter.targets[0].name + path = "more/even_more" + response = await async_proxy_adapter_test.adapter.get( + "{}/{}".format(node, path), async_proxy_adapter_test.request) + + assert response.data["even_more"] == ProxyTestHandler.data["more"]["even_more"] + assert async_proxy_adapter_test.adapter.param_tree.get('')['status'][node]['status_code'] == 200 + + @pytest.mark.asyncio + async def test_adapter_get_proxy_path_trailing_slash(self, async_proxy_adapter_test): + """ + Test that a PUT to a sub-path with a trailing slash in the URL within a targer succeeds + and returns the correct data. + """ + node = async_proxy_adapter_test.adapter.targets[0].name + path = "more/even_more/" + response = await async_proxy_adapter_test.adapter.get( + "{}/{}".format(node, path), async_proxy_adapter_test.request) + + assert response.data["even_more"] == ProxyTestHandler.data["more"]["even_more"] + assert async_proxy_adapter_test.adapter.param_tree.get('')['status'][node]['status_code'] == 200 + + @pytest.mark.asyncio + async def test_adapter_put_proxy_path(self, async_proxy_adapter_test): + """ + Test that a PUT to a sub-path without a trailing slash in the URL within a targer succeeds + and returns the correct data. + """ + node = async_proxy_adapter_test.adapter.targets[0].name + path = "more" + async_proxy_adapter_test.request.body = '{"replace": "been replaced"}' + response = await async_proxy_adapter_test.adapter.put( + "{}/{}".format(node, path), async_proxy_adapter_test.request) + + assert async_proxy_adapter_test.adapter.param_tree.get('')['status'][node]['status_code'] == 200 + assert convert_unicode_to_string(response.data["more"]["replace"]) == "been replaced" + + @pytest.mark.asyncio + async def test_adapter_get_bad_path(self, async_proxy_adapter_test): + """Test that a GET to a bad path within a target returns the appropriate error.""" + missing_path = 'missing/path' + response = await async_proxy_adapter_test.adapter.get(missing_path, async_proxy_adapter_test.request) + + assert 'error' in response.data + assert 'Invalid path: {}'.format(missing_path) == response.data['error'] + + @pytest.mark.asyncio + async def test_adapter_put_bad_path(self, async_proxy_adapter_test): + """Test that a PUT to a bad path within a target returns the appropriate error.""" + missing_path = 'missing/path' + response = await async_proxy_adapter_test.adapter.put(missing_path, async_proxy_adapter_test.request) + + assert 'error' in response.data + assert 'Invalid path: {}'.format(missing_path) == response.data['error'] + + @pytest.mark.asyncio + async def test_adapter_put_bad_type(self, async_proxy_adapter_test): + """Test that a PUT request with an inappropriate type returns the appropriate error.""" + async_proxy_adapter_test.request.body = "bad_body" + response = await async_proxy_adapter_test.adapter.put( + async_proxy_adapter_test.path, async_proxy_adapter_test.request) + + assert 'error' in response.data + assert 'Failed to decode PUT request body:' in response.data['error'] + + @pytest.mark.asyncio + async def test_adapter_bad_timeout(self, async_proxy_adapter_test, caplog): + """Test that a bad timeout specified for the proxy adatper yields a logged error message.""" + bad_timeout = 'not_timeout' + _ = await AsyncProxyAdapter(request_timeout=bad_timeout) + + assert log_message_seen(caplog, logging.ERROR, + 'Illegal timeout specified for proxy adapter: {}'.format(bad_timeout)) + + @pytest.mark.asyncio + async def test_adapter_bad_target_spec(self, caplog): + """ + Test that an incorrectly formatted target specified passed to a proxy adapter yields a + logged error message. + """ + bad_target_spec = 'bad_target_1,bad_target_2' + _ = await AsyncProxyAdapter(targets=bad_target_spec) + + assert log_message_seen(caplog, logging.ERROR, + "Illegal target specification for proxy adapter: bad_target_1") + + @pytest.mark.asyncio + async def test_adapter_no_target_spec(self, caplog): + """ + Test that a proxy adapter instantiated with no target specifier yields a logged + error message. + """ + _ = await AsyncProxyAdapter() + + assert log_message_seen(caplog, logging.ERROR, + "Failed to resolve targets for proxy adapter") + + @pytest.mark.asyncio + async def test_adapter_get_access_count(self, async_proxy_adapter_test): + """ + Test that requests via the proxy adapter correctly increment the access counters in the + target test servers. + """ + async_proxy_adapter_test.clear_access_counts() + + _ = await async_proxy_adapter_test.adapter.get( + async_proxy_adapter_test.path, async_proxy_adapter_test.request + ) + + access_counts = [server.get_access_count() for server in async_proxy_adapter_test.test_servers] + assert access_counts == [1]*async_proxy_adapter_test.num_targets + + @pytest.mark.asyncio + async def test_adapter_counter_get_single_node(self, async_proxy_adapter_test): + """ + Test that a requested to a single target in the proxy adapter only accesses that target, + increasing the access count appropriately. + """ + path = async_proxy_adapter_test.path + 'node_{}'.format(async_proxy_adapter_test.num_targets-1) + + async_proxy_adapter_test.clear_access_counts() + response = await async_proxy_adapter_test.adapter.get(path, async_proxy_adapter_test.request) + access_counts = [server.get_access_count() for server in async_proxy_adapter_test.test_servers] + + assert path in response.data + assert sum(access_counts) == 1 diff --git a/tests/adapters/test_proxy.py b/tests/adapters/test_proxy.py index e7998c0..b68892f 100644 --- a/tests/adapters/test_proxy.py +++ b/tests/adapters/test_proxy.py @@ -7,11 +7,13 @@ import threading import logging import time +from io import StringIO import pytest -from tornado.testing import AsyncHTTPTestCase, bind_unused_port +from tornado.testing import bind_unused_port from tornado.ioloop import IOLoop +from tornado.httpclient import HTTPResponse from tornado.web import Application, RequestHandler from tornado.httpserver import HTTPServer import tornado.gen @@ -125,7 +127,7 @@ def clear_access_count(self): class ProxyTargetTestFixture(object): """Container class used in fixtures for testing ProxyTarget.""" - def __init__(self): + def __init__(self, proxy_target_cls): """Initialise the fixture, starting the test server and defining a target.""" self.test_server = ProxyTestServer() self.port = self.test_server.port @@ -134,7 +136,7 @@ def __init__(self): self.url = 'http://127.0.0.1:{}/'.format(self.port) self.request_timeout = 0.1 - self.proxy_target = ProxyTarget(self.name, self.url, self.request_timeout) + self.proxy_target = proxy_target_cls(self.name, self.url, self.request_timeout) def __del__(self): """Ensure test server is stopped on deletion.""" @@ -149,7 +151,7 @@ def stop(self): @pytest.fixture def test_proxy_target(): """Fixture used in ProxyTarget test cases.""" - test_proxy_target = ProxyTargetTestFixture() + test_proxy_target = ProxyTargetTestFixture(ProxyTarget) yield test_proxy_target test_proxy_target.stop() @@ -189,7 +191,7 @@ def test_proxy_target_http_get_error_404(self, test_proxy_target): assert 'Not Found' in proxy_target.error_string def test_proxy_target_timeout_error(self, test_proxy_target): - """Test that a porxy target GET request that times out is handled correctly""" + """Test that a proxy target GET request that times out is handled correctly""" mock_fetch = Mock() mock_fetch.side_effect = tornado.ioloop.TimeoutError('timeout') proxy_target = ProxyTarget(test_proxy_target.name, test_proxy_target.url, @@ -201,7 +203,7 @@ def test_proxy_target_timeout_error(self, test_proxy_target): assert proxy_target.status_code == 408 assert 'timeout' in proxy_target.error_string - def test_proxy_target_other_error(self, test_proxy_target): + def test_proxy_target_io_error(self, test_proxy_target): """Test that a proxy target GET request to a non-existing server returns a 502 error.""" bad_url = 'http://127.0.0.1:{}'.format(test_proxy_target.port + 1) proxy_target = ProxyTarget(test_proxy_target.name, bad_url, @@ -211,6 +213,34 @@ def test_proxy_target_other_error(self, test_proxy_target): assert proxy_target.status_code == 502 assert 'Connection refused' in proxy_target.error_string + def test_proxy_target_unknown_error(self, test_proxy_target): + """Test that a proxy target GET request handles an unknown exception returning a 500 error.""" + mock_fetch = Mock() + mock_fetch.side_effect = ValueError('value error') + proxy_target = ProxyTarget( + test_proxy_target.name, test_proxy_target.url, test_proxy_target.request_timeout + ) + proxy_target.http_client.fetch = mock_fetch + + proxy_target.remote_get() + + assert proxy_target.status_code == 500 + assert 'value error' in proxy_target.error_string + + def test_proxy_target_traps_decode_error(self, test_proxy_target): + """Test that a proxy target correctly traps errors decoding a non-JSON response body.""" + mock_fetch = Mock() + mock_fetch.return_value = HTTPResponse(Mock(), 200, buffer=StringIO(u"wibble")) + + proxy_target = ProxyTarget( + test_proxy_target.name, test_proxy_target.url, test_proxy_target.request_timeout + ) + proxy_target.http_client.fetch = mock_fetch + + proxy_target.remote_get() + print(proxy_target.status_code, proxy_target.error_string) + assert proxy_target.status_code == 415 + assert "Failed to decode response body" in proxy_target.error_string class ProxyAdapterTestFixture(): """Container class used in fixtures for testing proxy adapters.""" @@ -402,7 +432,7 @@ def test_adapter_bad_timeout(self, proxy_adapter_test, caplog): _ = ProxyAdapter(request_timeout=bad_timeout) assert log_message_seen(caplog, logging.ERROR, - 'Illegal timeout specified for ProxyAdapter: {}'.format(bad_timeout)) + 'Illegal timeout specified for proxy adapter: {}'.format(bad_timeout)) def test_adapter_bad_target_spec(self, proxy_adapter_test, caplog): """ @@ -412,18 +442,18 @@ def test_adapter_bad_target_spec(self, proxy_adapter_test, caplog): bad_target_spec = 'bad_target_1,bad_target_2' _ = ProxyAdapter(targets=bad_target_spec) - assert log_message_seen(caplog, logging.ERROR, - "Illegal target specification for ProxyAdapter: bad_target_1") + assert log_message_seen(caplog, logging.ERROR, + "Illegal target specification for proxy adapter: bad_target_1") def test_adapter_no_target_spec(self, caplog): """ - Test that a proxy adapter instantiated with no target specifier yields a logged + Test that a proxy adapter instantiated with no target specifier yields a logged error message. """ _ = ProxyAdapter() - assert log_message_seen(caplog, logging.ERROR, - "Failed to resolve targets for ProxyAdapter") + assert log_message_seen(caplog, logging.ERROR, + "Failed to resolve targets for proxy adapter") def test_adapter_get_access_count(self, proxy_adapter_test): """ @@ -447,6 +477,6 @@ def test_adapter_counter_get_single_node(self, proxy_adapter_test): proxy_adapter_test.clear_access_counts() response = proxy_adapter_test.adapter.get(path, proxy_adapter_test.request) access_counts = [server.get_access_count() for server in proxy_adapter_test.test_servers] - + assert path in response.data assert sum(access_counts) == 1 diff --git a/tests/async_utils.py b/tests/async_utils.py new file mode 100644 index 0000000..92b3b52 --- /dev/null +++ b/tests/async_utils.py @@ -0,0 +1,72 @@ +""" +ijtiset +""" +import asyncio +from unittest.mock import NonCallableMock, CallableMixin + +import pytest +import pytest_asyncio +try: + asyncio_fixture_decorator = pytest_asyncio.fixture +except AttributeError: + asyncio_fixture_decorator = pytest.fixture + +class AwaitableTestFixture(object): + """Class implementing an awaitable test fixture.""" + def __init__(self, awaitable_cls=None): + self.awaitable_cls = awaitable_cls + + def __await__(self): + + async def closure(): + awaitables = [attr for attr in self.__dict__.values() if isinstance( + attr, self.awaitable_cls + )] + await asyncio.gather(*awaitables) + return self + + return closure().__await__() + +class AsyncCallableMixin(CallableMixin): + + def __init__(_mock_self, *args, **kwargs): + super().__init__(*args, **kwargs) + _mock_self.aenter_return_value = _mock_self + _mock_self._await_count = 0 + + def __call__(_mock_self, *args, **kwargs): + async def wrapper(): + _mock_self._await_count += 1 + _mock_self._mock_check_sig(*args, **kwargs) + return _mock_self._mock_call(*args, **kwargs) + return wrapper() + + async def __aenter__(_mock_self): + return _mock_self.aenter_return_value + + async def __aexit__(_mock_self, exc_type, exc_val, exc_tb): + pass + + def assert_awaited(_mock_self): + if _mock_self._await_count == 0: + raise AssertionError("Expected mock to have been awaited.") + + def assert_awaited_once(_mock_self): + if _mock_self._await_count != 1: + msg = ( + "Expected mock to have been awaited once. " + "Awaited {} times.".format(_mock_self._await_count) + ) + raise AssertionError(msg) + + @property + def await_count(_mock_self): + return _mock_self._await_count + + +class AsyncMock(AsyncCallableMixin, NonCallableMock): + """ + Drop-in replacement for absence of AsyncMock in python < 3.8. + + Based on https://github.com/timsavage/asyncmock + """ diff --git a/tests/config/test_async.cfg b/tests/config/test_async.cfg new file mode 100644 index 0000000..0afb0a9 --- /dev/null +++ b/tests/config/test_async.cfg @@ -0,0 +1,19 @@ +[server] +debug_mode = 1 +http_port = 8888 +http_addr = 127.0.0.1 +static_path = ./static +adapters = async, dummy + +[tornado] +logging = debug + +[adapter.async] +module = odin.adapters.async_dummy.AsyncDummyAdapter +async_sleep_duration = 1.5 +wrap_sync_sleep = 1 + +[adapter.dummy] +module = odin.adapters.dummy.DummyAdapter +background_task_enable = 1 +background_task_interval = 1.0 diff --git a/tests/config/test_async_proxy.cfg b/tests/config/test_async_proxy.cfg new file mode 100644 index 0000000..2b2abea --- /dev/null +++ b/tests/config/test_async_proxy.cfg @@ -0,0 +1,16 @@ +[server] +debug_mode = 1 +http_port = 8889 +http_addr = 127.0.0.1 +static_path = static +adapters = proxy + +[tornado] +logging = debug + +[adapter.proxy] +module = odin.adapters.async_proxy.AsyncProxyAdapter +targets = + node_1 = http://127.0.0.1:8888/api/0.1/system_info/, + node_2 = http://127.0.0.1:8887/api/0.1/system_info/ +request_timeout = 2.0 diff --git a/tests/config/test_proxy.cfg b/tests/config/test_proxy.cfg index 87e6a42..f70c391 100644 --- a/tests/config/test_proxy.cfg +++ b/tests/config/test_proxy.cfg @@ -1,7 +1,7 @@ [server] debug_mode = 1 http_port = 8889 -http_addr = 0.0.0.0 +http_addr = 127.0.0.1 static_path = static adapters = proxy diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..995f60a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +import sys + +collect_ignore = [] + +if sys.version_info[0] < 3: + collect_ignore_glob = ["*_py3.py"] +else: + collect_ignore_glob = ["*_py2.py"] diff --git a/tests/handlers/__init__.py b/tests/handlers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/handlers/fixtures.py b/tests/handlers/fixtures.py new file mode 100644 index 0000000..0cfba72 --- /dev/null +++ b/tests/handlers/fixtures.py @@ -0,0 +1,108 @@ +import sys +import json + +import pytest + +if sys.version_info[0] == 3: # pragma: no cover + from unittest.mock import Mock + import asyncio + async_allowed = True +else: # pragma: no cover + from mock import Mock + async_allowed = False + +from odin.http.handlers.base import BaseApiHandler, API_VERSION, ApiError, validate_api_request +from odin.http.routes.api import ApiHandler +from odin.adapters.adapter import ApiAdapterResponse +from odin.config.parser import AdapterConfig +from odin.util import wrap_result + + +class TestHandler(object): + """Class to create appropriate mocked objects to allow the ApiHandler to be tested.""" + + def __init__(self, handler_cls, async_adapter=async_allowed): + """Initialise the TestHandler.""" + # Initialise attribute to receive output of patched write() method + self.write_data = None + + # Create mock Tornado application and request objects for ApiHandler initialisation + self.app = Mock() + self.app.ui_methods = {} + self.request = Mock() + + # Create mock responses for ApiHandler test cases + self.json_dict_response = Mock() + self.json_dict_response.status_code = 200 + self.json_dict_response.content_type = 'application/json' + self.json_dict_response.data = {'response': 'is_json'} + + self.json_str_response = Mock() + self.json_str_response.status_code = 200 + self.json_str_response.content_type = 'application/json' + self.json_str_response.data = json.dumps(self.json_dict_response.data) + + # Create a mock route and a default adapter for a subsystem + self.route = Mock() + self.subsystem = 'default' + self.path = 'default/path' + self.route.adapters = {} + self.route.adapter = lambda subsystem: self.route.adapters[subsystem] + self.route.has_adapter = lambda subsystem: subsystem in self.route.adapters + + # Create a mock API adapter that returns appropriate responses + api_adapter_mock = Mock() + api_adapter_mock.is_async = async_adapter + api_adapter_mock.get.return_value = wrap_result(self.json_dict_response, async_adapter) + api_adapter_mock.post.return_value = wrap_result(self.json_dict_response, async_adapter) + api_adapter_mock.put.return_value = wrap_result(self.json_dict_response, async_adapter) + api_adapter_mock.delete.return_value = wrap_result(self.json_dict_response, async_adapter) + self.route.adapters[self.subsystem] = api_adapter_mock + + # Create the handler and mock its write method with the local version + self.handler = handler_cls(self.app, self.request, route=self.route) + self.handler.write = self.mock_write + self.handler.dummy_get = self.dummy_get + + self.respond = self.handler.respond + + def mock_write(self, chunk): + """Mock write function to be used with the handler.""" + if isinstance(chunk, dict): + self.write_data = json.dumps(chunk) + else: + self.write_data = chunk + + @validate_api_request(API_VERSION) + def dummy_get(self, subsystem, path=''): + """Dummy HTTP GET verb method to allow the request validation decorator to be tested.""" + response = ApiAdapterResponse( + {'subsystem': subsystem, 'path': path }, + content_type='application/json', + status_code=200 + ) + self.respond(response) + +if async_allowed: + fixture_params = [True, False] + fixture_ids = ["async", "sync"] +else: + fixture_params = [False] + fixture_ids = ["sync"] + +@pytest.fixture(scope="class", params=fixture_params, ids=fixture_ids) +def test_api_handler(request): + """ + Parameterised test fixture for testing the APIHandler class. + + The fixture parameters and id lists are set depending on whether async code is + allowed on the current platform (e.g. python 2 vs 3). + """ + test_api_handler = TestHandler(ApiHandler, request.param) + yield test_api_handler + +@pytest.fixture(scope="class") +def test_base_handler(): + """Test fixture for testing the BaseHandler class.""" + test_base_handler = TestHandler(BaseApiHandler) + yield test_base_handler diff --git a/tests/handlers/test_api_py2.py b/tests/handlers/test_api_py2.py new file mode 100644 index 0000000..4fdff92 --- /dev/null +++ b/tests/handlers/test_api_py2.py @@ -0,0 +1,59 @@ +import sys +import json + +import pytest + +if sys.version_info[0] == 3: # pragma: no cover + from unittest.mock import Mock +else: # pragma: no cover + from mock import Mock + +from odin.http.routes.api import ApiRoute, ApiHandler, ApiError, API_VERSION +from odin.config.parser import AdapterConfig +from tests.handlers.fixtures import test_api_handler + +class TestApiHandler(object): + + def test_handler_valid_get(self, test_api_handler): + """Test that the handler creates a valid status and response to a GET request.""" + test_api_handler.handler.get(str(API_VERSION), + test_api_handler.subsystem, test_api_handler.path) + assert test_api_handler.handler.get_status() == 200 + assert json.loads(test_api_handler.write_data) == test_api_handler.json_dict_response.data + + def test_handler_valid_post(self, test_api_handler): + """Test that the handler creates a valid status and response to a POST request.""" + test_api_handler.handler.post(str(API_VERSION), + test_api_handler.subsystem, test_api_handler.path) + assert test_api_handler.handler.get_status() == 200 + assert json.loads(test_api_handler.write_data) == test_api_handler.json_dict_response.data + + def test_handler_valid_put(self, test_api_handler): + """Test that the handler creates a valid status and response to a PUT request.""" + test_api_handler.handler.put(str(API_VERSION), + test_api_handler.subsystem, test_api_handler.path) + assert test_api_handler.handler.get_status() == 200 + assert json.loads(test_api_handler.write_data) == test_api_handler.json_dict_response.data + + def test_handler_valid_delete(self, test_api_handler): + """Test that the handler creates a valid status and response to a PUT request.""" + test_api_handler.handler.delete(str(API_VERSION), + test_api_handler.subsystem, test_api_handler.path) + assert test_api_handler.handler.get_status() == 200 + assert json.loads(test_api_handler.write_data) == test_api_handler.json_dict_response.data + + def test_bad_api_version(self, test_api_handler): + """Test that a bad API version in a GET call to the handler yields an error.""" + bad_version = 0.1234 + test_api_handler.handler.get(str(bad_version), + test_api_handler.subsystem, test_api_handler.path) + assert test_api_handler.handler.get_status() == 400 + assert "API version {} is not supported".format(bad_version) in test_api_handler.write_data + + def test_bad_subsystem(self, test_api_handler): + """Test that a bad subsystem in a GET call to the handler yields an error.""" + bad_subsystem = 'missing' + test_api_handler.handler.get(str(API_VERSION), bad_subsystem, test_api_handler.path) + assert test_api_handler.handler.get_status() == 400 + assert "No API adapter registered for subsystem {}".format(bad_subsystem) \ + in test_api_handler.write_data diff --git a/tests/handlers/test_api_py3.py b/tests/handlers/test_api_py3.py new file mode 100644 index 0000000..073ca10 --- /dev/null +++ b/tests/handlers/test_api_py3.py @@ -0,0 +1,65 @@ +import sys +import json + +import pytest + +if sys.version_info[0] == 3: # pragma: no cover + from unittest.mock import Mock +else: # pragma: no cover + from mock import Mock + pytest.skip("Skipping async tests", allow_module_level=True) + +from odin.http.handlers.base import BaseApiHandler, API_VERSION +from tests.handlers.fixtures import test_api_handler + +class TestApiHandler(object): + + @pytest.mark.asyncio + async def test_handler_valid_get(self, test_api_handler): + """Test that the handler creates a valid status and response to a GET request.""" + await test_api_handler.handler.get(str(API_VERSION), + test_api_handler.subsystem, test_api_handler.path) + assert test_api_handler.handler.get_status() == 200 + assert json.loads(test_api_handler.write_data) == test_api_handler.json_dict_response.data + + @pytest.mark.asyncio + async def test_handler_valid_post(self, test_api_handler): + """Test that the handler creates a valid status and response to a POST request.""" + await test_api_handler.handler.post(str(API_VERSION), + test_api_handler.subsystem, test_api_handler.path) + assert test_api_handler.handler.get_status() == 200 + assert json.loads(test_api_handler.write_data) == test_api_handler.json_dict_response.data + + @pytest.mark.asyncio + async def test_handler_valid_put(self, test_api_handler): + """Test that the handler creates a valid status and response to a PUT request.""" + await test_api_handler.handler.put(str(API_VERSION), + test_api_handler.subsystem, test_api_handler.path) + assert test_api_handler.handler.get_status() == 200 + assert json.loads(test_api_handler.write_data) == test_api_handler.json_dict_response.data + + @pytest.mark.asyncio + async def test_handler_valid_delete(self, test_api_handler): + """Test that the handler creates a valid status and response to a DELETE request.""" + await test_api_handler.handler.delete(str(API_VERSION), + test_api_handler.subsystem, test_api_handler.path) + assert test_api_handler.handler.get_status() == 200 + assert json.loads(test_api_handler.write_data) == test_api_handler.json_dict_response.data + + @pytest.mark.asyncio + async def test_bad_api_version(self, test_api_handler): + """Test that a bad API version in a GET call to the handler yields an error.""" + bad_version = 0.1234 + await test_api_handler.handler.get(str(bad_version), + test_api_handler.subsystem, test_api_handler.path) + assert test_api_handler.handler.get_status() == 400 + assert "API version {} is not supported".format(bad_version) in test_api_handler.write_data + + @pytest.mark.asyncio + async def test_bad_subsystem(self, test_api_handler): + """Test that a bad subsystem in a GET call to the handler yields an error.""" + bad_subsystem = 'missing' + await test_api_handler.handler.get(str(API_VERSION), bad_subsystem, test_api_handler.path) + assert test_api_handler.handler.get_status() == 400 + assert "No API adapter registered for subsystem {}".format(bad_subsystem) \ + in test_api_handler.write_data diff --git a/tests/handlers/test_base.py b/tests/handlers/test_base.py new file mode 100644 index 0000000..4fcbd92 --- /dev/null +++ b/tests/handlers/test_base.py @@ -0,0 +1,123 @@ +import sys +import json + +import pytest + +if sys.version_info[0] == 3: # pragma: no cover + from unittest.mock import Mock +else: # pragma: no cover + from mock import Mock + + +from odin.http.handlers.base import BaseApiHandler, API_VERSION, ApiError, validate_api_request +from odin.adapters.adapter import ApiAdapterResponse +from tests.handlers.fixtures import test_base_handler + + +class TestBaseApiHandler(object): + """Test cases for the BaseApiHandler class.""" + + def test_handler_initializes_route(self, test_base_handler): + """ + Check that the handler route has been set, i.e that that handler has its + initialize method called. + """ + assert test_base_handler.handler.route == test_base_handler.route + + def test_handler_response_json_str(self, test_base_handler): + """Test that the handler respond correctly deals with a string response.""" + test_base_handler.handler.respond(test_base_handler.json_str_response) + assert test_base_handler.write_data == test_base_handler.json_str_response.data + + def test_handler_response_json_dict(self, test_base_handler): + """Test that the handler respond correctly deals with a dict response.""" + test_base_handler.handler.respond(test_base_handler.json_dict_response) + assert test_base_handler.write_data == test_base_handler.json_str_response.data + + def test_handler_respond_valid_json(self, test_base_handler): + """Test that the base handler respond method handles a valid JSON ApiAdapterResponse.""" + data = {'valid': 'json', 'value': 1.234} + valid_response = ApiAdapterResponse(data, content_type="application/json") + test_base_handler.handler.respond(valid_response) + assert test_base_handler.handler.get_status() == 200 + assert json.loads(test_base_handler.write_data) == data + + def test_handler_respond_invalid_json(self, test_base_handler): + """ + Test that the base handler respond method raises an ApiError when passed + an invalid response. + """ + invalid_response = ApiAdapterResponse(1234, content_type="application/json") + with pytest.raises(ApiError) as excinfo: + test_base_handler.handler.respond(invalid_response) + + assert 'A response with content type application/json must have str or dict data' \ + in str(excinfo.value) + + def test_handler_get(self, test_base_handler): + """Test that the base handler get method raises a not implemented error.""" + with pytest.raises(NotImplementedError): + test_base_handler.handler.get( + test_base_handler.subsystem, test_base_handler.path) + + def test_handler_post(self, test_base_handler): + """Test that the base handler post method raises a not implemented error.""" + with pytest.raises(NotImplementedError): + test_base_handler.handler.post( + test_base_handler.subsystem, test_base_handler.path) + + def test_handler_put(self, test_base_handler): + """Test that the base handler put method raises a not implemented error.""" + with pytest.raises(NotImplementedError): + test_base_handler.handler.put( + test_base_handler.subsystem, test_base_handler.path) + + def test_handler_delete(self, test_base_handler): + """Test that the base handler delete method raises a not implemented error.""" + with pytest.raises(NotImplementedError): + test_base_handler.handler.delete( + test_base_handler.subsystem, test_base_handler.path) + + +class TestHandlerRequestValidation(): + """Test cases for the validate_api_request decorator.""" + + def test_invalid_api_request_version(self, test_base_handler): + """ + Check that a request with an invalid API version is intercepted by the decorator + and returns an appropriate HTTP response. + """ + bad_version = 0.1234 + test_base_handler.handler.dummy_get( + str(bad_version), test_base_handler.subsystem, test_base_handler.path + ) + assert test_base_handler.handler.get_status() == 400 + assert "API version {} is not supported".format(bad_version) in test_base_handler.write_data + + def test_invalid_subsystem_request(self, test_base_handler): + """ + Check that a request with an invalid subsystem, i.e. one which does not have an + adapter registered, is intercepted by the decorator and returns an appropriate + HTTP response. + """ + bad_subsystem = 'bad_subsys' + test_base_handler.handler.dummy_get( + str(API_VERSION), bad_subsystem, test_base_handler.path + ) + assert test_base_handler.handler.get_status() == 400 + assert "No API adapter registered for subsystem {}".format(bad_subsystem) \ + in test_base_handler.write_data + + def test_valid_request(self, test_base_handler): + """ + Check that a request with a valid API version and subsystem is not intercepted by + the decorator and calls the verb method correctly. + """ + test_base_handler.handler.dummy_get( + str(API_VERSION), test_base_handler.subsystem, test_base_handler.path + ) + assert test_base_handler.handler.get_status() == 200 + + + + diff --git a/tests/routes/test_api.py b/tests/routes/test_api.py index 69f9ff7..06d69c9 100644 --- a/tests/routes/test_api.py +++ b/tests/routes/test_api.py @@ -8,7 +8,7 @@ else: # pragma: no cover from mock import Mock -from odin.http.routes.api import ApiRoute, ApiHandler, ApiError, _api_version +from odin.http.routes.api import ApiRoute, ApiHandler, ApiError, API_VERSION from odin.config.parser import AdapterConfig @pytest.fixture(scope="class") @@ -86,119 +86,3 @@ def test_register_adapter_no_initialize(self, test_api_route): assert not raised -class ApiTestHandler(object): - """Class to create appropriate mocked objects to allow the ApiHandler to be tested.""" - - def __init__(self): - """Initialise the ApiTestHandler.""" - # Initialise attribute to receive output of patched write() method - self.write_data = None - - # Create mock Tornado application and request objects for ApiHandler initialisation - self.app = Mock() - self.app.ui_methods = {} - self.request = Mock() - - # Create mock responses for ApiHandler test cases - self.json_dict_response = Mock() - self.json_dict_response.status_code = 200 - self.json_dict_response.content_type = 'application/json' - self.json_dict_response.data = {'response': 'is_json'} - - self.json_str_response = Mock() - self.json_str_response.status_code = 200 - self.json_str_response.content_type = 'application/json' - self.json_str_response.data = json.dumps(self.json_dict_response.data) - - # Create a mock route and a default adapter for a subsystem - self.route = Mock() - self.subsystem = 'default' - self.route.adapters = {} - self.route.adapter = lambda subsystem: self.route.adapters[subsystem] - self.route.has_adapter = lambda subsystem: subsystem in self.route.adapters - - self.route.adapters[self.subsystem] = Mock() - self.route.adapters[self.subsystem].get.return_value = self.json_dict_response - self.route.adapters[self.subsystem].put.return_value = self.json_dict_response - self.route.adapters[self.subsystem].delete.return_value = self.json_dict_response - - # Create the handler and mock its write method with the local version - self.handler = ApiHandler(self.app, self.request, route=self.route) - self.handler.write = self.mock_write - - self.path = 'default/path' - - def mock_write(self, chunk): - """Mock write function to be used with the handler.""" - if isinstance(chunk, dict): - self.write_data = json.dumps(chunk) - else: - self.write_data = chunk - -@pytest.fixture(scope="class") -def test_api_handler(): - """Simple test fixture that creates a test API handler.""" - test_api_handler = ApiTestHandler() - yield test_api_handler - -class TestApiHandler(object): - - def test_handler_valid_get(self, test_api_handler): - """Test that the handler creates a valid status and response to a GET request.""" - test_api_handler.handler.get(str(_api_version), - test_api_handler.subsystem, test_api_handler.path) - assert test_api_handler.handler.get_status() == 200 - assert json.loads(test_api_handler.write_data) == test_api_handler.json_dict_response.data - - def test_handler_valid_put(self, test_api_handler): - """Test that the handler creates a valid status and response to a PUT request.""" - test_api_handler.handler.put(str(_api_version), - test_api_handler.subsystem, test_api_handler.path) - assert test_api_handler.handler.get_status() == 200 - assert json.loads(test_api_handler.write_data) == test_api_handler.json_dict_response.data - - def test_handler_valid_delete(self, test_api_handler): - """Test that the handler creates a valid status and response to a PUT request.""" - test_api_handler.handler.delete(str(_api_version), - test_api_handler.subsystem, test_api_handler.path) - assert test_api_handler.handler.get_status() == 200 - assert json.loads(test_api_handler.write_data) == test_api_handler.json_dict_response.data - - def test_bad_api_version(self, test_api_handler): - """Test that a bad API version in a GET call to the handler yields an error.""" - bad_version = 0.1234 - test_api_handler.handler.get(str(bad_version), - test_api_handler.subsystem, test_api_handler.path) - assert test_api_handler.handler.get_status() == 400 - assert "API version {} is not supported".format(bad_version) in test_api_handler.write_data - - def test_bad_subsystem(self, test_api_handler): - """Test that a bad subsystem in a GET call to the handler yields an error.""" - bad_subsystem = 'missing' - test_api_handler.handler.get(str(_api_version), bad_subsystem, test_api_handler.path) - assert test_api_handler.handler.get_status() == 400 - assert "No API adapter registered for subsystem {}".format(bad_subsystem) \ - in test_api_handler.write_data - - def test_handler_response_json_str(self, test_api_handler): - """Test that the handler respond correctly deals with a string response.""" - test_api_handler.handler.respond(test_api_handler.json_str_response) - assert test_api_handler.write_data == test_api_handler.json_str_response.data - - def test_handler_response_json_dict(self, test_api_handler): - """Test that the handler respond correctly deals with a dict response.""" - test_api_handler.handler.respond(test_api_handler.json_dict_response) - assert test_api_handler.write_data == test_api_handler.json_str_response.data - - def test_handler_response_json_bad_type(self, test_api_handler): - """Test that the handler raises an error if an incorrect type of response is returned.""" - bad_response = Mock() - bad_response.status_code = 200 - bad_response.content_type = 'application/json' - bad_response.data = 1234 - - with pytest.raises(ApiError) as excinfo: - test_api_handler.handler.respond(bad_response) - - assert 'A response with content type application/json must have str or dict data' \ - in str(excinfo.value) diff --git a/tests/routes/test_async_api_py3.py b/tests/routes/test_async_api_py3.py new file mode 100644 index 0000000..066e40c --- /dev/null +++ b/tests/routes/test_async_api_py3.py @@ -0,0 +1,56 @@ +import sys + +import pytest + +if sys.version_info[0] < 3: + pytest.skip("Skipping async tests", allow_module_level=True) +else: + try: + from unittest.mock import AsyncMock + except ImportError: + from tests.async_utils import AsyncMock + + +from odin.http.routes.api import ApiRoute +from odin.config.parser import AdapterConfig + +class ApiRouteAsyncTestFixture(object): + + def __init__(self): + + self.route = ApiRoute() + self.adapter_name = 'async_dummy' + self.adapter_module = 'odin.adapters.async_dummy.AsyncDummyAdapter' + self.adapter_config = AdapterConfig(self.adapter_name, self.adapter_module) + + self.route.register_adapter(self.adapter_config) + + self.initialize_mock = AsyncMock() + self.route.adapters[self.adapter_name].initialize = self.initialize_mock + + self.cleanup_mock = AsyncMock() + self.route.adapters[self.adapter_name].cleanup = self.cleanup_mock + +@pytest.fixture(scope="class") +def test_api_route_async(): + """Test fixture used in testing ApiRoute behaviour with async adapters""" + + test_api_route_async = ApiRouteAsyncTestFixture() + yield test_api_route_async + + +class TestApiRouteAsync(object): + + def test_register_async_adapter(self, test_api_route_async): + + assert test_api_route_async.route.has_adapter('async_dummy') + + def test_initialize_async_adapter(self, test_api_route_async): + + test_api_route_async.route.initialize_adapters() + test_api_route_async.initialize_mock.assert_awaited_once() + + def test_cleanup_async_adapter(self, test_api_route_async): + + test_api_route_async.route.cleanup_adapters() + test_api_route_async.cleanup_mock.assert_awaited_once() \ No newline at end of file diff --git a/tests/test_util.py b/tests/test_util.py index 8f4bbe3..41f396f 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,8 +1,13 @@ import sys import pytest +import time +import concurrent.futures +import tornado.concurrent +from tornado import version_info if sys.version_info[0] == 3: # pragma: no cover from unittest.mock import Mock + import asyncio else: # pragma: no cover from mock import Mock @@ -64,3 +69,55 @@ def test_convert_unicode_to_string_mixed_recursion(self): 'list': ['unicode string', "normal string"] } assert result == expected_result + + @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"]) + def test_wrap_result(self, is_async): + """Test that the wrap_result utility correctly wraps results in a future when needed.""" + result = 321 + wrapped_result = util.wrap_result(result, is_async) + if sys.version_info[0] == 3 and is_async: + assert isinstance(wrapped_result, asyncio.Future) + assert wrapped_result.result() == result + else: + assert wrapped_result == result + + def test_run_in_executor(self): + """Test that the run_in_executor utility can correctly nest asynchronous tasks.""" + # Container for task results modified by inner functions + task_result = { + 'count': 0, + 'outer_completed': False, + 'inner_completed': False, + } + + def nested_task(num_loops): + """Simple task that loops and increments a counter before completing.""" + for _ in range(num_loops): + time.sleep(0.01) + task_result['count'] += 1 + task_result['inner_completed'] = True + + def outer_task(num_loops): + """Outer task that launchas another task on an executor.""" + util.run_in_executor(executor, nested_task, num_loops) + task_result['outer_completed'] = True + + executor = concurrent.futures.ThreadPoolExecutor() + + num_loops = 10 + future = util.run_in_executor(executor, outer_task, num_loops) + + wait_count = 0 + while not task_result['inner_completed'] and wait_count < 100: + time.sleep(0.01) + wait_count += 1 + + if version_info[0] <= 4: + future_type = concurrent.futures.Future + else: + future_type = tornado.concurrent.Future + + assert isinstance(future, future_type) + assert task_result['inner_completed'] is True + assert task_result['count'] == num_loops + assert task_result['outer_completed'] is True diff --git a/tests/test_util_py3.py b/tests/test_util_py3.py new file mode 100644 index 0000000..b98fb5e --- /dev/null +++ b/tests/test_util_py3.py @@ -0,0 +1,76 @@ +import sys +import pytest +import time +import concurrent.futures + +import pytest_asyncio + +from odin import util +from odin import async_util + +if sys.version_info[0] < 3: + pytest.skip("Skipping async tests", allow_module_level=True) + +import asyncio + + +class TestUtilAsync(): + + @pytest.mark.asyncio + async def test_wrap_result(self): + """Test that the wrap_result utility correctly wraps results in a future when needed.""" + result = 321 + wrapped = util.wrap_result(result, True) + await wrapped + assert isinstance(wrapped, asyncio.Future) + assert wrapped.result() == result + + @pytest.mark.asyncio + async def test_wrap_async(self): + """Test that the wrap_async fuction correctly wraps results in a future.""" + result = 987 + wrapped = async_util.wrap_async(result) + await wrapped + assert isinstance(wrapped, asyncio.Future) + assert wrapped.result() == result + + @pytest.mark.asyncio + async def test_run_in_executor(self): + """ + Test that the run_in_executor utility runs a background task asynchronously and returns + an awaitable future. + """ + task_result = { + 'count': 0, + 'completed': False + } + + def task_func(num_loops): + """Simple task that loops and increments a counter before completing.""" + for _ in range(num_loops): + time.sleep(0.01) + task_result['count'] += 1 + task_result['completed'] = True + + executor = concurrent.futures.ThreadPoolExecutor() + + num_loops = 10 + await util.run_in_executor(executor, task_func, num_loops) + + wait_count = 0 + while not task_result['completed'] and wait_count < 100: + asyncio.sleep(0.01) + wait_count += 1 + + assert task_result['completed'] == True + assert task_result['count'] == num_loops + + def test_run_async(self): + + async def async_increment(value): + await asyncio.sleep(0) + return value + 1 + + value = 5 + result = async_util.run_async(async_increment, value) + assert result == value + 1 \ No newline at end of file diff --git a/tox.ini b/tox.ini index bb50c8c..b456747 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,4 @@ -# tox (https://tox.readthedocs.io/) is a tool for running tests -# in multiple virtualenvs. This configuration file will run the -# test suite on all supported python versions. To use it, "pip install tox" -# and then run "tox" from this directory. +# tox test configuration for odin-control [tox] envlist = clean,py27-tornado{4,5},py{36,37,38,39}-tornado{5,6},report @@ -27,7 +24,8 @@ deps = setenv = py{27,36,37,38,39}: COVERAGE_FILE=.coverage.{envname} commands = - pytest --cov=odin --cov-report=term-missing {posargs:-vv} + py{27,36}: pytest --cov=odin --cov-report=term-missing {posargs:-vv} + py{37,38,39}: pytest --cov=odin --cov-report=term-missing --asyncio-mode=strict {posargs:-vv} depends = py{27,36,37,38,39}: clean report: py{27,36,37,38,39}