Skip to content

Commit

Permalink
Refractor NEST Server with better error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
babsey committed Sep 25, 2024
1 parent ae09e1d commit 3f51ce6
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 112 deletions.
3 changes: 1 addition & 2 deletions pynest/nest/server/hl_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@
CORS_ORIGINS,
EXEC_CALL_ENABLED,
_check_security,
get_arguments,
nest_calls,
)
from .hl_api_server_mpi import api_client, do_call, log, mpi_comm
from .hl_api_server_utils import ErrorHandler
from .hl_api_server_utils import ErrorHandler, get_arguments

# This ensures that the logging information shows up in the console running the server,
# even when Flask's event loop is running.
Expand Down
115 changes: 38 additions & 77 deletions pynest/nest/server/hl_api_server_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,31 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

import ast
import importlib
import inspect
import io
import os
import sys
import time
import traceback

import nest
import RestrictedPython
from nest.lib.hl_api_exceptions import NESTError

from .hl_api_server_utils import get_boolean_environ, get_or_error
from .hl_api_server_utils import (
Capturing,
ErrorHandler,
clean_code,
get_boolean_environ,
get_lineno,
get_modules_from_env,
)

_default_origins = "http://localhost:*,http://127.0.0.1:*"
ACCESS_TOKEN = os.environ.get("NEST_SERVER_ACCESS_TOKEN", "")
AUTH_DISABLED = get_boolean_environ("NEST_SERVER_DISABLE_AUTH")
CORS_ORIGINS = os.environ.get("NEST_SERVER_CORS_ORIGINS", _default_origins).split(",")
EXEC_CALL_ENABLED = get_boolean_environ("NEST_SERVER_ENABLE_EXEC_CALL")
RESTRICTION_DISABLED = get_boolean_environ("NEST_SERVER_DISABLE_RESTRICTION")
MODULES = os.environ.get("NEST_SERVER_MODULES", "import nest")
RESTRICTION_DISABLED = get_boolean_environ("NEST_SERVER_DISABLE_RESTRICTION")

__all__ = [
"nestify",
Expand Down Expand Up @@ -75,31 +79,6 @@ def _check_security():
print("\n - ".join([" "] + msg) + "\n")


class Capturing(list):
"""Monitor stdout contents i.e. print."""

def __enter__(self):
self._stdout = sys.stdout
sys.stdout = self._stringio = io.StringIO()
return self

def __exit__(self, *args):
self.extend(self._stringio.getvalue().splitlines())
del self._stringio # free up some memory
sys.stdout = self._stdout


def clean_code(source):
codes = source.split("\n")
codes_cleaned = [] # noqa
for code in codes:
if code.startswith("import") or code.startswith("from"):
codes_cleaned.append("#" + code)
else:
codes_cleaned.append(code)
return "\n".join(codes_cleaned)


def do_exec(kwargs):
source_code = kwargs.get("source", "")
source_cleaned = clean_code(source_code)
Expand Down Expand Up @@ -132,56 +111,38 @@ def do_exec(kwargs):
return response


def get_arguments(request):
"""Get arguments from the request."""
args, kwargs = [], {}
if request.is_json:
json = request.get_json()
if isinstance(json, str) and len(json) > 0:
args = [json]
elif isinstance(json, list):
args = json
elif isinstance(json, dict):
kwargs = json
if "args" in kwargs:
args = kwargs.pop("args")
elif len(request.form) > 0:
if "args" in request.form:
args = request.form.getlist("args")
else:
kwargs = request.form.to_dict()
elif len(request.args) > 0:
if "args" in request.args:
args = request.args.getlist("args")
else:
kwargs = request.args.to_dict()
return list(args), kwargs
def get_or_error(func):
"""Wrapper to exec function."""

def func_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)

def get_modules_from_env():
"""Get modules from environment variable NEST_SERVER_MODULES.
except NESTError as err:
error_class = err.errorname + " (NESTError)"
detail = err.errormessage
lineno = get_lineno(err, 1)

This function converts the content of the environment variable NEST_SERVER_MODULES:
to a formatted dictionary for updating the Python `globals`.
except (KeyError, SyntaxError, TypeError, ValueError) as err:
error_class = err.__class__.__name__
detail = err.args[0]
lineno = get_lineno(err, 1)

Here is an example:
`NEST_SERVER_MODULES="import nest; import numpy as np; from numpy import random"`
is converted to the following dictionary:
`{'nest': <module 'nest'> 'np': <module 'numpy'>, 'random': <module 'numpy.random'>}`
"""
modules = {}
try:
parsed = ast.iter_child_nodes(ast.parse(MODULES))
except (SyntaxError, ValueError):
raise SyntaxError("The NEST server module environment variables contains syntax errors.")
for node in parsed:
if isinstance(node, ast.Import):
for alias in node.names:
modules[alias.asname or alias.name] = importlib.import_module(alias.name)
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
modules[alias.asname or alias.name] = importlib.import_module(f"{node.module}.{alias.name}")
return modules
except Exception as err:
error_class = err.__class__.__name__
detail = err.args[0]
lineno = get_lineno(err, -1)

