Skip to content

Commit

Permalink
Merge pull request #2632 from babsey/server-import-modules
Browse files Browse the repository at this point in the history
Enhance module imports in NEST Server
  • Loading branch information
Helveg authored Jul 31, 2023
2 parents b9514cc + bc0a9e1 commit e8932d7
Showing 1 changed file with 29 additions and 15 deletions.
44 changes: 29 additions & 15 deletions pynest/nest/server/hl_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

import ast
import importlib
import inspect
import io
Expand Down Expand Up @@ -83,12 +84,16 @@ def do_exec(args, kwargs):
response = dict()
if RESTRICTION_OFF:
with Capturing() as stdout:
exec(source_cleaned, get_globals(), locals_)
globals_ = globals().copy()
globals_.update(get_modules_from_env())
exec(source_cleaned, globals_, locals_)
if len(stdout) > 0:
response["stdout"] = "\n".join(stdout)
else:
code = RestrictedPython.compile_restricted(source_cleaned, "<inline>", "exec") # noqa
exec(code, get_restricted_globals(), locals_)
globals_ = get_restricted_globals()
globals_.update(get_modules_from_env())
exec(code, globals_, locals_)
if "_print" in locals_:
response["stdout"] = "".join(locals_["_print"].txt)

Expand Down Expand Up @@ -247,16 +252,30 @@ def get_arguments(request):
return list(args), kwargs


def get_globals():
"""Get globals for exec function."""
copied_globals = globals().copy()
def get_modules_from_env():
"""Get modules from environment variable NEST_SERVER_MODULES.
# Add modules to copied globals
modlist = [(module, importlib.import_module(module)) for module in MODULES]
modules = dict(modlist)
copied_globals.update(modules)
This function converts the content of the environment variable NEST_SERVER_MODULES:
to a formatted dictionary for updating the Python `globals`.
return copied_globals
Here is an example:
`NEST_SERVER_MODULES="import nest; import numpy as np; from numpy import random"`
is converted to the following dictionary:
`{'nest': <module 'nest'> 'np': <module 'numpy'>, 'random': <module 'numpy.random'>}`
"""
modules = {}
try:
parsed = ast.iter_child_nodes(ast.parse(MODULES))
except (SyntaxError, ValueError):
raise SyntaxError("The NEST server module environment variables contains syntax errors.")
for node in parsed:
if isinstance(node, ast.Import):
for alias in node.names:
modules[alias.asname or alias.name] = importlib.import_module(alias.name)
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
modules[alias.asname or alias.name] = importlib.import_module(f"{node.module}.{alias.name}")
return modules


def get_or_error(func):
Expand Down Expand Up @@ -305,11 +324,6 @@ def getitem(obj, index):
_write_=RestrictedPython.Guards.full_write_guard,
)

# Add modules to restricted globals
modlist = [(module, importlib.import_module(module)) for module in MODULES]
modules = dict(modlist)
restricted_globals.update(modules)

return restricted_globals


Expand Down

0 comments on commit e8932d7

Please sign in to comment.