Skip to content

Commit

Permalink
Added probabilistic model interface for factor graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsch420 committed Dec 7, 2023
1 parent 1d35c83 commit dba38ae
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ numpy>=1.24.4
random_events>=1.2.5
tabulate>=0.9.0
probabilistic-model>=1.4.13
typing-extensions
28 changes: 25 additions & 3 deletions src/fglib2/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from abc import ABC
from typing import List, Optional
from typing import Tuple, Iterable
from typing_extensions import Self

import networkx as nx
import numpy as np

from .distributions import Multinomial
from random_events.variables import Discrete
from random_events.events import Event
from random_events.variables import Discrete, Variable
from random_events.events import Event, EncodedEvent
from probabilistic_model.probabilistic_model import ProbabilisticModel


class Node(ABC):
Expand Down Expand Up @@ -219,7 +221,7 @@ def __repr__(self):
return str(self)


class FactorGraph(nx.Graph):
class FactorGraph(nx.Graph, ProbabilisticModel):
"""
A factor graph.
Expand Down Expand Up @@ -470,3 +472,23 @@ def reset(self):
for edge in self.edges:
self.edges[edge]['edge'].variable_to_factor = None
self.edges[edge]['edge'].factor_to_variable = None

def _likelihood(self, event: Iterable) -> float:
raise NotImplementedError("Implement this method according to its docstring here"
"https://probabilistic-model.readthedocs.io/en/latest/autoapi/probabilistic_model/probabilistic_model/index.html#")

def _probability(self, event: EncodedEvent) -> float:
raise NotImplementedError("Implement this method according to its docstring here"
"https://probabilistic-model.readthedocs.io/en/latest/autoapi/probabilistic_model/probabilistic_model/index.html#")

def _mode(self) -> Tuple[Iterable[EncodedEvent], float]:
raise NotImplementedError("Implement this method according to its docstring here"
"https://probabilistic-model.readthedocs.io/en/latest/autoapi/probabilistic_model/probabilistic_model/index.html#")

def _conditional(self, event: EncodedEvent) -> Tuple[Optional[Self], float]:
raise NotImplementedError("Implement this method according to its docstring here"
"https://probabilistic-model.readthedocs.io/en/latest/autoapi/probabilistic_model/probabilistic_model/index.html#")

def marginal(self, variables: Iterable[Variable]) -> Optional[Self]:
raise NotImplementedError("Implement this method according to its docstring here"
"https://probabilistic-model.readthedocs.io/en/latest/autoapi/probabilistic_model/probabilistic_model/index.html#")
10 changes: 10 additions & 0 deletions test/test_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import unittest


class MyTestCase(unittest.TestCase):
def test_something(self):
self.assertEqual(True, False) # add assertion here


if __name__ == '__main__':
unittest.main()

0 comments on commit dba38ae

Please sign in to comment.