From 53ae4c17c9579d7a5950e889b5fea65bb105b559 Mon Sep 17 00:00:00 2001 From: Mark Stephenson Date: Fri, 27 Sep 2024 16:21:26 -0600 Subject: [PATCH] Issue #0: Optimize communication when all satellites are talking --- src/bsk_rl/comm/communication.py | 30 ++++++++++++++++++----- src/bsk_rl/data/base.py | 4 +++ tests/unittest/comm/test_communication.py | 29 +++++++++++++++++++--- 3 files changed, 54 insertions(+), 9 deletions(-) diff --git a/src/bsk_rl/comm/communication.py b/src/bsk_rl/comm/communication.py index 759cc8d..9d68bce 100644 --- a/src/bsk_rl/comm/communication.py +++ b/src/bsk_rl/comm/communication.py @@ -2,11 +2,13 @@ import logging from abc import ABC, abstractmethod +from copy import copy from itertools import combinations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import numpy as np from scipy.sparse.csgraph import connected_components +from scipy.special import comb from bsk_rl.sim.dyn import LOSCommDynModel from bsk_rl.utils.functional import Resetable @@ -60,17 +62,33 @@ def communicate(self) -> None: return communication_pairs = self.communication_pairs() + if len(communication_pairs) > 0: logger.info( f"Communicating data between {len(communication_pairs)} pairs of satellites" ) - for sat_1, sat_2 in communication_pairs: - sat_1.data_store.stage_communicated_data(sat_2.data_store.data) - sat_2.data_store.stage_communicated_data(sat_1.data_store.data) + if len(communication_pairs) == comb(len(self.satellites), 2): + self._communicate_all() + self.last_communication_time = self.satellites[0].simulator.sim_time + else: + for sat_1, sat_2 in communication_pairs: + sat_1.data_store.stage_communicated_data(sat_2.data_store.data) + sat_2.data_store.stage_communicated_data(sat_1.data_store.data) + for satellite in self.satellites: + satellite.data_store.update_with_communicated_data() + self.last_communication_time = self.satellites[0].simulator.sim_time + + def _communicate_all(self): + """Optimized communication between all pairs of satellites.""" + logger.info("Optimizing data communication between all pairs of satellites") + + data_type = self.satellites[0].data_store.data.__class__ + final_data = data_type() + for satellite in self.satellites: + final_data += satellite.data_store.data for satellite in self.satellites: - satellite.data_store.update_with_communicated_data() - self.last_communication_time = self.satellites[0].simulator.sim_time + satellite.data_store.data = copy(final_data) class NoCommunication(CommunicationMethod): diff --git a/src/bsk_rl/data/base.py b/src/bsk_rl/data/base.py index b88b1c1..1910ff7 100644 --- a/src/bsk_rl/data/base.py +++ b/src/bsk_rl/data/base.py @@ -29,6 +29,10 @@ def __add__(self, other: "Data") -> "Data": """Define the combination of two units of data.""" pass + def __copy__(self) -> "Data": + """Create a shallow copy of the data.""" + return self.__class__() + self + class DataStore(ABC): """Base class for satellite data logging.""" diff --git a/tests/unittest/comm/test_communication.py b/tests/unittest/comm/test_communication.py index 6e51123..f4370cf 100644 --- a/tests/unittest/comm/test_communication.py +++ b/tests/unittest/comm/test_communication.py @@ -15,7 +15,7 @@ @patch.multiple(CommunicationMethod, __abstractmethods__=set()) class TestCommunicationMethod: def test_communicate(self): - mock_sats = [MagicMock(), MagicMock()] + mock_sats = [MagicMock(), MagicMock(), MagicMock()] mock_sats[0].simulator.sim_time = 0.0 comms = CommunicationMethod() comms.last_communication_time = 0.0 @@ -34,7 +34,7 @@ def test_communicate(self): sat.data_store.update_with_communicated_data.assert_called_once() def test_min_period_elapsed(self): - mock_sats = [MagicMock(), MagicMock()] + mock_sats = [MagicMock(), MagicMock(), MagicMock()] comms = CommunicationMethod(min_period=1.0) comms.link_satellites(mock_sats) comms.communication_pairs = MagicMock( @@ -47,7 +47,7 @@ def test_min_period_elapsed(self): sat.data_store.update_with_communicated_data.assert_called_once() def test_min_period_not_elapsed(self): - mock_sats = [MagicMock(), MagicMock()] + mock_sats = [MagicMock(), MagicMock(), MagicMock()] comms = CommunicationMethod(min_period=1.0) comms.link_satellites(mock_sats) comms.communication_pairs = MagicMock( @@ -59,6 +59,29 @@ def test_min_period_not_elapsed(self): for sat in mock_sats: sat.data_store.update_with_communicated_data.assert_not_called() + def test_override_communicate_all(self): + mock_sats = [MagicMock() for i in range(3)] + comms = FreeCommunication() + comms.last_communication_time = 0.0 + mock_sats[0].simulator.sim_time = 1.0 + comms.link_satellites(mock_sats) + comms._communicate_all = MagicMock() + comms.communicate() + comms._communicate_all.assert_called_once() + + def test_override_communicate_all_nocall(self): + mock_sats = [MagicMock() for i in range(3)] + comms = CommunicationMethod() + comms.communication_pairs = MagicMock( + return_value=[(mock_sats[1], mock_sats[0])] + ) + comms.last_communication_time = 0.0 + mock_sats[0].simulator.sim_time = 1.0 + comms.link_satellites(mock_sats) + comms._communicate_all = MagicMock() + comms.communicate() + comms._communicate_all.assert_not_called() + class TestNoCommunication: def test_communicate(self):