Skip to content

Commit

Permalink
Add activity parameter events
Browse files Browse the repository at this point in the history
  • Loading branch information
cmutel committed Nov 15, 2024
1 parent 438926e commit 8fca464
Show file tree
Hide file tree
Showing 6 changed files with 933 additions and 34 deletions.
60 changes: 44 additions & 16 deletions bw2data/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
from bw2data import config, databases, get_activity, projects
from bw2data.backends.schema import ExchangeDataset
from bw2data.signals import (
on_activity_parameter_recalculate,
on_activity_parameter_recalculate_exchanges,
on_activity_parameter_update_formula_activity_parameter_name,
on_activity_parameter_update_formula_database_parameter_name,
on_activity_parameter_update_formula_project_parameter_name,
on_database_parameter_recalculate,
on_database_parameter_update_formula_database_parameter_name,
on_database_parameter_update_formula_project_parameter_name,
Expand Down Expand Up @@ -673,7 +678,7 @@ def _static_dependencies(group):
return result

@staticmethod
def insert_dummy(group, activity):
def insert_dummy(group: str, activity: tuple, signal: bool = True):
code, database = activity[1], activity[0]
if (
not ActivityParameter.select()
Expand All @@ -684,13 +689,13 @@ def insert_dummy(group, activity):
)
.count()
):
ActivityParameter.create(
ActivityParameter(
group=group,
name="__dummy_{}__".format(uuid.uuid4().hex),
code=code,
database=database,
amount=0,
)
).save(signal=signal)

@staticmethod
def expired(group):
Expand Down Expand Up @@ -797,7 +802,7 @@ def is_dependency_within_group(name, group, include_order=False):
return True if name in names else False

@staticmethod
def recalculate(group):
def recalculate(group: str, signal: bool = True):
"""Recalculate all values for activity parameters in this group, and update their underlying `Activity` and `Exchange` values."""
# Start by traversing and updating the list of dependencies
if not ActivityParameter.expired(group):
Expand Down Expand Up @@ -847,10 +852,13 @@ def recalculate(group):
Group.get(name=group).freshen()
ActivityParameter.expire_downstream(group)

ActivityParameter.recalculate_exchanges(group)
ActivityParameter.recalculate_exchanges(group, signal=False)

if signal:
on_activity_parameter_recalculate.send(name=group)

@staticmethod
def recalculate_exchanges(group):
def recalculate_exchanges(group: str, signal: bool = True):
"""Recalculate formulas for all parameterized exchanges in group ``group``."""
if ActivityParameter.expired(group):
return ActivityParameter.recalculate(group)
Expand All @@ -862,14 +870,17 @@ def recalculate_exchanges(group):
for obj in ParameterizedExchange.select().where(ParameterizedExchange.group == group):
exc = ExchangeDataset.get(id=obj.exchange)
exc.data["amount"] = interpreter(obj.formula)
exc.save()
exc.save(signal=False)

databases.set_dirty(ActivityParameter.get(group=group).database)

if signal:
on_activity_parameter_recalculate_exchanges.send(name=group)

def save(self, *args, **kwargs):
"""Save this model instance"""
Group.get_or_create(name=self.group)[0].expire()
super(ActivityParameter, self).save(*args, **kwargs)
super().save(*args, **kwargs)

def is_deletable(self):
"""Perform a test to see if the current parameter can be deleted."""
Expand Down Expand Up @@ -901,7 +912,7 @@ def is_dependent_on(name, group):
return False

@classmethod
def update_formula_project_parameter_name(cls, old, new):
def update_formula_project_parameter_name(cls, old: str, new: str, signal: bool = True):
"""Performs an update of the formula of relevant parameters.
This method specifically targets project parameters used in activity
Expand Down Expand Up @@ -940,11 +951,16 @@ def update_formula_project_parameter_name(cls, old, new):
)
cls.bulk_update(data, fields=[cls.formula], batch_size=50)
for param_exc in exchanges:
param_exc.save()
param_exc.save(signal=False)
Group.update(fresh=False).where(Group.name << groups).execute()

if signal:
on_activity_parameter_update_formula_project_parameter_name.send(
old={"old": old}, new={"new": new}
)

@classmethod
def update_formula_database_parameter_name(cls, old, new):
def update_formula_database_parameter_name(cls, old: str, new: str, signal: bool = True):
"""Performs an update of the formula of relevant parameters.
This method specifically targets database parameters used in activity
Expand Down Expand Up @@ -983,11 +999,18 @@ def update_formula_database_parameter_name(cls, old, new):
)
cls.bulk_update(data, fields=[cls.formula], batch_size=50)
for param_exc in exchanges:
param_exc.save()
param_exc.save(signal=False)
Group.update(fresh=False).where(Group.name << groups).execute()

if signal:
on_activity_parameter_update_formula_database_parameter_name.send(
old={"old": old}, new={"new": new}
)

