diff --git a/python/amici/__init__.py b/python/amici/__init__.py index 183a6255ee..05a9ebae8c 100644 --- a/python/amici/__init__.py +++ b/python/amici/__init__.py @@ -136,6 +136,22 @@ def _imported_from_setup() -> bool: hdf5_enabled = 'readSolverSettingsFromHDF5' in dir() +def _get_ptr(obj: Union[AmiciModel, AmiciExpData, AmiciSolver] + ) -> Union['amici.Model', 'amici.ExpData', 'amici.Solver']: + """ + Convenience wrapper that returns the smart pointer pointee, if applicable + + :param obj: + Potential smart pointer + + :returns: + Non-smart pointer + """ + if isinstance(obj, (amici.ModelPtr, amici.ExpDataPtr, amici.SolverPtr)): + return obj.get() + return obj + + def runAmiciSimulation( model: AmiciModel, solver: AmiciSolver, @@ -159,11 +175,9 @@ def runAmiciSimulation( ReturnData object with simulation results """ - if edata and isinstance(edata, amici.ExpDataPtr): - edata = edata.get() - with capture_cstdout(): - rdata = amici.runAmiciSimulation(solver.get(), edata, model.get()) + rdata = amici.runAmiciSimulation(_get_ptr(solver), _get_ptr(edata), + _get_ptr(model)) return numpy.ReturnDataView(rdata) @@ -177,13 +191,13 @@ def ExpData(*args) -> 'amici.ExpData': """ if isinstance(args[0], ReturnDataView): return amici.ExpData(args[0]['ptr'].get(), *args[1:]) - elif isinstance(args[0], amici.ExpDataPtr): + elif isinstance(args[0], (amici.ExpData, amici.ExpDataPtr)): # the *args[:1] should be empty, but by the time you read this, # the constructor signature may have changed and you are glad this # wrapper did not break. - return amici.ExpData(args[0].get(), *args[1:]) - elif isinstance(args[0], amici.ModelPtr): - return amici.ExpData(args[0].get()) + return amici.ExpData(_get_ptr(args[0]), *args[1:]) + elif isinstance(args[0], (amici.Model, amici.ModelPtr)): + return amici.ExpData(_get_ptr(args[0])) else: return amici.ExpData(*args) @@ -209,9 +223,9 @@ def runAmiciSimulations( """ with capture_cstdout(): edata_ptr_vector = amici.ExpDataPtrVector(edata_list) - rdata_ptr_list = amici.runAmiciSimulations(solver.get(), + rdata_ptr_list = amici.runAmiciSimulations(_get_ptr(solver), edata_ptr_vector, - model.get(), + _get_ptr(model), failfast, num_threads) return [numpy.ReturnDataView(r) for r in rdata_ptr_list] @@ -229,10 +243,7 @@ def readSolverSettingsFromHDF5( :param solver: Solver instance to which settings will be transferred :param location: location of solver settings in hdf5 file """ - if isinstance(solver, amici.SolverPtr): - amici.readSolverSettingsFromHDF5(file, solver.get(), location) - else: - amici.readSolverSettingsFromHDF5(file, solver, location) + amici.readSolverSettingsFromHDF5(file, _get_ptr(solver), location) def writeSolverSettingsToHDF5( @@ -248,10 +259,7 @@ def writeSolverSettingsToHDF5( :param solver: Solver instance from which settings will stored :param location: location of solver settings in hdf5 file """ - if isinstance(solver, amici.SolverPtr): - amici.writeSolverSettingsToHDF5(solver.get(), file, location) - else: - amici.writeSolverSettingsToHDF5(solver, file, location) + amici.writeSolverSettingsToHDF5(_get_ptr(solver), file, location) class add_path: