Skip to content

Commit

Permalink
Fix equality and isinstance checks for CFunction and derived classes (
Browse files Browse the repository at this point in the history
#22)

* Fix equality and isinstance checks for `CFunction` and derived classes

* Add Delay and RateOf classes
* Fix comparisons and isinstance checks
* Add test / doc

Closes #21
  • Loading branch information
dweindl authored Oct 11, 2024
1 parent 847a3ef commit ea4517f
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 9 deletions.
102 changes: 93 additions & 9 deletions sbmlmath/cfunction.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
"""Handling of ``<csymbol>`` functions"""

from __future__ import annotations

from sympy.core.function import UndefinedFunction

__all__ = ["CFunction", "delay", "rate_of"]
__all__ = ["CFunction", "delay", "rate_of", "Delay", "RateOf"]

DEF_URL_BASE = "http://www.sbml.org/sbml/symbols/"
DEF_URL_RATE_OF = DEF_URL_BASE + "rateOf"
DEF_URL_DELAY = DEF_URL_BASE + "delay"


class CFunction(UndefinedFunction):
Expand All @@ -17,20 +23,29 @@ class CFunction(UndefinedFunction):
See also https://www.w3.org/TR/MathML2/chapter4.html#contm.deffun.
"""

DEFINITION_URL = None
_cache = {}
_definition_url_to_derived_class = {}

def __new__(
cls,
cls: type[CFunction],
*args,
definition_url: str,
definition_url: str = None,
encoding: str = "text",
**kwargs,
):
definition_url = definition_url or cls.DEFINITION_URL
if definition_url is None:
raise ValueError("definition_url must be provided")

# Cache instances.
# If not done: (CFunction("A", definition_url="x")
# == CFunction("A", definition_url="y")) == False
if not (name := kwargs.get("name")):
if not len(args):
raise ValueError("name argument must be provided")
name = args[0]

cache_key = (name, definition_url, encoding)
if cached := cls._cache.get(cache_key):
return cached
Expand All @@ -47,15 +62,84 @@ def __new__(

return obj

def __eq__(self, other):
if not isinstance(other, CFunction):
return False

# if they represent the same value, they are equal
return self.definition_url == other.definition_url

def __hash__(self):
return hash(
(
self.__class__.__name__,
self.DEFINITION_URL,
self.definition_url,
self.name,
)
)

@classmethod
def register_subclass(cls, derived_class: type[CFunction]):
cls._definition_url_to_derived_class[derived_class.DEFINITION_URL] = (
derived_class
)


# Derived classes for specific SBML functions
class Delay(CFunction):
"""Produces a SBML ``delay()`` function.
Usually, it's preferable to use the :func:`delay` function.
This class can be used if a *delay* function with a different name is
needed.
Examples:
>>> from sympy.abc import a
>>> my_delay = Delay("my_delay")
>>> my_delay(a)
my_delay(a)
>>> delay(a) == my_delay(a)
True
"""

DEFINITION_URL = DEF_URL_DELAY

def __new__(cls, *args, **kwargs):
return super().__new__(cls, *args, **kwargs)


CFunction.register_subclass(Delay)


class RateOf(CFunction):
"""Produces a SBML ``rateOf()`` function.
Usually, it's preferable to use the :func:`rate_of` function.
This class can be used if a *rateOf* function with a different name is
needed.
Examples:
>>> from sympy.abc import a
>>> my_rate_of = RateOf("my_rate_of")
>>> my_rate_of(a)
my_rate_of(a)
>>> rate_of(a) == my_rate_of(a)
True
"""

DEFINITION_URL = DEF_URL_RATE_OF

def __new__(cls, *args, **kwargs):
return super().__new__(cls, *args, **kwargs)


CFunction.register_subclass(RateOf)

# SBML-defined functions

#: The SBML ``delay()`` function.
delay = CFunction(
"delay", definition_url="http://www.sbml.org/sbml/symbols/delay"
)
delay = Delay("delay")

#: The SBML ``rateOf()`` function.
rate_of = CFunction(
"rateOf", definition_url="http://www.sbml.org/sbml/symbols/rateOf"
)
rate_of = RateOf("rateOf")
26 changes: 26 additions & 0 deletions tests/test_csymbol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import sympy as sp
from sympy.core.function import UndefinedFunction

from sbmlmath.cfunction import *
from sbmlmath.cfunction import DEF_URL_DELAY, DEF_URL_RATE_OF
from sbmlmath.csymbol import *


Expand Down Expand Up @@ -31,3 +34,26 @@ def test_time_symbol():
assert (2 * t1).has(TimeSymbol) is True
assert (2 * t1).has(CSymbol) is True
assert (sp.sympify("l * e * e * t")).has(TimeSymbol) is False


def test_cfunction():
rate_of = CFunction("rateOf", definition_url=DEF_URL_RATE_OF)
# test that the cache works
assert rate_of is CFunction("rateOf", definition_url=DEF_URL_RATE_OF)
# test that definition URL is considered, not only the name
assert rate_of is not CFunction("rateOf", definition_url=DEF_URL_DELAY)

assert rate_of == CFunction("rateOf", definition_url=DEF_URL_RATE_OF)
assert rate_of != CFunction("rateOf", definition_url=DEF_URL_DELAY)
assert isinstance(rate_of, CFunction)
assert isinstance(rate_of, RateOf)
assert not isinstance(rate_of, delay)
a, b = sp.symbols("a b")
assert (rate_of(a) * 4).has(rate_of) is True
assert (rate_of(a) * 4).has(RateOf("rAteOf")) is True
assert rate_of(a) == rate_of(a)
assert rate_of(a) != rate_of(b)
assert (rate_of(a) * 4).has(RateOf) is False

assert UndefinedFunction(rate_of.name) != rate_of
assert UndefinedFunction(rate_of.name)(a).has(rate_of) is False

0 comments on commit ea4517f

Please sign in to comment.