@classmethod
def update_formula_activity_parameter_name(cls, old, new, include_order=False):
def update_formula_activity_parameter_name(
cls, old: str, new: str, include_order: bool = False, signal: bool = True
):
"""Performs an update of the formula of relevant parameters.
This method specifically targets activity parameters used in activity
Expand Down Expand Up @@ -1017,9 +1040,14 @@ def update_formula_activity_parameter_name(cls, old, new, include_order=False):
)
cls.bulk_update(data, fields=[cls.formula], batch_size=50)
for param_exc in exchanges:
param_exc.save()
param_exc.save(signal=False)
Group.update(fresh=False).where(Group.name << groups).execute()

if signal:
on_activity_parameter_update_formula_activity_parameter_name.send(
old={"old": old}, new={"new": new, "include_order": include_order}
)

@classmethod
def create_table(cls):
super(ActivityParameter, cls).create_table()
Expand All @@ -1044,7 +1072,7 @@ def dict(self):
return obj


class ParameterizedExchange(Model):
class ParameterizedExchange(SnowflakeIDBaseClass):
group = TextField()
exchange = IntegerField(unique=True)
formula = TextField()
Expand All @@ -1057,7 +1085,7 @@ def create_table(cls):

def save(self, *args, **kwargs):
Group.get_or_create(name=self.group)[0].expire()
super(ParameterizedExchange, self).save(*args, **kwargs)
super().save(*args, **kwargs)
# Push the changed formula to the Exchange.
exc = ExchangeDataset.get_or_none(id=self.exchange)
if exc and exc.data.get("formula") != self.formula:
Expand Down
34 changes: 34 additions & 0 deletions bw2data/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,9 @@ def signal_dispatcher_generic_no_diff(
signal_dispatcher_on_database_parameter = partial(
signal_dispatcher_generic_no_diff, prefix="database_parameter", obj_type="database_parameter"
)
signal_dispatcher_on_activity_parameter = partial(
signal_dispatcher_generic_no_diff, prefix="activity_parameter", obj_type="activity_parameter"
)


def signal_dispatcher_on_database_write(sender, name: str) -> int:
Expand Down Expand Up @@ -678,6 +681,15 @@ def signal_dispatcher_on_update_formula_parameter_name(
signal_dispatcher_on_database_parameter_update_formula_database_parameter_name = partial(
signal_dispatcher_on_update_formula_parameter_name, kind="database", extra="database_"
)
signal_dispatcher_on_activity_parameter_update_formula_project_parameter_name = partial(
signal_dispatcher_on_update_formula_parameter_name, kind="activity", extra="project_"
)
signal_dispatcher_on_activity_parameter_update_formula_database_parameter_name = partial(
signal_dispatcher_on_update_formula_parameter_name, kind="activity", extra="database_"
)
signal_dispatcher_on_activity_parameter_update_formula_activity_parameter_name = partial(
signal_dispatcher_on_update_formula_parameter_name, kind="activity", extra="activity_"
)

# `.connect()` directly just fails silently...
signal_dispatcher_on_activity_database_change = partial(
Expand All @@ -697,6 +709,12 @@ def signal_dispatcher_on_update_formula_parameter_name(
signal_dispatcher_on_database_parameter_recalculate = partial(
signal_dispatcher_on_database_parameter, verb="recalculate"
)
signal_dispatcher_on_activity_parameter_recalculate = partial(
signal_dispatcher_on_activity_parameter, verb="recalculate"
)
signal_dispatcher_on_activity_parameter_recalculate_exchanges = partial(
signal_dispatcher_on_activity_parameter, verb="recalculate_exchanges"
)

projects = ProjectManager()
bw2signals.signaleddataset_on_save.connect(signal_dispatcher)
Expand All @@ -713,6 +731,13 @@ def signal_dispatcher_on_update_formula_parameter_name(
bw2signals.on_database_parameter_recalculate.connect(
signal_dispatcher_on_database_parameter_recalculate
)
bw2signals.on_activity_parameter_recalculate.connect(
signal_dispatcher_on_activity_parameter_recalculate
)
bw2signals.on_activity_parameter_recalculate_exchanges.connect(
signal_dispatcher_on_activity_parameter_recalculate_exchanges
)

bw2signals.on_project_parameter_update_formula_parameter_name.connect(
signal_dispatcher_on_project_parameter_update_formula_parameter_name
)
Expand All @@ -722,6 +747,15 @@ def signal_dispatcher_on_update_formula_parameter_name(
bw2signals.on_database_parameter_update_formula_database_parameter_name.connect(
signal_dispatcher_on_database_parameter_update_formula_database_parameter_name
)
bw2signals.on_activity_parameter_update_formula_project_parameter_name.connect(
signal_dispatcher_on_activity_parameter_update_formula_project_parameter_name
)
bw2signals.on_activity_parameter_update_formula_database_parameter_name.connect(
signal_dispatcher_on_activity_parameter_update_formula_database_parameter_name
)
bw2signals.on_activity_parameter_update_formula_activity_parameter_name.connect(
signal_dispatcher_on_activity_parameter_update_formula_activity_parameter_name
)


@wrapt.decorator
Expand Down
53 changes: 39 additions & 14 deletions bw2data/revisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from bw2data.backends.utils import dict_as_activitydataset, dict_as_exchangedataset
from bw2data.database import DatabaseChooser
from bw2data.errors import DifferentObjects, IncompatibleClasses
from bw2data.parameters import DatabaseParameter, ParameterBase, ProjectParameter
from bw2data.parameters import ActivityParameter, DatabaseParameter, ParameterBase, ProjectParameter
from bw2data.signals import SignaledDataset
from bw2data.snowflake_ids import snowflake_id_generator
from bw2data.utils import get_node
Expand Down Expand Up @@ -326,7 +326,9 @@ def project_parameter_recalculate(cls, revision_data: dict) -> None:

@classmethod
def project_parameter_update_formula_parameter_name(cls, revision_data: dict) -> None:
cls.ORM_CLASS.update_formula_parameter_name(signal=False, **cls._unwrap_diff_dict(revision_data))
cls.ORM_CLASS.update_formula_parameter_name(
signal=False, **cls._unwrap_diff_dict(revision_data)
)


class RevisionedDatabaseParameter(RevisionedParameter):
Expand All @@ -339,30 +341,51 @@ def database_parameter_recalculate(cls, revision_data: dict) -> None:

@classmethod
def database_parameter_update_formula_project_parameter_name(cls, revision_data: dict) -> None:
print(revision_data)
cls.ORM_CLASS.update_formula_project_parameter_name(signal=False, **cls._unwrap_diff_dict(revision_data))
cls.ORM_CLASS.update_formula_project_parameter_name(
signal=False, **cls._unwrap_diff_dict(revision_data)
)

@classmethod
def database_parameter_update_formula_database_parameter_name(cls, revision_data: dict) -> None:
cls.ORM_CLASS.update_formula_database_parameter_name(signal=False, **cls._unwrap_diff_dict(revision_data))
cls.ORM_CLASS.update_formula_database_parameter_name(
signal=False, **cls._unwrap_diff_dict(revision_data)
)


class RevisionedActivityParameter(RevisionedParameter):
KEYS = ("id", "database", "name", "formula", "amount", "data")
ORM_CLASS = DatabaseParameter
KEYS = ("id", "group", "database", "code", "name", "formula", "amount", "data")
ORM_CLASS = ActivityParameter

@classmethod
def database_parameter_recalculate(cls, revision_data: dict) -> None:
cls.ORM_CLASS.recalculate(database=revision_data["id"], signal=False)
def activity_parameter_recalculate(cls, revision_data: dict) -> None:
cls.ORM_CLASS.recalculate(group=revision_data["id"], signal=False)

@classmethod
def database_parameter_update_formula_project_parameter_name(cls, revision_data: dict) -> None:
print(revision_data)
cls.ORM_CLASS.update_formula_project_parameter_name(signal=False, **cls._unwrap_diff_dict(revision_data))
def activity_parameter_recalculate_exchanges(cls, revision_data: dict) -> None:
cls.ORM_CLASS.recalculate_exchanges(group=revision_data["id"], signal=False)

@classmethod
def database_parameter_update_formula_database_parameter_name(cls, revision_data: dict) -> None:
cls.ORM_CLASS.update_formula_database_parameter_name(signal=False, **cls._unwrap_diff_dict(revision_data))
def activity_parameter_update_formula_project_parameter_name(cls, revision_data: dict) -> None:
cls.ORM_CLASS.update_formula_project_parameter_name(
signal=False, **cls._unwrap_diff_dict(revision_data)
)

@classmethod
def activity_parameter_update_formula_database_parameter_name(cls, revision_data: dict) -> None:
cls.ORM_CLASS.update_formula_database_parameter_name(
signal=False, **cls._unwrap_diff_dict(revision_data)
)

@classmethod
def activity_parameter_update_formula_activity_parameter_name(cls, revision_data: dict) -> None:
dct = {
"old": revision_data["delta"]["dictionary_item_removed"]["root['old']"],
"new": revision_data["delta"]["dictionary_item_added"]["root['new']"],
"include_order": revision_data["delta"]["dictionary_item_added"][
"root['include_order']"
],
}
cls.ORM_CLASS.update_formula_activity_parameter_name(signal=False, **dct)


class RevisionedNode(RevisionedORMProxy):
Expand Down Expand Up @@ -429,12 +452,14 @@ def handle(cls, revision_data: dict) -> None:
ExchangeDataset: "lci_edge",
ProjectParameter: "project_parameter",
DatabaseParameter: "database_parameter",
ActivityParameter: "activity_parameter",
}
REVISIONED_LABEL_AS_OBJECT = {
"lci_node": RevisionedNode,
"lci_edge": RevisionedEdge,
"lci_database": RevisionedDatabase,
"project_parameter": RevisionedProjectParameter,
"database_parameter": RevisionedDatabaseParameter,
"activity_parameter": RevisionedActivityParameter,
}
REVISIONS_OBJECT_AS_LABEL = {v: k for k, v in REVISIONED_LABEL_AS_OBJECT.items()}
Loading

0 comments on commit 8fca464

Please sign in to comment.