for line in traceback.format_exception(*sys.exc_info()):
print(line, flush=True)

if lineno == -1:
message = "%s: %s" % (error_class, detail)
else:
message = "%s at line %d: %s" % (error_class, lineno, detail)
raise ErrorHandler(message, lineno)

return func_wrapper


def get_restricted_globals():
Expand Down
113 changes: 80 additions & 33 deletions pynest/nest/server/hl_api_server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,29 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.


import ast
import importlib
import io
import os
import sys
import traceback

from nest.lib.hl_api_exceptions import NESTError
MODULES = os.environ.get("NEST_SERVER_MODULES", "import nest")


class Capturing(list):
"""Monitor stdout contents i.e. print."""

def __enter__(self):
self._stdout = sys.stdout
sys.stdout = self._stringio = io.StringIO()
return self

def __exit__(self, *args):
self.extend(self._stringio.getvalue().splitlines())
del self._stringio # free up some memory
sys.stdout = self._stdout


class ErrorHandler(Exception):
Expand All @@ -47,6 +65,43 @@ def to_dict(self):
return rv


def clean_code(source):
codes = source.split("\n")
codes_cleaned = [] # noqa
for code in codes:
if code.startswith("import") or code.startswith("from"):
codes_cleaned.append("#" + code)
else:
codes_cleaned.append(code)
return "\n".join(codes_cleaned)


def get_arguments(request):
"""Get arguments from the request."""
args, kwargs = [], {}
if request.is_json:
json = request.get_json()
if isinstance(json, str) and len(json) > 0:
args = [json]
elif isinstance(json, list):
args = json
elif isinstance(json, dict):
kwargs = json
if "args" in kwargs:
args = kwargs.pop("args")
elif len(request.form) > 0:
if "args" in request.form:
args = request.form.getlist("args")
else:
kwargs = request.form.to_dict()
elif len(request.args) > 0:
if "args" in request.args:
args = request.args.getlist("args")
else:
kwargs = request.args.to_dict()
return list(args), kwargs


def get_boolean_environ(env_key, default_value="false"):
env_value = os.environ.get(env_key, default_value)
return env_value.lower() in ["yes", "true", "t", "1"]
Expand All @@ -65,35 +120,27 @@ def get_lineno(err, tb_idx):
return lineno


def get_or_error(func):
"""Wrapper to exec function."""

def func_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)

except NESTError as err:
error_class = err.errorname + " (NESTError)"
detail = err.errormessage
lineno = get_lineno(err, 1)

except (KeyError, SyntaxError, TypeError, ValueError) as err:
error_class = err.__class__.__name__
detail = err.args[0]
lineno = get_lineno(err, 1)

except Exception as err:
error_class = err.__class__.__name__
detail = err.args[0]
lineno = get_lineno(err, -1)

for line in traceback.format_exception(*sys.exc_info()):
print(line, flush=True)

if lineno == -1:
message = "%s: %s" % (error_class, detail)
else:
message = "%s at line %d: %s" % (error_class, lineno, detail)
raise ErrorHandler(message, lineno)

return func_wrapper
def get_modules_from_env():
"""Get modules from environment variable NEST_SERVER_MODULES.
This function converts the content of the environment variable NEST_SERVER_MODULES:
to a formatted dictionary for updating the Python `globals`.
Here is an example:
`NEST_SERVER_MODULES="import nest; import numpy as np; from numpy import random"`
is converted to the following dictionary:
`{'nest': <module 'nest'> 'np': <module 'numpy'>, 'random': <module 'numpy.random'>}`
"""
modules = {}
try:
parsed = ast.iter_child_nodes(ast.parse(MODULES))
except (SyntaxError, ValueError):
raise SyntaxError("The NEST server module environment variables contains syntax errors.")
for node in parsed:
if isinstance(node, ast.Import):
for alias in node.names:
modules[alias.asname or alias.name] = importlib.import_module(alias.name)
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
modules[alias.asname or alias.name] = importlib.import_module(f"{node.module}.{alias.name}")
return modules

0 comments on commit 3f51ce6

Please sign in to comment.