Skip to content

Commit

Permalink
Improve model extension import
Browse files Browse the repository at this point in the history
Previously, it wasn't possible to import two model modules with the same name.
Now this is at least possible if they are in different locations.
Overwriting and importing a previously imported extension is still not supported.
  • Loading branch information
dweindl committed Nov 27, 2024
1 parent 4ed131b commit 759cf39
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 113 deletions.
15 changes: 6 additions & 9 deletions python/examples/example_splines/ExampleSplines.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"from importlib import import_module\n",
"from shutil import rmtree\n",
"from tempfile import TemporaryDirectory\n",
"from uuid import uuid1\n",
Expand Down Expand Up @@ -57,7 +55,7 @@
" parameters,\n",
" build_dir=build_dir,\n",
" model_name=model_name,\n",
" **kwargs\n",
" **kwargs,\n",
" )\n",
" else:\n",
" build_dir = os.path.join(BUILD_PATH, model_name)\n",
Expand All @@ -67,7 +65,7 @@
" parameters,\n",
" build_dir=build_dir,\n",
" model_name=model_name,\n",
" **kwargs\n",
" **kwargs,\n",
" )\n",
"\n",
"\n",
Expand All @@ -79,7 +77,7 @@
" model_name,\n",
" T=1,\n",
" discard_annotations=False,\n",
" plot=True\n",
" plot=True,\n",
"):\n",
" if parameters is None:\n",
" parameters = {}\n",
Expand All @@ -89,13 +87,12 @@
" )\n",
" sbml_importer.sbml2amici(model_name, build_dir)\n",
" # Import the model module\n",
" sys.path.insert(0, os.path.abspath(build_dir))\n",
" model_module = import_module(model_name)\n",
" model_module = amici.import_model_module(model_name, build_dir)\n",
" # Setup simulation timepoints and parameters\n",
" model = model_module.getModel()\n",
" for name, value in parameters.items():\n",
" model.setParameterByName(name, value)\n",
" if isinstance(T, (int, float)):\n",
" if isinstance(T, int | float):\n",
" T = np.linspace(0, T, 100)\n",
" model.setTimepoints([float(t) for t in T])\n",
" solver = model.getSolver()\n",
Expand Down Expand Up @@ -320,7 +317,7 @@
],
"source": [
"# Finally, we can simulate it in AMICI\n",
"model, rdata = simulate(sbml_model);"
"model, rdata = simulate(sbml_model)"
]
},
{
Expand Down
123 changes: 45 additions & 78 deletions python/sdist/amici/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
"""

import contextlib
import datetime
import importlib.util
import importlib
import os
import re
import sys
import sysconfig
from pathlib import Path
from types import ModuleType as ModelModule
from types import ModuleType
from typing import Any
from collections.abc import Callable

Expand Down Expand Up @@ -145,6 +144,8 @@ def get_model(self) -> amici.Model:
def get_jax_model(self) -> JAXModel: ...

AmiciModel = Union[amici.Model, amici.ModelPtr]
else:
ModelModule = ModuleType


class add_path:
Expand Down Expand Up @@ -182,6 +183,29 @@ def __exit__(self, exc_type, exc_value, traceback):
sys.path = self.orginal_path


def _module_from_path(module_name: str, module_path: Path | str) -> ModuleType:
"""Import a module from a given path.
Import a module from a given path. The module is not added to
`sys.modules`. The `_self` attribute of the module is set to the module
itself.
:param module_name:
Name of the module.
:param module_path:
Path to the module file. Absolute or relative to the current working
directory.
"""
module_path = Path(module_path).resolve()
if not module_path.is_file():
raise ModuleNotFoundError(f"Module file not found: {module_path}")
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module._self = module
spec.loader.exec_module(module)
return module


def import_model_module(
module_name: str, module_path: Path | str
) -> ModelModule:
Expand All @@ -195,86 +219,29 @@ def import_model_module(
:return:
The model module
"""
module_path = str(module_path)
model_root = str(module_path)

# ensure we will find the newly created module
importlib.invalidate_caches()

if not os.path.isdir(module_path):
raise ValueError(f"module_path '{module_path}' is not a directory.")

module_path = os.path.abspath(module_path)
ext_suffix = sysconfig.get_config_var("EXT_SUFFIX")
ext_mod_name = f"{module_name}._{module_name}"

# module already loaded?
if (m := sys.modules.get(ext_mod_name)) and m.__file__.endswith(
ext_suffix
):
# this is the c++ extension we can't unload
loaded_file = Path(m.__file__)
needed_file = Path(
module_path,
module_name,
f"_{module_name}{ext_suffix}",
)
# if we import a matlab-generated model where the extension
# is in a different directory
needed_file_matlab = Path(
module_path,
f"_{module_name}{ext_suffix}",
)
if not needed_file.exists():
if needed_file_matlab.exists():
needed_file = needed_file_matlab
else:
raise ModuleNotFoundError(
f"Cannot find extension module for {module_name} in "
f"{module_path}."
)

if not loaded_file.samefile(needed_file):
# this is not the right module, and we can't unload it
raise RuntimeError(
f"Cannot import extension for {module_name} from "
f"{module_path}, because an extension with the same name was "
f"has already been imported from {loaded_file.parent}. "
"Import the module with a different name or restart the "
"Python kernel."
)
# this is the right file, but did it change on disk?
t_imported = m._get_import_time() # noqa: protected-access
t_modified = os.path.getmtime(m.__file__)
if t_imported < t_modified:
t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat()
t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat()
raise RuntimeError(
f"Cannot import extension for {module_name} from "
f"{module_path}, because an extension in the same location "
f"has already been imported, but the file was modified on "
f"disk. \nImported at {t_imp_str}\nModified at {t_mod_str}.\n"
"Import the module with a different name or restart the "
"Python kernel."
)

# unlike extension modules, Python modules can be unloaded
if module_name in sys.modules:
# if a module with that name is already in sys.modules, we remove it,
# along with all other modules from that package. otherwise, there
# will be trouble if two different models with the same name are to
# be imported.
del sys.modules[module_name]
# collect first, don't delete while iterating
to_unload = {
loaded_module_name
for loaded_module_name in sys.modules.keys()
if loaded_module_name.startswith(f"{module_name}.")
}
for m in to_unload:
del sys.modules[m]

with set_path(module_path):
return importlib.import_module(module_name)
raise ValueError(f"module_path '{model_root}' is not a directory.")

module_path = Path(model_root, module_name, "__init__.py")

# We may want to import a matlab-generated model where the extension
# is in a different directory. This is not a regular use case. It's only
# used in the amici tests and can be removed at any time.
# The models (currently) use the default swig-import and require
# modifying sys.path.
module_path_matlab = Path(model_root, f"{module_name}.py")
if not module_path.is_file() and module_path_matlab.is_file():
with set_path(model_root):
return _module_from_path(module_name, module_path_matlab)

module = _module_from_path(module_name, module_path)
module._self = module
return module


class AmiciVersionError(RuntimeError):
Expand Down
44 changes: 39 additions & 5 deletions python/sdist/amici/__init__.template.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""AMICI-generated module for model TPL_MODELNAME"""

import datetime
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING
import amici


if TYPE_CHECKING:
from amici.jax import JAXModel

Expand All @@ -18,14 +22,44 @@
"version currently installed."
)

from .TPL_MODELNAME import * # noqa: F403, F401
from .TPL_MODELNAME import getModel as get_model # noqa: F401
TPL_MODELNAME = amici._module_from_path(
"TPL_MODELNAME.TPL_MODELNAME", Path(__file__).parent / "TPL_MODELNAME.py"
)
for var in dir(TPL_MODELNAME):
if not var.startswith("_"):
globals()[var] = getattr(TPL_MODELNAME, var)
get_model = TPL_MODELNAME.getModel

try:
# _self: this module; will be set during import
# via amici.import_model_module
TPL_MODELNAME._model_module = _self # noqa: F821
except NameError:
# when the model package is imported via `import`
TPL_MODELNAME._model_module = sys.modules[__name__]

def get_jax_model() -> "JAXModel":
from .jax import JAXModel_TPL_MODELNAME

return JAXModel_TPL_MODELNAME()
def get_jax_model() -> "JAXModel":
# If the model directory was meanwhile overwritten, this would load the
# new version, which would not match the previously imported extension.
# This is not allowed, as it would lead to inconsistencies.
jax_py_file = Path(__file__).parent / "jax.py"
jax_py_file = jax_py_file.resolve()
t_imported = TPL_MODELNAME._get_import_time() # noqa: protected-access
t_modified = os.path.getmtime(jax_py_file)
if t_imported < t_modified:
t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat()
t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat()
raise RuntimeError(
f"Refusing to import {jax_py_file} which was changed since "
f"TPL_MODELNAME was imported. This is to avoid inconsistencies "
"between the different model implementations.\n"
f"Imported at {t_imp_str}\nModified at {t_mod_str}.\n"
"Import the module with a different name or restart the "
"Python kernel."
)
jax = amici._module_from_path("jax", jax_py_file)
return jax.JAXModel_TPL_MODELNAME()


__version__ = "TPL_PACKAGE_VERSION"
4 changes: 2 additions & 2 deletions python/tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ def test_handling_of_fixed_time_point_event_triggers():
end
"""
module_name = "test_events_time_based"
with TemporaryDirectory(prefix=module_name, delete=False) as outdir:
with TemporaryDirectory(prefix=module_name) as outdir:
antimony2amici(
ant_model,
model_name=module_name,
Expand Down Expand Up @@ -765,7 +765,7 @@ def test_multiple_event_assignment_with_compartment():
"""
# watch out for too long path names on windows ...
module_name = "tst_mltple_ea_w_cmprtmnt"
with TemporaryDirectory(prefix=module_name, delete=False) as outdir:
with TemporaryDirectory(prefix=module_name) as outdir:
antimony2amici(
ant_model,
model_name=module_name,
Expand Down
Loading

0 comments on commit 759cf39

Please sign in to comment.