Skip to content

Commit

Permalink
Merge branch 'develop' into release_0_11_7
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Sep 22, 2020
2 parents fb1a96e + 1ce06e8 commit 3052de7
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions python/amici/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,22 @@ def _imported_from_setup() -> bool:
hdf5_enabled = 'readSolverSettingsFromHDF5' in dir()


def _get_ptr(obj: Union[AmiciModel, AmiciExpData, AmiciSolver]
) -> Union['amici.Model', 'amici.ExpData', 'amici.Solver']:
"""
Convenience wrapper that returns the smart pointer pointee, if applicable
:param obj:
Potential smart pointer
:returns:
Non-smart pointer
"""
if isinstance(obj, (amici.ModelPtr, amici.ExpDataPtr, amici.SolverPtr)):
return obj.get()
return obj


def runAmiciSimulation(
model: AmiciModel,
solver: AmiciSolver,
Expand All @@ -159,11 +175,9 @@ def runAmiciSimulation(
ReturnData object with simulation results
"""

if edata and isinstance(edata, amici.ExpDataPtr):
edata = edata.get()

with capture_cstdout():
rdata = amici.runAmiciSimulation(solver.get(), edata, model.get())
rdata = amici.runAmiciSimulation(_get_ptr(solver), _get_ptr(edata),
_get_ptr(model))
return numpy.ReturnDataView(rdata)


Expand All @@ -177,13 +191,13 @@ def ExpData(*args) -> 'amici.ExpData':
"""
if isinstance(args[0], ReturnDataView):
return amici.ExpData(args[0]['ptr'].get(), *args[1:])
elif isinstance(args[0], amici.ExpDataPtr):
elif isinstance(args[0], (amici.ExpData, amici.ExpDataPtr)):
# the *args[:1] should be empty, but by the time you read this,
# the constructor signature may have changed and you are glad this
# wrapper did not break.
return amici.ExpData(args[0].get(), *args[1:])
elif isinstance(args[0], amici.ModelPtr):
return amici.ExpData(args[0].get())
return amici.ExpData(_get_ptr(args[0]), *args[1:])
elif isinstance(args[0], (amici.Model, amici.ModelPtr)):
return amici.ExpData(_get_ptr(args[0]))
else:
return amici.ExpData(*args)

Expand All @@ -209,9 +223,9 @@ def runAmiciSimulations(
"""
with capture_cstdout():
edata_ptr_vector = amici.ExpDataPtrVector(edata_list)
rdata_ptr_list = amici.runAmiciSimulations(solver.get(),
rdata_ptr_list = amici.runAmiciSimulations(_get_ptr(solver),
edata_ptr_vector,
model.get(),
_get_ptr(model),
failfast,
num_threads)
return [numpy.ReturnDataView(r) for r in rdata_ptr_list]
Expand All @@ -229,10 +243,7 @@ def readSolverSettingsFromHDF5(
:param solver: Solver instance to which settings will be transferred
:param location: location of solver settings in hdf5 file
"""
if isinstance(solver, amici.SolverPtr):
amici.readSolverSettingsFromHDF5(file, solver.get(), location)
else:
amici.readSolverSettingsFromHDF5(file, solver, location)
amici.readSolverSettingsFromHDF5(file, _get_ptr(solver), location)


def writeSolverSettingsToHDF5(
Expand All @@ -248,10 +259,7 @@ def writeSolverSettingsToHDF5(
:param solver: Solver instance from which settings will stored
:param location: location of solver settings in hdf5 file
"""
if isinstance(solver, amici.SolverPtr):
amici.writeSolverSettingsToHDF5(solver.get(), file, location)
else:
amici.writeSolverSettingsToHDF5(solver, file, location)
amici.writeSolverSettingsToHDF5(_get_ptr(solver), file, location)


class add_path:
Expand Down

0 comments on commit 3052de7

Please sign in to comment.