From 1d35c83aa8423b44f0b714ea9e8da1559424fa93 Mon Sep 17 00:00:00 2001 From: Tom Schierenbeck Date: Wed, 6 Dec 2023 13:08:48 +0100 Subject: [PATCH] Adapted to new interface from probabilistic models. --- example/distributions.ipynb | 72 +++++++++++++++--------------- example/quickstart.ipynb | 16 +++---- requirements.txt | 4 +- src/fglib2/__init__.py | 2 +- src/fglib2/distributions.py | 88 +++---------------------------------- test/test_distributions.py | 48 +++----------------- 6 files changed, 58 insertions(+), 172 deletions(-) diff --git a/example/distributions.ipynb b/example/distributions.ipynb index 3b88716..304f308 100644 --- a/example/distributions.ipynb +++ b/example/distributions.ipynb @@ -16,21 +16,21 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 18, "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2023-11-02T15:14:41.667488383Z", - "start_time": "2023-11-02T15:14:41.636990366Z" + "end_time": "2023-12-06T12:08:07.453850144Z", + "start_time": "2023-12-06T12:08:07.428326539Z" } }, "outputs": [ { "data": { - "text/plain": "(Symbolic(name='animal'), Symbolic(name='color'), Integer(name='weight'))" + "text/plain": "(Symbolic(animal), Symbolic(color), Integer(weight))" }, - "execution_count": 9, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -55,13 +55,13 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 19, "outputs": [ { "data": { - "text/plain": "((Symbolic(name='animal'), Symbolic(name='color'), Integer(name='weight')),\n (3, 4, 25))" + "text/plain": "((Symbolic(animal), Symbolic(color), Integer(weight)), (3, 4, 25))" }, - "execution_count": 10, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -77,8 +77,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-02T15:14:41.669845591Z", - "start_time": "2023-11-02T15:14:41.641591695Z" + "end_time": "2023-12-06T12:08:07.455358353Z", + "start_time": "2023-12-06T12:08:07.431559191Z" } }, "id": "b4658690bcc40b12" @@ -100,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 20, "outputs": [], "source": [ "from random_events.events import Event\n", @@ -109,8 +109,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-02T15:14:41.728917032Z", - "start_time": "2023-11-02T15:14:41.699385761Z" + "end_time": "2023-12-06T12:08:07.481397647Z", + "start_time": "2023-12-06T12:08:07.440570786Z" } }, "id": "dffc7d500e84b520" @@ -127,13 +127,13 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 21, "outputs": [ { "data": { "text/plain": "0.37218590717704436" }, - "execution_count": 12, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -144,8 +144,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-02T15:14:41.731048688Z", - "start_time": "2023-11-02T15:14:41.699511383Z" + "end_time": "2023-12-06T12:08:07.482987436Z", + "start_time": "2023-12-06T12:08:07.444435556Z" } }, "id": "ee2f2d63802ad2d2" @@ -162,13 +162,13 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 22, "outputs": [ { "data": { - "text/plain": "([{Symbolic(name='animal'): ('Cat',), Symbolic(name='color'): ('brown',), Integer(name='weight'): (5,)}],\n 0.006809245726270245)" + "text/plain": "([{Symbolic(animal): ('Cat',), Symbolic(color): ('brown',), Integer(weight): (5,)}],\n 0.006809245726270245)" }, - "execution_count": 13, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -179,8 +179,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-02T15:14:41.731457742Z", - "start_time": "2023-11-02T15:14:41.699600957Z" + "end_time": "2023-12-06T12:08:07.500720800Z", + "start_time": "2023-12-06T12:08:07.451077955Z" } }, "id": "e2994d27e8474c54" @@ -197,13 +197,13 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 23, "outputs": [ { "data": { - "text/plain": "((Symbolic(name='animal'), Symbolic(name='color')), (3, 4))" + "text/plain": "((Symbolic(animal), Symbolic(color)), (3, 4))" }, - "execution_count": 14, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -215,8 +215,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-02T15:14:41.731646143Z", - "start_time": "2023-11-02T15:14:41.699682618Z" + "end_time": "2023-12-06T12:08:07.501022406Z", + "start_time": "2023-12-06T12:08:07.491780538Z" } }, "id": "d5c35229158ee9c2" @@ -243,26 +243,26 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 24, "outputs": [ { "data": { "text/plain": "0.37218590717704436" }, - "execution_count": 15, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "conditional = distribution.conditional(event)\n", - "conditional.probability(event)" + "conditional, probability = distribution.conditional(event)\n", + "probability" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-02T15:14:41.731791084Z", - "start_time": "2023-11-02T15:14:41.699766715Z" + "end_time": "2023-12-06T12:08:07.501267015Z", + "start_time": "2023-12-06T12:08:07.492079689Z" } }, "id": "5f036dc30b5cf18d" @@ -289,7 +289,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 25, "outputs": [ { "name": "stdout", @@ -331,8 +331,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-02T15:14:41.731995643Z", - "start_time": "2023-11-02T15:14:41.699946388Z" + "end_time": "2023-12-06T12:08:07.501511837Z", + "start_time": "2023-12-06T12:08:07.492235984Z" } }, "id": "dcc9ca9e0f922e53" diff --git a/example/quickstart.ipynb b/example/quickstart.ipynb index e2b4e21..ff27e89 100644 --- a/example/quickstart.ipynb +++ b/example/quickstart.ipynb @@ -27,8 +27,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-02T15:13:40.814428028Z", - "start_time": "2023-11-02T15:13:40.754451007Z" + "end_time": "2023-12-06T12:07:52.022289981Z", + "start_time": "2023-12-06T12:07:52.017452534Z" } }, "id": "b857b83f5ae8482c" @@ -63,8 +63,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-02T15:13:40.926252491Z", - "start_time": "2023-11-02T15:13:40.815684343Z" + "end_time": "2023-12-06T12:07:52.127695085Z", + "start_time": "2023-12-06T12:07:52.024695184Z" } }, "id": "97e60dbee16c1dbd" @@ -99,8 +99,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-02T15:13:41.194177397Z", - "start_time": "2023-11-02T15:13:40.926542391Z" + "end_time": "2023-12-06T12:07:52.414076356Z", + "start_time": "2023-12-06T12:07:52.126868633Z" } }, "id": "bc41f2e657f4f715" @@ -174,8 +174,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-02T15:13:41.195205356Z", - "start_time": "2023-11-02T15:13:41.185785904Z" + "end_time": "2023-12-06T12:07:52.414471921Z", + "start_time": "2023-12-06T12:07:52.403649920Z" } }, "id": "47a7412bada04433" diff --git a/requirements.txt b/requirements.txt index 757456c..5be9a84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ networkx>=3.0 numpy>=1.24.4 -random_events>=1.1.3 +random_events>=1.2.5 tabulate>=0.9.0 -probabilistic-model>=1.1.0 +probabilistic-model>=1.4.13 diff --git a/src/fglib2/__init__.py b/src/fglib2/__init__.py index 7bb021e..bc50bee 100644 --- a/src/fglib2/__init__.py +++ b/src/fglib2/__init__.py @@ -1 +1 @@ -__version__ = '1.1.3' +__version__ = '1.1.4' diff --git a/src/fglib2/distributions.py b/src/fglib2/distributions.py index bd3093a..4e334c0 100644 --- a/src/fglib2/distributions.py +++ b/src/fglib2/distributions.py @@ -8,6 +8,7 @@ import tabulate from probabilistic_model.probabilistic_model import ProbabilisticModel +from typing_extensions import Self class Multinomial(ProbabilisticModel): @@ -15,7 +16,7 @@ class Multinomial(ProbabilisticModel): A multinomial distribution over discrete random variables. """ - variables = Tuple[Discrete] + variables: Tuple[Discrete] """ The variables in the distribution. """ @@ -42,13 +43,6 @@ def __init__(self, variables: Iterable[Discrete], probabilities: Optional[np.nda self.probabilities = probabilities def marginal(self, variables: Iterable[Discrete]) -> 'Multinomial': - """ - Compute the marginal distribution over the given variables. - - :param variables: The variables to keep over. - - :return: The marginal distribution over variables. - """ # calculate which variables to marginalize over as the difference between variables and self.variables axis = tuple(self.variables.index(variable) for variable in self.variables if variable not in variables) @@ -59,10 +53,6 @@ def marginal(self, variables: Iterable[Discrete]) -> 'Multinomial': return Multinomial(variables, probabilities) def _mode(self) -> Tuple[List[EncodedEvent], float]: - """ - Calculate the most likely event. - :return: The mode of the distribution as EncodedEvent and its likelihood. - """ likelihood = np.max(self.probabilities) events = np.transpose(np.asarray(self.probabilities == likelihood).nonzero()) mode = [EncodedEvent(zip(self.variables, event)) for event in events.tolist()] @@ -109,7 +99,7 @@ def __eq__(self, other: 'Multinomial') -> bool: functions are equal and the order of dimensions are equal. """ - return (self.variables == other.variables and + return (isinstance(other, self.__class__) and self.variables == other.variables and self.probabilities.shape == other.probabilities.shape and np.allclose(self.probabilities, other.probabilities)) @@ -129,87 +119,19 @@ def to_tabulate(self) -> str: return tabulate.tabulate(table, headers="firstrow", tablefmt="fancy_grid") - def encode(self, event: Iterable) -> List[int]: - """ - Encode an event into a list of indices within the respective domains. - :param event: The event to encode as a list of elements of the respective variables domains - :return: The encoded event - """ - return [variable.encode(value) for variable, value in zip(self.variables, event)] - - def encode_many(self, events: Iterable[Iterable]) -> List[List[int]]: - """ - Encode multiple events into a list of indices within the respective domains. - :param events: The events to encode as a list of elements of the respective variables domains - :return: The encoded events - """ - return [self.encode(event) for event in events] - - def decode(self, event: Iterable[int]) -> List: - """ - Decode an event from a list of indices to a list of values. - :param event: The event to decode as a list of indices - :return: The decoded event - """ - return [variable.decode(value) for variable, value in zip(self.variables, event)] - - def decode_many(self, events: Iterable[Iterable[int]]) -> List[List]: - """ - Decode multiple events from a list of indices to a list of values. - :param events: The events to decode as a list of indices - :return: The decoded events - """ - return [self.decode(event) for event in events] - def _probability(self, event: EncodedEvent) -> float: - """ - Calculate the probability of an event encoded. - The encoded event has to contain information about all variables in the distribution. - :param event: The event to calculate the probability of. - :return: P(event) - """ indices = tuple(event[variable] for variable in self.variables) return self.probabilities[np.ix_(*indices)].sum() - def probability(self, event: Event) -> float: - """ - Calculate the probability of an event. - :param event: The event to calculate the probability of. - :return: P(event) - """ - event = Event({variable: variable.domain for variable in self.variables}) & event - return self._probability(event.encode()) - def _likelihood(self, event: List[int]) -> float: - """ - Calculate the likelihood of a full evidence query. - The event is a list of indices for the variable values in the same order - :param event: - :return: P(event) - """ return float(self.probabilities[tuple(event)]) - def _conditional(self, event: EncodedEvent) -> 'Multinomial': - """ - Calculate the conditional distribution given an event encoded. - The encoded event has to contain information about all variables in the distribution. - :param event: The event to condition on. - :return: The conditional distribution - """ + def _conditional(self, event: EncodedEvent) -> Tuple[Optional[Self], float]: indices = tuple(event[variable] for variable in self.variables) indices = np.ix_(*indices) probabilities = np.zeros_like(self.probabilities) probabilities[indices] = self.probabilities[indices] - return Multinomial(self.variables, probabilities) - - def conditional(self, event: Event) -> 'Multinomial': - """ - Calculate the conditional distribution given an event. - :param event: The event to condition on - :return: The conditional distribution - """ - event = Event({variable: variable.domain for variable in self.variables}) & event - return self._conditional(event.encode()) + return Multinomial(self.variables, probabilities), self.probabilities[indices].sum() def normalize(self) -> 'Multinomial': """ diff --git a/test/test_distributions.py b/test/test_distributions.py index 2699d6a..b7f21be 100755 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -54,45 +54,6 @@ def test_to_str(self): self.assertTrue(str(distribution)) -class MultinomialEncodingTestCase(unittest.TestCase): - - animal: Symbolic - color: Symbolic - distribution: Multinomial - - @classmethod - def setUpClass(cls): - cls.animal = Symbolic("animal", {"cat", "dog", "mouse"}) - cls.color = Symbolic("color", {"grey", "brown", "black"}) - cls.distribution = Multinomial((cls.animal, cls.color)) - - def test_encode(self): - event = ["cat", "grey"] - self.assertEqual(self.distribution.encode(event), [0, 2]) - - def test_decode(self): - event = [1, 0] - self.assertEqual(self.distribution.decode(event), ["dog", "black"]) - - def test_encode_raises(self): - event = ["bob", "linda"] - with self.assertRaises(ValueError): - self.distribution.encode(event) - - def test_decode_raises(self): - event = [3, 0] - with self.assertRaises(IndexError): - self.distribution.decode(event) - - def test_encode_many(self): - event = [["cat", "grey"], ["dog", "brown"]] - self.assertEqual(self.distribution.encode_many(event), [[0, 2], [1, 1]]) - - def test_decode_many(self): - event = [[0, 2], [1, 1]] - self.assertEqual(self.distribution.decode_many(event), [["cat", "grey"], ["dog", "brown"]]) - - class MultinomialInferenceTestCase(unittest.TestCase): x: Symbolic y: Symbolic @@ -184,14 +145,16 @@ def test_random_probability(self): def test_crafted_conditional(self): event = Event({self.y: (0, 1)}) - conditional = self.crafted_distribution.conditional(event).normalize() + conditional, probability = self.crafted_distribution.conditional(event) + conditional = conditional.normalize() self.assertEqual(conditional.probability(event), 1) self.assertEqual(conditional.probability(Event()), 1.) self.assertEqual(conditional.probability(Event({self.y: 2})), 0.) def test_random_conditional(self): event = Event({self.y: (0, 1)}) - conditional = self.random_distribution.conditional(event).normalize() + conditional, _ = self.random_distribution.conditional(event) + conditional = conditional.normalize() self.assertAlmostEqual(conditional.probability(event), 1) self.assertAlmostEqual(conditional.probability(Event()), 1.) self.assertEqual(conditional.probability(Event({self.y: 2})), 0.) @@ -246,8 +209,9 @@ def test_left_subset_variables(self): event = condition.intersection(Event({self.y: y_value})) # manual result as P(Y,X) = P(X) * P(Y|X) + conditional_result, conditional_probability = result.conditional(condition) manual_result = (adjusted_distribution_x.normalize().probability(condition) * - result.conditional(condition).normalize().probability(event)) + conditional_result.normalize().probability(event)) product_result = result.normalize().probability(event) self.assertAlmostEqual(manual_result, product_result)