diff --git a/python/amici/sbml_import.py b/python/amici/sbml_import.py index 6d63255733..06db8cf8f2 100644 --- a/python/amici/sbml_import.py +++ b/python/amici/sbml_import.py @@ -86,9 +86,10 @@ def __init__( Arguments: - sbml_source: Either a path to SBML file where the model is specified. - Or a model string as created by sbml.sbmlWriter().writeSBMLToString(). - @type string + sbml_source: Either a path to SBML file where the model is + specified, or a model string as created by + sbml.sbmlWriter().writeSBMLToString() or an instance of + libsbml.Model. @type str show_sbml_warnings: Indicates whether libSBML warnings should be displayed (default = True). @type bool @@ -103,13 +104,15 @@ def __init__( Raises: """ - self.sbml_reader = sbml.SBMLReader() - - if from_file: - sbml_doc = self.sbml_reader.readSBMLFromFile(sbml_source) + if isinstance(sbml_source, sbml.Model): + self.sbml_doc = sbml_source.getSBMLDocument() else: - sbml_doc = self.sbml_reader.readSBMLFromString(sbml_source) - self.sbml_doc = sbml_doc + self.sbml_reader = sbml.SBMLReader() + if from_file: + sbml_doc = self.sbml_reader.readSBMLFromFile(sbml_source) + else: + sbml_doc = self.sbml_reader.readSBMLFromString(sbml_source) + self.sbml_doc = sbml_doc self.show_sbml_warnings = show_sbml_warnings @@ -878,13 +881,17 @@ def processObservables(self, observables, sigmas, noise_distributions): # set cost functions llhYStrings = [] - for y_name in observables: + for y_name in observableNames: llhYStrings.append(noise_distribution_to_cost_function( noise_distributions.get(y_name, 'normal'))) llhYValues = [] - for llhYString, o_sym, m_sym, s_sym in zip(llhYStrings, observableSyms, measurementYSyms, sigmaYSyms): - f = sp.sympify(llhYString(o_sym), locals={str(o_sym): o_sym, str(m_sym): m_sym, str(s_sym): s_sym}) + for llhYString, o_sym, m_sym, s_sym \ + in zip(llhYStrings, observableSyms, + measurementYSyms, sigmaYSyms): + f = sp.sympify(llhYString(o_sym), locals={str(o_sym): o_sym, + str(m_sym): m_sym, + str(s_sym): s_sym}) llhYValues.append(f) llhYValues = sp.Matrix(llhYValues) diff --git a/tests/testMisc.py b/tests/testMisc.py index 7157bae01a..3ab6247e35 100755 --- a/tests/testMisc.py +++ b/tests/testMisc.py @@ -8,6 +8,9 @@ import os import unittest import sympy as sp +import libsbml +from tempfile import TemporaryDirectory + class TestAmiciMisc(unittest.TestCase): """TestCase class various AMICI Python interface functions""" @@ -47,6 +50,18 @@ def test_csc_matrix(self): assert symbolList == ['a0', 'a1', 'a2'] assert str(sparseMatrix) == 'Matrix([[a0, 0], [a1, a2]])' + def test_csc_matrix_empty(self): + """Test sparse CSC matrix creation for empty matrix""" + matrix = sp.Matrix() + symbolColPtrs, symbolRowVals, sparseList, symbolList, sparseMatrix = \ + amici.ode_export.csc_matrix(matrix, 'a') + print(symbolColPtrs, symbolRowVals, sparseList, symbolList, sparseMatrix) + assert symbolColPtrs == [0] + assert symbolRowVals == [] + assert sparseList == sp.Matrix(0, 0, []) + assert symbolList == [] + assert str(sparseMatrix) == 'Matrix(0, 0, [])' + def test_csc_matrix_vector(self): """Test sparse CSC matrix creation from matrix slice""" matrix = sp.Matrix([[1, 0], [2, 3]]) @@ -70,6 +85,29 @@ def test_csc_matrix_vector(self): assert symbolList == ['a2'] assert str(sparseMatrix) == 'Matrix([[0], [a2]])' + def test_sbml2amici_no_observables(self): + """Test model generation works for model without observables""" + + document = libsbml.SBMLDocument(3, 1) + model = document.createModel() + model.setTimeUnits("second") + model.setExtentUnits("mole") + model.setSubstanceUnits('mole') + c1 = model.createCompartment() + c1.setId('C1') + model.addCompartment(c1) + s1 = model.createSpecies() + s1.setId('S1') + s1.setCompartment('C1') + model.addSpecies(s1) + + sbml_importer = amici.sbml_import.SbmlImporter(sbml_source=model, + from_file=False) + tmpdir = TemporaryDirectory() + sbml_importer.sbml2amici(modelName="test", + output_dir=tmpdir.name, + observables=None) + if __name__ == '__main__': suite = unittest.TestSuite() diff --git a/version.txt b/version.txt index 69da6ebcd0..2d993c425b 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.10.6 +0.10.7