diff --git a/.ruff.toml b/.ruff.toml index 39bec903..7ba5db15 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -1 +1 @@ -ignore-init-module-imports = true \ No newline at end of file +ignore-init-module-imports = true diff --git a/pyproject.toml b/pyproject.toml index 24b0d74b..0829a54b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,36 +1,37 @@ [build-system] - requires = ["setuptools", "setuptools-scm"] - build-backend = "setuptools.build_meta" +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" [project] - name = "bsk_rl" - version = "0.0.0" - authors = [ - {name = "Adam Herrmann", email = "adam.herrmann@colorado.edu"}, - {name = "Mark Stephenson", email = "mark.a.stephenson@colorado.edu"}, - ] - description = "RL environments and tools for spacecraft autonomy research, built on Basilisk. Developed by the AVS Lab." - readme = "README.md" - requires-python = ">=3.9.0" - license = {text = "MIT"} - dependencies = [ - "deap==1.3.3", - "Deprecated", - "gymnasium", - "matplotlib", - "numpy", - "pandas", - "pettingzoo", - "pytest", - "pytest-cov", - "pytest-repeat", - "requests", - "scikit-learn", - "scipy", - "stable-baselines3", - "tensorflow", - "torch", - ] +name = "bsk_rl" +version = "0.0.0" +authors = [ + { name = "Adam Herrmann", email = "adam.herrmann@colorado.edu" }, + { name = "Mark Stephenson", email = "mark.a.stephenson@colorado.edu" }, +] +description = "RL environments and tools for spacecraft autonomy research, built on Basilisk. Developed by the AVS Lab." +readme = "README.md" +requires-python = ">=3.9.0" +license = { text = "MIT" } +dependencies = [ + "deap==1.3.3", + "Deprecated", + "gymnasium", + "matplotlib", + "numpy", + "pandas", + "pettingzoo", + "pytest", + "pytest-cov", + "pytest-repeat", + "requests", + "ruff>=0.1.9", + "scikit-learn", + "scipy", + "stable-baselines3", + "tensorflow", + "torch", +] [project.scripts] - finish_install = "bsk_rl.finish_install:pck_install" \ No newline at end of file +finish_install = "bsk_rl.finish_install:pck_install" diff --git a/src/bsk_rl/envs/general_satellite_tasking/.ruff.toml b/src/bsk_rl/envs/general_satellite_tasking/.ruff.toml new file mode 100644 index 00000000..844e81f5 --- /dev/null +++ b/src/bsk_rl/envs/general_satellite_tasking/.ruff.toml @@ -0,0 +1,11 @@ +extend-safe-fixes = ["D"] + +[lint.pydocstyle] +convention = "google" + +[lint] +select = ["D", "W", "D401", "D404"] +ignore = ["D104"] + +[lint.pycodestyle] +max-doc-length = 88 diff --git a/src/bsk_rl/envs/general_satellite_tasking/gym_env.py b/src/bsk_rl/envs/general_satellite_tasking/gym_env.py index fae80735..67fd5fbc 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/gym_env.py +++ b/src/bsk_rl/envs/general_satellite_tasking/gym_env.py @@ -1,3 +1,5 @@ +"""General Satellite Tasking is a framework for satellite tasking RL environments.""" + import functools import logging import os @@ -30,6 +32,28 @@ class GeneralSatelliteTasking(Env, Generic[SatObs, SatAct]): + """A Gymnasium environment adaptable to a wide range satellite tasking problems. + + These problems involve satellite(s) being tasked to complete tasks and maintain + aliveness. These tasks often include rewards for data collection. The environment + can be configured for any collection of satellites, including heterogenous + constellations. Other configurable aspects are environment features (e.g. + imaging targets), data collection and recording, and intersatellite + communication of data. + + The state space is a tuple containing the state of each satellite. Actions are + assigned as a tuple of actions, one per satellite. + + The preferred method of instantiating this environment is to make the + "GeneralSatelliteTasking-v1" environment and pass a kwargs dict with the + environment configuration. In some cases (e.g. the multiprocessed Gymnasium + vector environment), it is necessary for compatibility to instead register a new + environment using the GeneralSatelliteTasking class and a kwargs dict. See + examples/general_satellite_tasking for examples of environment configuration. + + New environments should be built using this framework. + """ + def __init__( self, satellites: Union[Satellite, list[Satellite]], @@ -47,25 +71,7 @@ def __init__( log_dir: Optional[str] = None, render_mode=None, ) -> None: - """A Gymnasium environment adaptable to a wide range satellite tasking problems - that involve satellite(s) being tasked to complete tasks and maintain aliveness. - These tasks often include rewards for data collection. The environment can be - configured for any collection of satellites, including heterogenous - constellations. Other configurable aspects are environment features (e.g. - imaging targets), data collection and recording, and intersatellite - communication of data. - - The state space is a tuple containing the state of each satellite. Actions are - assigned as a tuple of actions, one per satellite. - - The preferred method of instantiating this environment is to make the - "GeneralSatelliteTasking-v1" environment and pass a kwargs dict with the - environment configuration. In some cases (e.g. the multiprocessed Gymnasium - vector environment), it is necessary for compatibility to instead register a new - environment using the GeneralSatelliteTasking class and a kwargs dict. See - examples/general_satellite_tasking for examples of environment configuration. - - New environments should be built using this framework. + """Construct the GeneralSatelliteTasking environment. Args: satellites: Satellites(s) to be simulated. @@ -206,10 +212,11 @@ def reset( return observation, info def delete_simulator(self): - """Delete Basilisk objects. Only self.simulator contains strong references to - BSK models, so deleting it will delete all Basilisk objects. Enable debug-level - logging to verify that the simulator, FSW, dynamics, and environment models are - all deleted on reset. + """Delete Basilisk objects. + + Only self.simulator contains strong references to BSK models, so deleting it + will delete all Basilisk objects. Enable debug-level logging to verify that the + simulator, FSW, dynamics, and environment models are all deleted on reset. """ try: del self.simulator @@ -264,7 +271,7 @@ def _get_truncated(self) -> bool: @property def action_space(self) -> spaces.Space[MultiSatAct]: - """Compose satellite action spaces + """Compose satellite action spaces. Returns: Joint action space @@ -273,8 +280,9 @@ def action_space(self) -> spaces.Space[MultiSatAct]: @property def observation_space(self) -> spaces.Space[MultiSatObs]: - """Compose satellite observation spaces. Note: calls reset(), which can be - expensive, to determine observation size. + """Compose satellite observation spaces. + + Note: calls reset(), which can be expensive, to determine observation size. Returns: Joint observation space @@ -319,10 +327,10 @@ def _step(self, actions: MultiSatAct) -> None: def step( self, actions: MultiSatAct ) -> tuple[MultiSatObs, float, bool, bool, dict[str, Any]]: - """Propagate the simulation, update information, and get rewards + """Propagate the simulation, update information, and get rewards. Args: - Joint action for satellites + actions: Joint action for satellites Returns: observation, reward, terminated, truncated, info @@ -343,22 +351,24 @@ def step( return observation, reward, terminated, truncated, info def render(self) -> None: # pragma: no cover - """No rendering implemented""" + """No rendering implemented.""" return None def close(self) -> None: - """Try to cleanly delete everything""" + """Try to cleanly delete everything.""" if self.simulator is not None: del self.simulator class SingleSatelliteTasking(GeneralSatelliteTasking, Generic[SatObs, SatAct]): - """A special case of the GeneralSatelliteTasking for one satellite. For - compatibility with standard training APIs, actions and observations are directly + """A special case of the GeneralSatelliteTasking for one satellite. + + For compatibility with standard training APIs, actions and observations are directly exposed for the single satellite and are not wrapped in a tuple. """ def __init__(self, *args, **kwargs) -> None: + """Construct the SingleSatelliteTasking environment.""" super().__init__(*args, **kwargs) if not len(self.satellites) == 1: raise ValueError( @@ -367,21 +377,22 @@ def __init__(self, *args, **kwargs) -> None: @property def action_space(self) -> spaces.Space[SatAct]: - """Return the single satellite action space""" + """Return the single satellite action space.""" return self.satellite.action_space @property def observation_space(self) -> spaces.Box: - """Return the single satellite observation space""" + """Return the single satellite observation space.""" super().observation_space return self.satellite.observation_space @property def satellite(self) -> Satellite: + """Satellite being tasked.""" return self.satellites[0] def step(self, action) -> tuple[Any, float, bool, bool, dict[str, Any]]: - """Task the satellite with a single action""" + """Task the satellite with a single action.""" return super().step([action]) def _get_obs(self) -> Any: @@ -396,12 +407,13 @@ class MultiagentSatelliteTasking( def reset( self, seed: int | None = None, options=None ) -> tuple[MultiSatObs, dict[str, Any]]: + """Reset the environment and return PettingZoo Parallel API format.""" self.newly_dead = [] return super().reset(seed, options) @property def agents(self) -> list[AgentID]: - """Agents currently in the environment""" + """Agents currently in the environment.""" truncated = super()._get_truncated() return [ satellite.id @@ -411,7 +423,7 @@ def agents(self) -> list[AgentID]: @property def num_agents(self) -> int: - """Number of agents currently in the environment""" + """Number of agents currently in the environment.""" return len(self.agents) @property @@ -421,7 +433,7 @@ def possible_agents(self) -> list[AgentID]: @property def max_num_agents(self) -> int: - """Maximum number of agents possible in the environment""" + """Maximum number of agents possible in the environment.""" return len(self.possible_agents) @property @@ -431,7 +443,7 @@ def previously_dead(self) -> list[AgentID]: @property def observation_spaces(self) -> dict[AgentID, spaces.Box]: - """Return the observation space for each agent""" + """Return the observation space for each agent.""" return { agent: obs_space for agent, obs_space in zip(self.possible_agents, super().observation_space) @@ -439,12 +451,12 @@ def observation_spaces(self) -> dict[AgentID, spaces.Box]: @functools.lru_cache(maxsize=None) def observation_space(self, agent: AgentID) -> spaces.Space[SatObs]: - """Return the observation space for a certain agent""" + """Return the observation space for a certain agent.""" return self.observation_spaces[agent] @property def action_spaces(self) -> dict[AgentID, spaces.Space[SatAct]]: - """Return the action space for each agent""" + """Return the action space for each agent.""" return { agent: act_space for agent, act_space in zip(self.possible_agents, super().action_space) @@ -452,11 +464,11 @@ def action_spaces(self) -> dict[AgentID, spaces.Space[SatAct]]: @functools.lru_cache(maxsize=None) def action_space(self, agent: AgentID) -> spaces.Space[SatAct]: - """Return the action space for a certain agent""" + """Return the action space for a certain agent.""" return self.action_spaces[agent] def _get_obs(self) -> dict[AgentID, SatObs]: - """Format the observation per the PettingZoo Parallel API""" + """Format the observation per the PettingZoo Parallel API.""" return { agent: satellite.get_obs() for agent, satellite in zip(self.possible_agents, self.satellites) @@ -464,7 +476,7 @@ def _get_obs(self) -> dict[AgentID, SatObs]: } def _get_reward(self) -> dict[AgentID, float]: - """Format the reward per the PettingZoo Parallel API""" + """Format the reward per the PettingZoo Parallel API.""" reward = deepcopy(self.reward_dict) for agent, satellite in zip(self.possible_agents, self.satellites): if not satellite.is_alive(): @@ -478,7 +490,7 @@ def _get_reward(self) -> dict[AgentID, float]: return reward def _get_terminated(self) -> dict[AgentID, bool]: - """Format terminations per the PettingZoo Parallel API""" + """Format terminations per the PettingZoo Parallel API.""" if self.terminate_on_time_limit and super()._get_truncated(): return { agent: True @@ -493,7 +505,7 @@ def _get_terminated(self) -> dict[AgentID, bool]: } def _get_truncated(self) -> dict[AgentID, bool]: - """Format truncations per the PettingZoo Parallel API""" + """Format truncations per the PettingZoo Parallel API.""" truncated = super()._get_truncated() return { agent: truncated @@ -502,7 +514,7 @@ def _get_truncated(self) -> dict[AgentID, bool]: } def _get_info(self) -> dict[AgentID, dict]: - """Format info per the PettingZoo Parallel API""" + """Format info per the PettingZoo Parallel API.""" info = super()._get_info() for agent in self.possible_agents: if agent in self.previously_dead: @@ -519,7 +531,7 @@ def step( dict[AgentID, bool], dict[AgentID, dict], ]: - """Step the environment and return PettingZoo Parallel API format""" + """Step the environment and return PettingZoo Parallel API format.""" logger.info("=== STARTING STEP ===") previous_alive = self.agents diff --git a/src/bsk_rl/envs/general_satellite_tasking/scenario/communication.py b/src/bsk_rl/envs/general_satellite_tasking/scenario/communication.py index af0c82bd..f43718b8 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/scenario/communication.py +++ b/src/bsk_rl/envs/general_satellite_tasking/scenario/communication.py @@ -1,3 +1,5 @@ +"""Communication of data between satellites.""" + import logging from abc import ABC, abstractmethod from itertools import combinations @@ -15,22 +17,26 @@ class CommunicationMethod(ABC): + """Base class for defining data sharing between satellites.""" + def __init__(self, satellites: Optional[list["Satellite"]] = None) -> None: - """Base class for defining data sharing between satellites. Subclasses implement - a way of determining which pairs of satellites share data.""" + """Construct base communication class. + + Subclasses implement a way of determining which pairs of satellites share data. + """ self.satellites = satellites def reset(self) -> None: - """Called after simulator initialization""" + """Reset communication after simulator initialization.""" pass @abstractmethod # pragma: no cover def _communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]: - """List pair of satellite that should share data""" + """List pair of satellite that should share data.""" pass def communicate(self) -> None: - """Share data between paired satellites""" + """Share data between paired satellites.""" for sat_1, sat_2 in self._communication_pairs(): sat_1.data_store.stage_communicated_data(sat_2.data_store.data) sat_2.data_store.stage_communicated_data(sat_1.data_store.data) @@ -39,24 +45,31 @@ def communicate(self) -> None: class NoCommunication(CommunicationMethod): - """Implements no communication between satellite""" + """Implements no communication between satellite.""" def _communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]: return [] class FreeCommunication(CommunicationMethod): - """Implements communication between satellites at every step""" + """Implements communication between satellites at every step.""" def _communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]: return list(combinations(self.satellites, 2)) class LOSCommunication(CommunicationMethod): + """Implements communication between satellites with a direct line-of-sight. + # TODO only communicate data from before latest LOS time + """ + def __init__(self, satellites: list["Satellite"]) -> None: - """Implements communication between satellites that have a direct line of - sight""" + """Construct line-of-sigh communication management. + + Args: + satellites: List of satellites to communicate between. + """ super().__init__(satellites) for satellite in self.satellites: if not issubclass(satellite.dyn_type, LOSCommDynModel): @@ -66,6 +79,7 @@ def __init__(self, satellites: list["Satellite"]) -> None: ) def reset(self) -> None: + """Add loggers to satellites to track line-of-sight communication.""" super().reset() self.los_logs = {} @@ -94,6 +108,7 @@ def _communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]: return pairs def communicate(self) -> None: + """Clear line-of-sight communication logs once communicated.""" super().communicate() for sat_1, logs in self.los_logs.items(): for sat_2, logger in logs.items(): @@ -101,8 +116,11 @@ def communicate(self) -> None: class MultiDegreeCommunication(CommunicationMethod): - """Compose with another type to have multi-degree communications. For example, if - a <-> b and b <-> c, also communicate between a <-> c""" + """Compose with another type to use multi-degree communications. + + For example, if a <-> b and b <-> c, multidegree communication will also communicate + between a <-> c. + """ def _communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]: graph = np.zeros((len(self.satellites), len(self.satellites)), dtype=bool) @@ -118,4 +136,9 @@ def _communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]: class LOSMultiCommunication(MultiDegreeCommunication, LOSCommunication): + """Multidegree line of sight communication. + + Composes MultiDegreeCommunication with LOSCommunication. + """ + pass diff --git a/src/bsk_rl/envs/general_satellite_tasking/scenario/data.py b/src/bsk_rl/envs/general_satellite_tasking/scenario/data.py index 91ec01a1..98f00474 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/scenario/data.py +++ b/src/bsk_rl/envs/general_satellite_tasking/scenario/data.py @@ -1,3 +1,5 @@ +"""Data logging, management, and reward calculation.""" + import logging from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Callable, Optional @@ -19,19 +21,24 @@ class DataType(ABC): - """Base class for units of satellite data""" + """Base class for units of satellite data.""" @abstractmethod # pragma: no cover def __add__(self, other: "DataType") -> "DataType": - """Define the combination of two units of data""" + """Define the combination of two units of data.""" pass class DataStore(ABC): + """Base class for satellite data logging. + + One DataStore is created per satellite. + """ + DataType: type[DataType] # Define the unit of data used by the DataStore def __init__(self, data_manager: "DataManager", satellite: "Satellite") -> None: - """Base class for satellite data logging; one created per satellite + """Construct a DataStore base class. Args: data_manager: Simulation data manager to report back to @@ -46,23 +53,25 @@ def __init__(self, data_manager: "DataManager", satellite: "Satellite") -> None: self.new_data = self.DataType() def _initialize_knowledge(self, env_features: "EnvironmentFeatures") -> None: - """Establish knowledge about the world known to the satellite. Defaults to - knowing everything about the environment.""" + """Establish knowledge about the world known to the satellite. + + Defaults to knowing everything about the environment. + """ self.env_knowledge = env_features def _clear_logs(self) -> None: - """If necessary, clear any loggers""" + """If necessary, clear any loggers.""" pass def _get_log_state(self) -> LogStateType: - """Pull information for current data contribution e.g. sensor readings""" + """Pull information for current data contribution e.g. sensor readings.""" pass @abstractmethod # pragma: no cover def _compare_log_states( self, old_state: LogStateType, new_state: LogStateType ) -> "DataType": - """Generate a unit of data based on previous step and current step logs + """Generate a unit of data based on previous step and current step logs. Args: old_state: A previous result of _get_log_state() @@ -74,7 +83,7 @@ def _compare_log_states( pass def internal_update(self) -> "DataType": - """Update the data store based on collected information + """Update the data store based on collected information. Returns: New data from the previous step @@ -92,7 +101,7 @@ def internal_update(self) -> "DataType": return new_data def stage_communicated_data(self, external_data: "DataType") -> None: - """Prepare data to be added from another source, but don't add it yet + """Prepare data to be added from another source, but don't add it yet. Args: external_data: Data from another satellite to be added @@ -100,7 +109,7 @@ def stage_communicated_data(self, external_data: "DataType") -> None: self.staged_data.append(external_data) def communication_update(self) -> None: - """Update the data store from staged data + """Update the data store from staged data. Args: external_data (DataType): Data collected by another satellite @@ -111,12 +120,14 @@ def communication_update(self) -> None: class DataManager(ABC): + """Base class for simulation-wide data management.""" + DataStore: type[DataStore] # type of DataStore managed by the DataManager def __init__(self, env_features: Optional["EnvironmentFeatures"] = None) -> None: - """Base class for simulation-wide data management; handles data recording and - rewarding. - TODO: allow for creation/composition of multiple managers + """Construct base class to handle data recording and rewarding. + + TODO: allow for creation/composition of multiple managers. Args: env_features: Information about the environment that can be collected as @@ -126,17 +137,18 @@ def __init__(self, env_features: Optional["EnvironmentFeatures"] = None) -> None self.DataType = self.DataStore.DataType def reset(self) -> None: + """Refresh data and cumulative reward for a new episode.""" self.data = self.DataType() self.cum_reward = {} def create_data_store(self, satellite: "Satellite") -> None: - """Create a data store for a satellite""" + """Create a data store for a satellite.""" satellite.data_store = self.DataStore(self, satellite) self.cum_reward[satellite.id] = 0.0 @abstractmethod # pragma: no cover def _calc_reward(self, new_data_dict: dict[str, DataType]) -> dict[str, float]: - """Calculate step reward based on all satellite data from a step + """Calculate step reward based on all satellite data from a step. Args: new_data_dict: Satellite-DataType pairs of new data from a step @@ -147,7 +159,7 @@ def _calc_reward(self, new_data_dict: dict[str, DataType]) -> dict[str, float]: pass def reward(self, new_data_dict: dict[str, DataType]) -> dict[str, float]: - """Calls _calc_reward and logs cumulative reward""" + """Call _calc_reward and log cumulative reward.""" reward = self._calc_reward(new_data_dict) for satellite_id, sat_reward in reward.items(): self.cum_reward[satellite_id] += sat_reward @@ -159,11 +171,16 @@ def reward(self, new_data_dict: dict[str, DataType]) -> dict[str, float]: # No Data # ########### class NoData(DataType): + """DataType for no data.""" + def __add__(self, other): + """Add nothing to nothing.""" return self.__class__() class NoDataStore(DataStore): + """DataStore for no data.""" + DataType = NoData def _compare_log_states(self, old_state, new_state): @@ -171,9 +188,12 @@ def _compare_log_states(self, old_state, new_state): class NoDataManager(DataManager): + """DataManager for no data.""" + DataStore = NoDataStore def _calc_reward(self, new_data_dict): + """Reward nothing.""" return {sat: 0.0 for sat in new_data_dict.keys()} @@ -183,10 +203,12 @@ def _calc_reward(self, new_data_dict): class UniqueImageData(DataType): + """DataType for unique images of targets.""" + def __init__( self, imaged: Optional[list["Target"]] = None, duplicates: int = 0 ) -> None: - """DataType to log unique imaging + """Construct unit of data to record unique images. Args: imaged: List of targets that are known to be imaged. @@ -199,6 +221,14 @@ def __init__( self.duplicates = duplicates + len(imaged) - len(self.imaged) def __add__(self, other: "UniqueImageData") -> "UniqueImageData": + """Combine two units of data. + + Args: + other: Another unit of data to combine with this one. + + Returns: + Combined unit of data. + """ imaged = list(set(self.imaged + other.imaged)) duplicates = ( self.duplicates @@ -211,10 +241,12 @@ def __add__(self, other: "UniqueImageData") -> "UniqueImageData": class UniqueImageStore(DataStore): + """DataStore for unique images of targets.""" + DataType = UniqueImageData def _get_log_state(self) -> np.ndarray: - """Log the instantaneous storage unit state at the end of each step + """Log the instantaneous storage unit state at the end of each step. Returns: array: storedData from satellite storage unit @@ -226,8 +258,7 @@ def _get_log_state(self) -> np.ndarray: def _compare_log_states( self, old_state: np.ndarray, new_state: np.ndarray ) -> UniqueImageData: - """Checks two storage unit logs for an increase in logged data to identify new - images + """Check for an increase in logged data to identify new images. Args: old_state: older storedData from satellite storage unit @@ -252,6 +283,8 @@ def _compare_log_states( class UniqueImagingManager(DataManager): + """DataManager for rewarding unique images.""" + DataStore = UniqueImageStore def __init__( @@ -259,7 +292,7 @@ def __init__( env_features: Optional["EnvironmentFeatures"] = None, reward_fn: Callable = lambda p: p, ) -> None: - """DataManager for rewarding unique images + """DataManager for rewarding unique images. Args: env_features: DataManager.env_features @@ -271,7 +304,7 @@ def __init__( def _calc_reward( self, new_data_dict: dict[str, UniqueImageData] ) -> dict[str, float]: - """Reward new each unique image once using self.reward_fn() + """Reward new each unique image once using self.reward_fn(). Args: new_data_dict: Record of new images for each satellite @@ -391,14 +424,16 @@ def _calc_reward( # self.data += new_data # return reward -################# -# Nadir Pointing# -################# +################## +# Nadir Pointing # +################## class NadirScanningTimeData(DataType): + """DataType for time spent scanning nadir.""" + def __init__(self, scanning_time: float = 0.0) -> None: - """DataType to log data generated scanning nadir + """DataType to log data generated scanning nadir. Args: scanning_time: Time spent scanning nadir @@ -406,18 +441,19 @@ def __init__(self, scanning_time: float = 0.0) -> None: self.scanning_time = scanning_time def __add__(self, other: "NadirScanningTimeData") -> "NadirScanningTimeData": - """Define the combination of two units of data""" + """Define the combination of two units of data.""" scanning_time = self.scanning_time + other.scanning_time return self.__class__(scanning_time) class ScanningNadirTimeStore(DataStore): + """DataStore for time spent scanning nadir.""" + DataType = NadirScanningTimeData def _get_log_state(self) -> LogStateType: - """Returns the amount of data stored in the storage unit.""" - + """Return the amount of data stored in the storage unit.""" storage_unit = self.satellite.dynamics.storageUnit.storageUnitDataOutMsg.read() stored_amount = storage_unit.storageLevel @@ -427,8 +463,7 @@ def _get_log_state(self) -> LogStateType: def _compare_log_states( self, old_state: float, new_state: float ) -> "NadirScanningTimeData": - """Generate a unit of data based on previous step and current step stored - data amount. + """Generate a unit of data based on change in stored data amount. Args: old_state: Previous amount of data in the storage unit @@ -448,18 +483,21 @@ def _compare_log_states( class NadirScanningManager(DataManager): + """DataManager for rewarding time spent scanning nadir.""" + DataStore = ScanningNadirTimeStore # type of DataStore managed by the DataManager def __init__( self, env_features: Optional["EnvironmentFeatures"] = None, - reward_fn: Callable = None, + reward_fn: Optional[Callable] = None, ) -> None: - """ + """Construct a data manager for nadir scanning. + Args: env_features: Information about the environment that can be collected as data - reward_fn: Reward as function of time spend pointing nadir + reward_fn: Reward as function of time spend pointing nadir. """ super().__init__(env_features) if reward_fn is None: @@ -474,7 +512,7 @@ def reward_fn(p): def _calc_reward( self, new_data_dict: dict[str, "NadirScanningTimeData"] ) -> dict[str, float]: - """Calculate step reward based on all satellite data from a step + """Calculate step reward based on all satellite data from a step. Args: new_data_dict (dict): Satellite-DataType of new data from a step diff --git a/src/bsk_rl/envs/general_satellite_tasking/scenario/environment_features.py b/src/bsk_rl/envs/general_satellite_tasking/scenario/environment_features.py index ba84b771..79c8c862 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/scenario/environment_features.py +++ b/src/bsk_rl/envs/general_satellite_tasking/scenario/environment_features.py @@ -1,3 +1,5 @@ +"""Environment features define data available for satellites to collect.""" + import logging import os import sys @@ -12,20 +14,23 @@ class EnvironmentFeatures(ABC): - """Base environment feature class""" + """Base environment feature class.""" def reset(self) -> None: # pragma: no cover - """Reset environment features""" + """Reset environment features.""" pass class Target: + """Ground target with associated value.""" + def __init__(self, name: str, location: Iterable[float], priority: float) -> None: - """Representation of a ground target + """Construct a Target. + Args: name: Identifier; does not need to be unique location: PCPF location [m] - priority: Value metric + priority: Value metric. """ self.name = name self.location = np.array(location) @@ -33,7 +38,11 @@ def __init__(self, name: str, location: Iterable[float], priority: float) -> Non @property def id(self) -> str: - """str: Unique human-readable identifier""" + """Get unique human-readable identifier. + + Returns: + Unique human-readable identifier. + """ try: return self._id except AttributeError: @@ -41,20 +50,31 @@ def id(self) -> str: return self._id def __hash__(self) -> int: + """Hash target by unique id.""" return hash((self.id)) def __repr__(self) -> str: + """Get string representation of target. + + Use target.id for a unique string identifier. + + Returns: + Target string + """ return f"Target({self.name})" class StaticTargets(EnvironmentFeatures): + """Environment with targets distributed uniformly.""" + def __init__( self, n_targets: Union[int, tuple[int, int]], priority_distribution: Optional[Callable] = None, radius: float = orbitalMotion.REQ_EARTH * 1e3, ) -> None: - """Environment with a set number of evenly-distributed static targets. + """Construct an environment with evenly-distributed static targets. + Args: n_targets: Number (or range) of targets to generate priority_distribution: Function for generating target priority. @@ -68,6 +88,7 @@ def __init__( self.targets = [] def reset(self) -> None: + """Regenerate target set for new episode.""" if isinstance(self._n_targets, int): self.n_targets = self._n_targets else: @@ -76,7 +97,7 @@ def reset(self) -> None: self.regenerate_targets() def regenerate_targets(self) -> None: - """Regenerate targets uniformly""" + """Regenerate targets uniformly.""" self.targets = [] for i in range(self.n_targets): x = np.random.normal(size=3) @@ -89,7 +110,8 @@ def regenerate_targets(self) -> None: def lla2ecef(lat: float, long: float, radius: float): - """ + """Project LLA to Earth Centered, Earth Fixed location. + Args: lat: [deg] long: [deg] @@ -103,6 +125,8 @@ def lla2ecef(lat: float, long: float, radius: float): class CityTargets(StaticTargets): + """Environment with targets distributed around population centers.""" + def __init__( self, n_targets: Union[int, tuple[int, int]], @@ -111,7 +135,8 @@ def __init__( priority_distribution: Optional[Callable] = None, radius: float = orbitalMotion.REQ_EARTH * 1e3, ) -> None: - """Environment with a set number of static targets around population centers. + """Construct environment with of static targets around population centers. + Args: n_targets: Number of targets to generate n_select_from: Generate targets from the top n most populous. @@ -126,7 +151,7 @@ def __init__( self.location_offset = location_offset def regenerate_targets(self) -> None: - """Regenerate targets based on cities""" + """Regenerate targets based on cities.""" self.targets = [] cities = pd.read_csv( os.path.join( @@ -157,12 +182,11 @@ def regenerate_targets(self) -> None: class UniformNadirFeature(EnvironmentFeatures): - """ - Defines a nadir target center at the center of the planet. - """ + """Defines a nadir target center at the center of the planet.""" def __init__(self, value_per_second: float = 1.0) -> None: - """ " + """Construct uniform data over the surface of the planet. + Args: value_per_second: Amount of reward per second imaging nadir. """ diff --git a/src/bsk_rl/envs/general_satellite_tasking/scenario/sat_actions.py b/src/bsk_rl/envs/general_satellite_tasking/scenario/sat_actions.py index 29d483ed..a50ac081 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/scenario/sat_actions.py +++ b/src/bsk_rl/envs/general_satellite_tasking/scenario/sat_actions.py @@ -1,3 +1,5 @@ +"""Satellite action types can be used to add actions to the agents.""" + from copy import deepcopy from typing import Any, Optional, Union @@ -19,10 +21,14 @@ class SatAction(Satellite): class DiscreteSatAction(SatAction): + """Base satellite subclass for composing discrete actions.""" + def __init__(self, *args, **kwargs) -> None: - """Base satellite subclass for composing discrete actions. Actions are added to - the satellite for each DiscreteSatAction subclass, and can be accessed by index - in order added.""" + """Construct satellite with discrete actions. + + Actions are added to the satellite for each DiscreteSatAction subclass, and can + be accessed by index in order added. + """ super().__init__(*args, **kwargs) self.action_list = [] self.action_map = {} @@ -57,8 +63,15 @@ def add_action( self.action_list.append(bind(self, deepcopy(act_i))) def generate_indexed_action(self, act_fn, index: int): - """Create a indexed action function from an action function that takes an index - as an argument""" + """Create an indexed action function. + + Makes an indexed action function from an action function that takes an index + as an argument. + + Args: + act_fn: Action function to index. + index: Index to pass to act_fn. + """ def act_i(self, prev_action_key=None) -> Any: return getattr(self, act_fn.__name__)( @@ -68,7 +81,7 @@ def act_i(self, prev_action_key=None) -> Any: return act_i def set_action(self, action: int): - """Function called by the environment when setting action""" + """Call action function my index.""" self._disable_timed_terminal_event() self.prev_action_key = self.action_list[action]( prev_action_key=self.prev_action_key @@ -81,16 +94,30 @@ def action_space(self) -> spaces.Discrete: def fsw_action_gen(fsw_action: str, action_duration: float = 1e9) -> type: + """Generate an action class for a FSW @action. + + Args: + fsw_action: Function name of FSW action. + action_duration: Time to task action for. + + Returns: + Satellite action class with fsw_action action. + """ + @configurable class FSWAction(DiscreteSatAction): def __init__( self, *args, action_duration: float = action_duration, **kwargs ) -> None: - """Discrete action to perform a fsw action; typically this is a function - decorated by @action + """Discrete action to perform a fsw action. + + Typically this is includes a function decorated by @action. Args: action_duration: Time to act when action selected. [s] + args: Passed through to satellite + kwargs: Passed through to satellite + """ super().__init__(*args, **kwargs) setattr(self, fsw_action + "_duration", action_duration) @@ -136,11 +163,15 @@ def act(self, prev_action_key=None) -> str: @configurable class ImagingActions(DiscreteSatAction, ImagingSatellite): + """Satellite subclass to add upcoming target imaging to action space.""" + def __init__(self, *args, n_ahead_act=10, **kwargs) -> None: """Discrete action to image upcoming targets. Args: n_ahead_act: Number of actions to include in action space. + args: Passed through to satellite + kwargs: Passed through to satellite """ super().__init__(*args, **kwargs) self.add_action(self.image, n_actions=n_ahead_act, act_name="image") @@ -150,6 +181,7 @@ def image(self, target: Union[int, Target, str], prev_action_key=None) -> str: Args: target: Target, in terms of upcoming index, Target, or ID, + prev_action_key: Previous action key Returns: Target ID @@ -166,8 +198,10 @@ def image(self, target: Union[int, Target, str], prev_action_key=None) -> str: return target.id def set_action(self, action: Union[int, Target, str]): - """Allow the satellite to be tasked by Target or target id, in addition to - index""" + """Allow the satellite to be tasked by Target or target id. + + Allows for additional tasking modes in addition to action index-based tasking. + """ self._disable_image_event() if isinstance(action, (Target, str)): self.prev_action_key = self.image(action, self.prev_action_key) diff --git a/src/bsk_rl/envs/general_satellite_tasking/scenario/sat_observations.py b/src/bsk_rl/envs/general_satellite_tasking/scenario/sat_observations.py index 5c55cd6b..5280c6f4 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/scenario/sat_observations.py +++ b/src/bsk_rl/envs/general_satellite_tasking/scenario/sat_observations.py @@ -1,3 +1,5 @@ +"""Satellite observation types can be used to add information to the observation.""" + from copy import deepcopy from typing import Any, Callable, Optional, Union @@ -18,11 +20,15 @@ @configurable class SatObservation(Satellite): + """Base satellite subclass for composing observations.""" + def __init__(self, *args, obs_type: type = np.ndarray, **kwargs) -> None: """Satellite subclass for composing observations. Args: obs_type: Datatype of satellite's returned observation + args: Passed through to satellite + kwargs: Passed through to satellite """ super().__init__(*args, **kwargs) self.obs_type = obs_type @@ -32,8 +38,10 @@ def __init__(self, *args, obs_type: type = np.ndarray, **kwargs) -> None: @property def obs_dict(self): - """Human-readable observation format. Cached so only computed once per - timestep.""" + """Human-readable observation format. + + Cached so only computed once per timestep. + """ if ( self.obs_dict_cache is None or self.simulator.sim_time != self.obs_cache_time @@ -44,16 +52,16 @@ def obs_dict(self): @property def obs_ndarray(self): - """Numpy vector observation format""" + """Numpy vector observation format.""" return vectorize_nested_dict(self.obs_dict) @property def obs_list(self): - """List observation format""" + """List observation format.""" return list(self.obs_ndarray) def get_obs(self) -> Union[dict, np.ndarray, list]: - """Update the observation""" + """Update the observation.""" if self.obs_type is dict: return self.obs_dict elif self.obs_type is np.ndarray: @@ -64,7 +72,7 @@ def get_obs(self) -> Union[dict, np.ndarray, list]: raise ValueError(f"Invalid observation type: {self.obs_type}") def add_to_observation(self, obs_element: Callable) -> None: - """Add a function to be called when constructing observations + """Add a function to be called when constructing observations. Args: obs_element: Callable to be observed @@ -74,6 +82,8 @@ def add_to_observation(self, obs_element: Callable) -> None: @configurable class NormdPropertyState(SatObservation): + """Satellite subclass to add satellites properties to the observation.""" + def __init__( self, *args, obs_properties: list[dict[str, Any]] = [], **kwargs ) -> None: @@ -86,8 +96,9 @@ def __init__( [dict(prop="prop_name", module="fsw"/"dynamics"/None, norm=1.0)] If module is not specified or None, the source of the property is inferred. If norm is not specified, it is set to 1.0 (no normalization). + args: Passed through to satellite + kwargs: Passed through to satellite """ - super().__init__(*args, **kwargs) for obs_prop in obs_properties: @@ -102,9 +113,6 @@ def add_prop_function( prop: Property to query module: Module (dynamics or fsw) that holds the property. Can be inferred. norm: Value to normalize property by. Defaults to 1.0. - - Returns: - _type_: _description_ """ if module is not None: @@ -129,32 +137,39 @@ def prop_fn(self): @configurable class TimeState(SatObservation): + """Satellite subclass to add simulation time to the observation.""" + def __init__(self, *args, normalization_time: Optional[float] = None, **kwargs): - """Adds the sim time to the observation state. Automatically normalizes to the - sim duration. + """Add the sim time to the observation state. + + Automatically normalizes to the sim duration. Args: normalization_time: Time to normalize by. If None, is set to simulation duration + args: Passed through to satellite + kwargs: Passed through to satellite """ - super().__init__(*args, **kwargs) self.normalization_time = normalization_time self.add_to_observation(self.normalized_time) def reset_post_sim(self): - """Autodetect normalization time""" + """Autodetect normalization time.""" super().reset_post_sim() if self.normalization_time is None: self.normalization_time = self.simulator.time_limit def normalized_time(self): + """Return time normalized by normalization_time.""" assert self.normalization_time is not None return self.simulator.sim_time / self.normalization_time @configurable class TargetState(SatObservation, ImagingSatellite): + """Satellite subclass to add upcoming target information to the observation.""" + def __init__( self, *args, @@ -162,7 +177,7 @@ def __init__( target_properties: Optional[list[dict[str, Any]]] = None, **kwargs, ): - """Adds information about upcoming targets to the observation state. + """Add information about upcoming targets to the observation state. Args: n_ahead_observe: Number of upcoming targets to consider. @@ -174,6 +189,8 @@ def __init__( - window_open - window_mid - window_close + args: Passed through to satellite + kwargs: Passed through to satellite """ super().__init__(*args, n_ahead_observe=n_ahead_observe, **kwargs) if target_properties is None: @@ -189,7 +206,9 @@ def __init__( self.target_obs_generator(target_properties) def target_obs_generator(self, target_properties): - """Generate the target_obs function from the target_properties spec and add it + """Generate the target_obs function. + + Generates the observation function from the target_properties spec and add it to the observation. """ @@ -231,18 +250,22 @@ def target_obs(self): @configurable class EclipseState(SatObservation): + """Satellite subclass to add upcoming eclipse information to the observation.""" + def __init__(self, *args, orbit_period=5700, **kwargs): - """Adds a tuple of the orbit-normalized next orbit start and end. + """Add a tuple of the orbit-normalized next orbit start and end. Args: orbit_period: Normalization factor for eclipse time. + args: Passed through to satellite + kwargs: Passed through to satellite """ - super().__init__(*args, **kwargs) self.orbit_period_eclipse_norm = orbit_period self.add_to_observation(self.eclipse_state) def eclipse_state(self): + """Return tuple of normalized next eclipse start and end.""" eclipse_start, eclipse_end = self.trajectory.next_eclipse( self.simulator.sim_time ) @@ -254,6 +277,8 @@ def eclipse_state(self): @configurable class GroundStationState(SatObservation, AccessSatellite): + """Satellite subclass to add ground station information to the observation.""" + def __init__( self, *args, @@ -261,18 +286,21 @@ def __init__( downlink_window_properties: Optional[list[dict[str, Any]]] = None, **kwargs, ): - """Adds information about upcoming downlink opportunities to the observation - state. + """Add information about downlink opportunities to the observation state. Args: - n_ahead_observe: Number of upcoming downlink opportunities to consider. - target_properties: List of properties to include in the observation in the - format [dict(prop="prop_name", norm=norm)]. If norm is not specified, it - is set to 1.0 (no normalization). Properties to choose from: + n_ahead_observe_downlinks: Number of upcoming downlink opportunities to + consider. + downlink_window_properties: List of properties to include in the observation + in the format [dict(prop="prop_name", norm=norm)]. If norm is not + specified, it is set to 1.0 (no normalization). Properties to choose + from: - location - window_open - window_mid - window_close + args: Passed through to satellite + kwargs: Passed through to satellite """ super().__init__(*args, **kwargs) if downlink_window_properties is None: @@ -285,7 +313,7 @@ def __init__( ) def reset_post_sim(self) -> None: - """Add downlink ground stations to be considered by the access checker""" + """Add downlink ground stations to be considered by the access checker.""" for ground_station in self.simulator.environment.groundStations: self.add_location_for_access_checking( object=ground_station.ModelTag, @@ -299,9 +327,11 @@ def ground_station_obs_generator( self, downlink_window_properties: list[dict[str, Any]], n_ahead_observe_downlinks: int, - ): - """Generate the ground_station_obs function from the downlink_window_properties - spec and add it to the observation. + ) -> None: + """Generate the ground_station_obs function. + + Generates an obs function from the downlink_window_properties spec and adds it + to the observation. """ def ground_station_obs(self): diff --git a/src/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py b/src/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py index 0bde02e8..7426ed74 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py +++ b/src/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py @@ -1,3 +1,5 @@ +"""Satellites are the agents in the environment.""" + import bisect import inspect import logging @@ -38,12 +40,14 @@ class Satellite(ABC): + """Abstract base class for satellites.""" + dyn_type: type["DynamicsModel"] # Type of dynamics model used by this satellite fsw_type: type["FSWModel"] # Type of FSW model used by this satellite @classmethod def default_sat_args(cls, **kwargs) -> dict[str, Any]: - """Compile default arguments for FSW and dynamics models + """Compile default arguments for FSW and dynamics models. Returns: default arguments for satellite models @@ -71,7 +75,7 @@ def __init__( variable_interval: bool = True, **kwargs, ) -> None: - """Base satellite constructor + """Construct base satellite. Args: name: identifier for satellite; does not need to be unique @@ -79,6 +83,7 @@ def __init__( key: function}, where function is called at reset to set the value (used for randomization). variable_interval: Stop simulation at terminal events + kwargs: Ignored """ self.name = name self.logger = logging.getLogger(__name__).getChild(self.name) @@ -95,18 +100,18 @@ def __init__( @property def id(self) -> str: - """Unique human-readable identifier""" + """Unique human-readable identifier.""" return f"{self.name}_{id(self)}" def _generate_sat_args(self) -> None: - """Instantiate sat_args from any randomizers in provided sat_args""" + """Instantiate sat_args from any randomizers in provided sat_args.""" self.sat_args = { k: v if not callable(v) else v() for k, v in self.sat_args_generator.items() } self.logger.debug(f"Satellite initialized with {self.sat_args}") def reset_pre_sim(self) -> None: - """Called in environment reset, before simulator initialization""" + """Reset during environment reset, before simulator initialization.""" self.info = [] self.requires_retasking = True self._generate_sat_args() @@ -123,7 +128,9 @@ def reset_pre_sim(self) -> None: self._timed_terminal_event_name = None def set_simulator(self, simulator: "Simulator"): - """Sets the simulator for models; called during simulator initialization + """Set the simulator for models. + + Called during simulator initialization. Args: simulator: Basilisk simulator @@ -131,7 +138,7 @@ def set_simulator(self, simulator: "Simulator"): self.simulator = proxy(simulator) def set_dynamics(self, dyn_rate: float) -> "DynamicsModel": - """Create dynamics model; called during simulator initialization + """Create dynamics model; called during simulator initialization. Args: dyn_rate: rate for dynamics simulation [s] @@ -144,7 +151,7 @@ def set_dynamics(self, dyn_rate: float) -> "DynamicsModel": return dynamics def set_fsw(self, fsw_rate: float) -> "FSWModel": - """Create flight software model; called during simulator initialization + """Create flight software model; called during simulator initialization. Args: fsw_rate: rate for FSW simulation [s] @@ -157,12 +164,12 @@ def set_fsw(self, fsw_rate: float) -> "FSWModel": return fsw def reset_post_sim(self) -> None: - """Called in environment reset, after simulator initialization""" + """Reset in environment reset, after simulator initialization.""" pass @property def observation_space(self) -> spaces.Box: - """Observation space for single satellite, determined from observation + """Observation space for single satellite, determined from observation. Returns: gymanisium observation space @@ -174,7 +181,7 @@ def observation_space(self) -> spaces.Box: @property @abstractmethod # pragma: no cover def action_space(self) -> spaces.Space: - """Action space for single satellite + """Action space for single satellite. Returns: gymanisium action space @@ -182,11 +189,12 @@ def action_space(self) -> spaces.Space: pass def is_alive(self, log_failure=False) -> bool: - """Check if the satellite is violating any requirements from dynamics or FSW - models + """Check if the satellite is violating any aliveness requirements. + + Checkes aliveness checkers in dynamics and FSW models. Returns: - is alive + is_alive """ return self.dynamics.is_alive(log_failure=log_failure) and self.fsw.is_alive( log_failure=log_failure @@ -194,14 +202,14 @@ def is_alive(self, log_failure=False) -> bool: @property def _satellite_command(self) -> str: - """Generate string that refers to self in simBase""" + """Generate string that refers to self in simBase.""" return ( "[satellite for satellite in self.satellites " + f"if satellite.id=='{self.id}'][0]" ) def _info_command(self, info: str) -> str: - """Generate command to log to info from an event + """Generate command to log to info from an event. Args: info: information to log; cannot include `'` or `"` @@ -212,7 +220,7 @@ def _info_command(self, info: str) -> str: return self._satellite_command + f".log_info('{info}')" def log_info(self, info: Any) -> None: - """Record information at the current time + """Record information at the current time. Args: info: Information to log @@ -221,13 +229,14 @@ def log_info(self, info: Any) -> None: self.logger.info(f"{info}") def _update_timed_terminal_event( - self, t_close: float, info: str = "", extra_actions=[] + self, t_close: float, info: str = "", extra_actions: list[str] = [] ) -> None: - """Create a simulator event that causes the simulation to stop at a certain time + """Create a simulator event that stops the simulation a certain time. Args: t_close: Termination time [s] info: Additional identifying info to log at terminal time + extra_actions: Additional actions to perform at terminal time """ self._disable_timed_terminal_event() self.log_info(f"setting timed terminal event at {t_close:.1f}") @@ -251,8 +260,7 @@ def _update_timed_terminal_event( self.simulator.eventMap[self._timed_terminal_event_name].eventActive = True def _disable_timed_terminal_event(self) -> None: - """Turn off simulator termination due to this satellite's window close - checker""" + """Turn off simulator termination due to window close checker.""" if ( self._timed_terminal_event_name is not None and self._timed_terminal_event_name in self.simulator.eventMap @@ -261,7 +269,7 @@ def _disable_timed_terminal_event(self) -> None: @abstractmethod # pragma: no cover def get_obs(self) -> SatObs: - """Construct the satellite's observation + """Construct the satellite's observation. Returns: satellite observation @@ -270,8 +278,9 @@ def get_obs(self) -> SatObs: @abstractmethod # pragma: no cover def set_action(self, action: int) -> None: - """Enables certain processes in the simulator to command the satellite task. - Should call an @action from FSW, among other things. + """Enable certain processes in the simulator to command the satellite task. + + Should call an @action from FSW, among other things. Args: action: action index @@ -280,6 +289,8 @@ def set_action(self, action: int) -> None: class AccessSatellite(Satellite): + """Satellite that can detect access opportunities for ground locations.""" + def __init__( self, *args, @@ -288,8 +299,7 @@ def __init__( access_dist_threshold: float = 4e6, **kwargs, ) -> None: - """Satellite that can detect access opportunities for ground locations with - elevation constraints. + """Construct an AccessSatellite. Args: generation_duration: Duration to calculate additional imaging windows for @@ -299,6 +309,8 @@ def __init__( [s] access_dist_threshold: Distance bound [m] for evaluating imaging windows more exactly. 4e6 will capture >10 elevation windows for a 500 km orbit. + args: Passed through to Satellite constructor + kwargs: Passed through to Satellite constructor """ super().__init__(*args, **kwargs) self.generation_duration = generation_duration @@ -306,6 +318,7 @@ def __init__( self.access_dist_threshold = access_dist_threshold def reset_pre_sim(self) -> None: + """Reset satellite window calculations and lists.""" super().reset_pre_sim() self.opportunities: list[dict] = [] self.window_calculation_time = 0 @@ -318,8 +331,10 @@ def add_location_for_access_checking( min_elev: float, type: str, ) -> None: - """Adds a location to be included in window calculations. Note that this - location will only be included in future calls to calculate_additional_windows. + """Add a location to be included in window calculations. + + Note that this location will only be included in future calls to + calculate_additional_windows. Args: object: Object to add window for @@ -332,6 +347,7 @@ def add_location_for_access_checking( self.locations_for_access_checking.append(location_dict) def reset_post_sim(self) -> None: + """Handle initial window calculations for new simulation.""" super().reset_post_sim() if self.initial_generation_duration is None: if self.simulator.time_limit == float("inf"): @@ -404,8 +420,11 @@ def _find_elevation_roots( min_elev: float, window: tuple[float, float], ): - """Find exact times where the satellite's elevation relative to a target is - equal to the minimum elevation.""" + """Find times where the elevation is equal to the minimum elevation. + + Finds exact times where the satellite's elevation relative to a target is + equal to the minimum elevation. + """ def root_fn(t): return elevation(position_interp(t), location) - min_elev @@ -422,9 +441,11 @@ def root_fn(t): def _find_candidate_windows( location: np.ndarray, times: np.ndarray, positions: np.ndarray, threshold: float ) -> list[tuple[float, float]]: - """Find `times` where a window is plausible; i.e. where a `positions` point is - within `threshold` of `location`. Too big of a dt in times may miss windows or - produce bad results.""" + """Find `times` where a window is plausible. + + i.e. where a `positions` point is within `threshold` of `location`. Too big of + a dt in times may miss windows or produce bad results. + """ close_times = np.linalg.norm(positions - location, axis=1) < threshold close_indices = np.where(close_times)[0] groups = np.split(close_indices, np.where(np.diff(close_indices) != 1)[0] + 1) @@ -442,8 +463,7 @@ def _refine_window( candidate_window: tuple[float, float], computation_window: tuple[float, float], ) -> list[tuple[float, float]]: - """Detect if an exact window has been truncated by the edge of the coarse - window.""" + """Detect if an exact window has been truncated by a coarse window.""" endpoints = list(endpoints) # Filter endpoints that are too close @@ -474,13 +494,14 @@ def _add_window( type: str, merge_time: Optional[float] = None, ): - """ + """Add an opportunity window. + Args: object: Object to add window for new_window: New window for target type: Type of window being added merge_time: Time at which merges with existing windows will occur. If None, - check all windows for merges + check all windows for merges. """ if new_window[0] == merge_time or merge_time is None: for opportunity in self.opportunities: @@ -517,7 +538,7 @@ def opportunities_dict( types: Optional[Union[str, list[str]]] = None, filter: list = [], ) -> dict[Any, list[tuple[float, float]]]: - """Dictionary of opportunities that maps objects to lists of windows. + """Make dictionary of opportunities that maps objects to lists of windows. Args: types: Types of opportunities to include. If None, include all types. @@ -543,8 +564,9 @@ def upcoming_opportunities_dict( types: Optional[Union[str, list[str]]] = None, filter: list = [], ) -> dict[Any, list[tuple[float, float]]]: - """Dictionary of opportunities that maps objects to lists of windows that have - not yet closed. + """Get dictionary of opportunities. + + Maps objects to lists of windows that have not yet closed. Args: types: Types of opportunities to include. If None, include all types. @@ -570,7 +592,7 @@ def next_opportunities_dict( types: Optional[Union[str, list[str]]] = None, filter: list = [], ) -> dict[Any, tuple[float, float]]: - """Dictionary of opportunities that maps objects to the next open window. + """Make dictionary of opportunities that maps objects to the next open windows. Args: types: Types of opportunities to include. If None, include all types. @@ -636,6 +658,8 @@ def find_next_opportunities( class ImagingSatellite(AccessSatellite): + """Satellite with agile imaging capabilities.""" + dyn_type = dynamics.ImagingDynModel fsw_type = fsw.ImagingFSWModel @@ -644,8 +668,9 @@ def __init__( *args, **kwargs, ) -> None: - """Satellite with agile imaging capabilities. Can stop the simulation when a - target is imaged or missed. + """Construct an ImagingSatellite. + + Can stop the simulation when a target is imaged or missed. """ super().__init__(*args, **kwargs) self.fsw: ImagingSatellite.fsw_type @@ -653,7 +678,7 @@ def __init__( self.data_store: UniqueImageStore def reset_pre_sim(self) -> None: - """Set the buffer parameters based on computed windows""" + """Set the buffer parameters based on computed windows.""" super().reset_pre_sim() self.sat_args["transmitterNumBuffers"] = len( self.data_store.env_knowledge.targets @@ -666,7 +691,7 @@ def reset_pre_sim(self) -> None: self.missed = 0 def reset_post_sim(self) -> None: - """Handle initial_generation_duration setting and calculate windows""" + """Handle initial_generation_duration setting and calculate windows.""" for target in self.data_store.env_knowledge.targets: self.add_location_for_access_checking( object=target, @@ -684,7 +709,7 @@ def _get_imaged_filter(self): @property def windows(self) -> dict[Target, list[tuple[float, float]]]: - """Access windows via dict of targets -> list of windows""" + """Access windows via dict of targets -> list of windows.""" return self.opportunities_dict(types="target", filter=self._get_imaged_filter()) @property @@ -708,8 +733,9 @@ def next_windows(self) -> dict[Target, tuple[float, float]]: def upcoming_targets( self, n: int, pad: bool = True, max_lookahead: int = 100 ) -> list[Target]: - """Find the n nearest targets. Targets are sorted by window close time; - currently open windows are included. + """Find the n nearest targets. + + Targets are sorted by window close time; currently open windows are included. Args: n: number of windows to look ahead @@ -732,8 +758,9 @@ def upcoming_targets( ] def _update_image_event(self, target: Target) -> None: - """Create a simulator event that causes the simulation to stop when a target is - imaged + """Create a simulator event that terminates on imaging. + + Causes the simulation to stop when a target is imaged. Args: target: Target expected to be imaged @@ -772,7 +799,7 @@ def _update_image_event(self, target: Target) -> None: self.simulator.eventMap[self._image_event_name].eventActive = True def _disable_image_event(self) -> None: - """Turn off simulator termination due to this satellite's imaging checker""" + """Turn off simulator termination due to this satellite's imaging checker.""" if ( self._image_event_name is not None and self._image_event_name in self.simulator.eventMap @@ -780,8 +807,9 @@ def _disable_image_event(self) -> None: self.simulator.delete_event(self._image_event_name) def parse_target_selection(self, target_query: Union[int, Target, str]): - """Identify a target based on upcoming target index, Target object, or target - id. + """Identify a target from a query. + + Parses an upcoming target index, Target object, or target id. Args: target_query: Taret upcoming index, object, or id. @@ -802,7 +830,7 @@ def parse_target_selection(self, target_query: Union[int, Target, str]): return target def enable_target_window(self, target: Target): - """Enable the next window close event for target""" + """Enable the next window close event for target.""" self._update_image_event(target) next_window = self.next_windows[target] self.log_info( @@ -815,7 +843,7 @@ def enable_target_window(self, target: Target): ) def task_target_for_imaging(self, target: Target): - """Task the satellite to image a target + """Task the satellite to image a target. Args: target: Selected target @@ -830,11 +858,15 @@ def task_target_for_imaging(self, target: Target): ### Convenience Types ### ######################### class SteeringImagerSatellite(ImagingSatellite): + """Convenience type for an imaging satellite with MRP steering.""" + dyn_type = dynamics.FullFeaturedDynModel fsw_type = fsw.SteeringImagerFSWModel class FBImagerSatellite(ImagingSatellite): + """Convenience type for an imaging satellite with feedback control.""" + dyn_type = dynamics.FullFeaturedDynModel fsw_type = fsw.ImagingFSWModel @@ -853,6 +885,8 @@ class FBImagerSatellite(ImagingSatellite): class DoNothingSatellite(sa.DriftAction, so.TimeState): + """Convenience type for a satellite that does nothing.""" + dyn_type = dynamics.BasicDynamicsModel fsw_type = fsw.BasicFSWModel @@ -872,6 +906,8 @@ class ImageAheadSatellite( ), SteeringImagerSatellite, ): + """Convenience type for a satellite with common features enabled.""" + pass @@ -892,4 +928,6 @@ class FullFeaturedSatellite( ), SteeringImagerSatellite, ): + """Convenience type for a satellite with common features enabled.""" + pass diff --git a/src/bsk_rl/envs/general_satellite_tasking/simulation/dynamics.py b/src/bsk_rl/envs/general_satellite_tasking/simulation/dynamics.py index 6d010923..e8265dd6 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/simulation/dynamics.py +++ b/src/bsk_rl/envs/general_satellite_tasking/simulation/dynamics.py @@ -1,3 +1,6 @@ +"""Basilisk dynamics models.""" + + from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Iterable, Optional @@ -43,9 +46,15 @@ class DynamicsModel(ABC): + """Abstract Basilisk dynamics model. + + One DynamicsModel is instantiated for each satellite in the environment each time a + new simulator is created. + """ + @classmethod @property - def requires_env(cls) -> list[type["EnvironmentModel"]]: + def _requires_env(cls) -> list[type["EnvironmentModel"]]: """Define minimum EnvironmentModels for compatibility.""" return [] @@ -56,17 +65,18 @@ def __init__( priority: int = 200, **kwargs, ) -> None: - """Base DynamicsModel + """Construct a base dynamics model. Args: satellite: Satellite modelled by this model dyn_rate: Rate of dynamics simulation [s] priority: Model priority. + kwargs: Ignored """ self.satellite = satellite self.logger = self.satellite.logger.getChild(self.__class__.__name__) - for required in self.requires_env: + for required in self._requires_env: if not issubclass(type(self.simulator.environment), required): raise TypeError( f"{self.simulator.environment} must be a subclass of {required} to " @@ -87,15 +97,17 @@ def __init__( @property def simulator(self) -> "Simulator": + """Reference to the episode simulator.""" return self.satellite.simulator @property def environment(self) -> "EnvironmentModel": + """Reference to the episode environment model.""" return self.simulator.environment @abstractmethod # pragma: no cover def _init_dynamics_objects(self, **kwargs) -> None: - """Caller for all dynamics object initialization""" + """Caller for all dynamics object initialization.""" pass def is_alive(self, log_failure=False) -> bool: @@ -107,52 +119,60 @@ def is_alive(self, log_failure=False) -> bool: return check_aliveness_checkers(self, log_failure=log_failure) def reset_for_action(self) -> None: - """Called whenever a FSW @action is called""" + """Reset whenever a FSW @action is called.""" pass def __del__(self): + """Log when dynamics are deleted.""" self.logger.debug("Basilisk dynamics deleted") class BasicDynamicsModel(DynamicsModel): - """Minimal set of Basilisk dynamics objects""" + """Basic Dynamics model with minimum necessary Basilisk components.""" @classmethod @property - def requires_env(cls) -> list[type["EnvironmentModel"]]: + def _requires_env(cls) -> list[type["EnvironmentModel"]]: return [environment.BasicEnvironmentModel] @property def sigma_BN(self): + """Body attitude MRP relative to inertial frame.""" return self.scObject.scStateOutMsg.read().sigma_BN @property def BN(self): + """Body relative to inertial frame rotation matrix.""" return RigidBodyKinematics.MRP2C(self.sigma_BN) @property def omega_BN_B(self): + """Body rate relative to inertial frame in body frame [rad/s].""" return self.scObject.scStateOutMsg.read().omega_BN_B @property def BP(self): + """Body relative to planet freame rotation matrix.""" return np.matmul(self.BN, self.environment.PN.T) @property def r_BN_N(self): + """Body position relative to inertial origin in inertial frame [m].""" return self.scObject.scStateOutMsg.read().r_BN_N @property def r_BN_P(self): + """Body position relative to inertial origin in planet frame [m].""" return np.matmul(self.environment.PN, self.r_BN_N) @property def v_BN_N(self): + """Body velocity relative to inertial origin in inertial frame [m/s].""" return self.scObject.scStateOutMsg.read().v_BN_N @property def v_BN_P(self): - """P-frame derivative of r_BN""" + """P-frame derivative of r_BN.""" omega_NP_P = np.matmul(self.environment.PN, -self.environment.omega_PN_N) return np.matmul(self.environment.PN, self.v_BN_N) + np.cross( omega_NP_P, self.r_BN_P @@ -160,24 +180,29 @@ def v_BN_P(self): @property def omega_BP_P(self): + """Body angular velocity relative to planet frame in plant frame [rad/s].""" omega_BN_N = np.matmul(self.BN.T, self.omega_BN_B) omega_BP_N = omega_BN_N - self.environment.omega_PN_N return np.matmul(self.environment.PN, omega_BP_N) @property def battery_charge(self): + """Battery charge [W*s].""" return self.powerMonitor.batPowerOutMsg.read().storageLevel @property def battery_charge_fraction(self): + """Battery charge as a fraction of capacity.""" return self.battery_charge / self.powerMonitor.storageCapacity @property def wheel_speeds(self): + """Wheel speeds [rad/s].""" return np.array(self.rwStateEffector.rwSpeedOutMsg.read().wheelSpeeds) @property def wheel_speeds_fraction(self): + """Wheel speeds normalized by maximum.""" return self.wheel_speeds / (self.maxWheelSpeed * macros.rpm2radsec) def _init_dynamics_objects(self, **kwargs) -> None: @@ -219,7 +244,7 @@ def _set_spacecraft_hub( priority: int = 2000, **kwargs, ) -> None: - """Defines the spacecraft object properties. + """Set the spacecraft object properties. Args: mass: Hub mass [kg] @@ -233,6 +258,7 @@ def _set_spacecraft_hub( oe: (a, e, i, AN, AP, f); alternative to rN, vN [km, rad] mu: Gravitational parameter (used only with oe) priority: Model priority. + kwargs: Ignored """ if rN is not None and vN is not None and oe is None: pass @@ -280,6 +306,7 @@ def _set_disturbance_torque( Args: disturbance_vector: Constant disturbance torque [N*m]. + kwargs: Ignored """ if disturbance_vector is None: disturbance_vector = np.array([0, 0, 0]) @@ -309,6 +336,7 @@ def _set_drag_effector( height: Hub height [m] panelArea: Solar panel surface area [m**2] priority: Model priority. + kwargs: Ignored """ self.dragEffector = facetDragDynamicEffector.FacetDragDynamicEffector() self.dragEffector.ModelTag = "FacetDrag" @@ -342,10 +370,11 @@ def _set_drag_effector( ) def _set_simple_nav_object(self, priority: int = 1400, **kwargs) -> None: - """Defines the navigation module. + """Make the navigation module. Args: priority: Model priority. + kwargs: Ignored """ self.simpleNavObject = simpleNav.SimpleNav() self.simpleNavObject.ModelTag = "SimpleNav" @@ -356,8 +385,10 @@ def _set_simple_nav_object(self, priority: int = 1400, **kwargs) -> None: @aliveness_checker def altitude_valid(self) -> bool: - """Check for deorbit by checking if altitude is greater than 200km above Earth's - surface.""" + """Check that satellite has not deorbited. + + Checks if altitude is greater than 200km above Earth's surface. + """ return np.linalg.norm(self.r_BN_N) > (orbitalMotion.REQ_EARTH + 200) * 1e3 @default_args( @@ -373,12 +404,14 @@ def _set_reaction_wheel_dyn_effector( priority: int = 997, **kwargs, ) -> None: - """Defines the RW state effector. + """Set the RW state effector parameters. Args: wheelSpeeds: Initial speeds of each wheel [RPM] maxWheelSpeed: Failure speed for wheels [RPM] + u_max: Torque producible by wheel [N*m] priority: Model priority. + kwargs: Ignored """ self.maxWheelSpeed = maxWheelSpeed self.rwStateEffector, self.rwFactory, _ = aP.balancedHR16Triad( @@ -404,7 +437,7 @@ def rw_speeds_valid(self) -> bool: return valid def _set_thruster_dyn_effector(self, priority: int = 996) -> None: - """Defines the thruster state effector. + """Make the thruster state effector. Args: priority: Model priority. @@ -420,13 +453,13 @@ def _set_thruster_dyn_effector(self, priority: int = 996) -> None: def _set_thruster_power( self, thrusterPowerDraw, priority: int = 899, **kwargs ) -> None: - """Defines the thruster power draw. + """Set the thruster power draw. Args: thrusterPowerDraw: Constant power draw desat mode is active. [W] priority: Model priority. + kwargs: Ignored """ - self.thrusterPowerSink = simplePowerSink.SimplePowerSink() self.thrusterPowerSink.ModelTag = "thrustPowerSink" + self.satellite.id self.thrusterPowerSink.nodePowerOut = thrusterPowerDraw # Watts @@ -436,7 +469,7 @@ def _set_thruster_power( self.powerMonitor.addPowerNodeToModel(self.thrusterPowerSink.nodePowerOutMsg) def _set_eclipse_object(self) -> None: - """Adds the spacecraft to the eclipse module""" + """Add the spacecraft to the eclipse module.""" self.environment.eclipseObject.addSpacecraftToModel(self.scObject.scStateOutMsg) self.eclipse_index = len(self.environment.eclipseObject.eclipseOutMsgs) - 1 @@ -453,7 +486,7 @@ def _set_solar_panel( priority: int = 898, **kwargs, ) -> None: - """Sets the solar panel for power generation. + """Set the solar panel parameters for power generation. Args: panelArea: Solar panel surface area [m**2] @@ -461,6 +494,7 @@ def _set_solar_panel( conversion nHat_B: Body-fixed array normal vector priority: Model priority. + kwargs: Ignored """ self.solarPanel = simpleSolarPanel.SimpleSolarPanel() self.solarPanel.ModelTag = "solarPanel" + self.satellite.id @@ -493,12 +527,13 @@ def _set_battery( priority: int = 799, **kwargs, ) -> None: - """Sets the battery model. + """Set the battery model parameters. Args: batteryStorageCapacity: Maximum battery charge [W*s] storedCharge_Init: Initial battery charge [W*s] priority: Model priority. + kwargs: Ignored """ self.powerMonitor = simpleBattery.SimpleBattery() self.powerMonitor.ModelTag = "powerMonitor" @@ -525,7 +560,7 @@ def _set_reaction_wheel_power( priority: int = 987, **kwargs, ) -> None: - """Defines the reaction wheel power draw. + """Set the reaction wheel power draw. Args: rwBasePower: Constant power draw when operational [W] @@ -534,6 +569,7 @@ def _set_reaction_wheel_power( rwElecToMechEfficiency: Efficiency factor to convert electrical power to mechanical power priority: Model priority. + kwargs: Ignored """ self.rwPowerList = [] for i_device in range(self.rwFactory.getNumOfDevices()): @@ -551,7 +587,7 @@ def _set_reaction_wheel_power( class LOSCommDynModel(BasicDynamicsModel): - """For evaluating line-of-sight connections between satellites for communication""" + """For evaluating line-of-sight connections between satellites for communication.""" def _init_dynamics_objects(self, **kwargs) -> None: super()._init_dynamics_objects(**kwargs) @@ -562,6 +598,7 @@ def _set_los_comms(self, priority: int = 500, **kwargs) -> None: Args: priority: Model priority. + kwargs: Ignored """ self.losComms = spacecraftLocation.SpacecraftLocation() self.losComms.ModelTag = "losComms" @@ -599,10 +636,12 @@ class ImagingDynModel(BasicDynamicsModel): @property def storage_level(self): + """Storage level [bits].""" return self.storageUnit.storageUnitDataOutMsg.read().storageLevel @property def storage_level_fraction(self): + """Storage level as a fraction of capacity.""" return self.storage_level / self.storageUnit.storageCapacity def _init_dynamics_objects(self, **kwargs) -> None: @@ -623,6 +662,7 @@ def _set_instrument( Args: instrumentBaudRate: Data generated in a single step by an image [bits] priority: Model priority. + kwargs: Ignored """ self.instrument = simpleInstrument.SimpleInstrument() self.instrument.ModelTag = "instrument" + self.satellite.id @@ -650,6 +690,7 @@ def _set_transmitter( instrumentBaudRate: Image size, used to set packet size [bits] transmitterNumBuffers: Number of transmitter buffers priority: Model priority. + kwargs: Ignored """ if transmitterBaudRate > 0: self.logger.warning( @@ -663,18 +704,19 @@ def _set_transmitter( self.transmitter.packetSize = -instrumentBaudRate # bits self.transmitter.numBuffers = transmitterNumBuffers self.simulator.AddModelToTask( - self.task_name, self.transmitter, ModelPriority=798 + self.task_name, self.transmitter, ModelPriority=priority ) @default_args(instrumentPowerDraw=-30.0) def _set_instrument_power_sink( self, instrumentPowerDraw: float, priority: int = 897, **kwargs ) -> None: - """Defines the instrument power sink parameters. + """Set the instrument power sink parameters. Args: instrumentPowerDraw: Power draw when instrument is enabled [W] priority: Model priority. + kwargs: Ignored """ self.instrumentPowerSink = simplePowerSink.SimplePowerSink() self.instrumentPowerSink.ModelTag = "insPowerSink" + self.satellite.id @@ -688,11 +730,12 @@ def _set_instrument_power_sink( def _set_transmitter_power_sink( self, transmitterPowerDraw: float, priority: int = 896, **kwargs ) -> None: - """Defines the transmitter power sink parameters. + """Set the transmitter power sink parameters. Args: transmitterPowerDraw: Power draw when transmitter is enabled [W] priority: Model priority. + kwargs: Ignored """ self.transmitterPowerSink = simplePowerSink.SimplePowerSink() self.transmitterPowerSink.ModelTag = "transPowerSink" + self.satellite.id @@ -724,6 +767,7 @@ def _set_storage_unit( priority: Model priority. storageUnitValidCheck: If True, check that the storage level is below the storage capacity. + kwargs: Ignored """ self.storageUnit = partitionedStorageUnit.PartitionedStorageUnit() self.storageUnit.ModelTag = "storageUnit" + self.satellite.id @@ -778,8 +822,9 @@ def _set_imaging_target( priority: int = 2000, **kwargs, ) -> None: - """Add a generic imaging target to dynamics. The target must be updated with a - particular location when used. + """Add a generic imaging target to dynamics. + + The target must be updated with a particular location when used. Args: groundLocationPlanetRadius: Radius of ground locations from center of planet @@ -789,6 +834,7 @@ def _set_imaging_target( imageTargetMaximumRange: Maximum range from target to satellite when imaging. -1 to disable. [m] priority: Model priority. + kwargs: Ignored """ self.imagingTarget = groundLocation.GroundLocation() self.imagingTarget.ModelTag = "ImagingTarget" @@ -818,8 +864,11 @@ def reset_for_action(self) -> None: class ContinuousImagingDynModel(ImagingDynModel): - """Equips the satellite with an instrument, storage unit, and transmitter - for continuous nadir imaging.""" + """Equips the satellite for continuous nadir imaging. + + Equips satellite with an instrument, storage unit, and transmitter + for continuous nadir imaging. + """ @default_args(instrumentBaudRate=8e6) def _set_instrument( @@ -830,6 +879,7 @@ def _set_instrument( Args: instrumentBaudRate: Data generated in step by continuous imaging [bits] priority: Model priority. + kwargs: Ignored """ self.instrument = simpleInstrument.SimpleInstrument() self.instrument.ModelTag = "instrument" + self.satellite.id @@ -857,7 +907,8 @@ def _set_storage_unit( priority: Model priority. storageUnitValidCheck: If True, check that the storage level is below the storage capacity. - setStorageInit: Initial storage level [bits] + storageInit: Initial storage level [bits] + kwargs: Ignored """ self.storageUnit = simpleStorageUnit.SimpleStorageUnit() self.storageUnit.ModelTag = "storageUnit" + self.satellite.id @@ -885,13 +936,15 @@ def _set_imaging_target( priority: int = 2000, **kwargs, ) -> None: - """Add a generic imaging target to dynamics. The target must be updated with a - particular location when used. + """Add a generic imaging target to dynamics. + + The target must be updated with a particular location when used. Args: imageTargetMaximumRange: Maximum range from target to satellite when imaging. -1 to disable. [m] priority: Model priority. + kwargs: Ignored """ self.imagingTarget = groundLocation.GroundLocation() self.imagingTarget.ModelTag = "scanningTarget" @@ -914,12 +967,12 @@ def _set_imaging_target( class GroundStationDynModel(ImagingDynModel): - """Model that connects satellite to environment ground stations""" + """Model that connects satellite to environment ground stations.""" @classmethod @property - def requires_env(cls) -> list[type["EnvironmentModel"]]: - return super().requires_env + [environment.GroundStationEnvModel] + def _requires_env(cls) -> list[type["EnvironmentModel"]]: + return super()._requires_env + [environment.GroundStationEnvModel] def _init_dynamics_objects(self, **kwargs) -> None: super()._init_dynamics_objects(**kwargs) @@ -933,4 +986,6 @@ def _set_ground_station_locations(self) -> None: class FullFeaturedDynModel(GroundStationDynModel, LOSCommDynModel): + """Convenience class for a satellite with ground station and line-of-sight comms.""" + pass diff --git a/src/bsk_rl/envs/general_satellite_tasking/simulation/environment.py b/src/bsk_rl/envs/general_satellite_tasking/simulation/environment.py index 8ca10ad4..f458ca12 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/simulation/environment.py +++ b/src/bsk_rl/envs/general_satellite_tasking/simulation/environment.py @@ -1,3 +1,5 @@ +"""Basilisk environment models.""" + import logging from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Optional, Union @@ -29,9 +31,15 @@ class EnvironmentModel(ABC): + """Abstract Basilisk environment model. + + One EnvironmentModel is instantiated for the environment each time a new simulator + is created. + """ + @classmethod def default_env_args(cls, **kwargs) -> dict[str, Any]: - """Compile default argments for the environment model""" + """Compile default argments for the environment model.""" defaults = collect_default_args(cls) for k, v in kwargs.items(): if k not in defaults: @@ -46,12 +54,13 @@ def __init__( priority: int = 300, **kwargs, ) -> None: - """Base environment model + """Construct base environment model. Args: simulator: Simulator using this model env_rate: Rate of environment simulation [s] priority: Model priority. + kwargs: Ignored """ self.simulator: Simulator = proxy(simulator) @@ -67,19 +76,21 @@ def __init__( self._init_environment_objects(**kwargs) def __del__(self): + """Log when environment is deleted.""" logger.debug("Basilisk environment deleted") @abstractmethod # pragma: no cover def _init_environment_objects(self, **kwargs) -> None: - """Caller for all environment objects""" + """Caller for all environment objects.""" pass class BasicEnvironmentModel(EnvironmentModel): - """Minimal set of Basilisk environment objects""" + """Basic Environment with minimum necessary Basilisk environment components.""" @property def PN(self): + """Planet relative to inertial frame rotation matrix.""" return np.array( self.gravFactory.spiceObject.planetStateOutMsgs[self.body_index] .read() @@ -88,6 +99,7 @@ def PN(self): @property def omega_PN_N(self): + """Planet angular velocity in inertial frame [rad/s].""" PNdot = np.array( self.gravFactory.spiceObject.planetStateOutMsgs[self.body_index] .read() @@ -111,6 +123,7 @@ def _set_gravity_bodies( Args: utc_init: UTC datetime string priority: Model priority. + kwargs: Ignored """ self.gravFactory = simIncludeGravBody.gravBodyFactory() self.gravFactory.createSun() @@ -141,6 +154,7 @@ def _set_epoch_object(self, priority: int = 988, **kwargs) -> None: Args: priority: Model priority. + kwargs: Ignored """ self.ephemConverter = ephemerisConverter.EphemerisConverter() self.ephemConverter.ModelTag = "ephemConverter" @@ -174,6 +188,7 @@ def _set_atmosphere_density_model( baseDensity: Exponential model parameter [kg/m^3] scaleHeight: Exponential model parameter [m] priority (int, optional): Model priority. + kwargs: Ignored """ self.densityModel = exponentialAtmosphere.ExponentialAtmosphere() self.densityModel.ModelTag = "expDensity" @@ -192,6 +207,7 @@ def _set_eclipse_object(self, priority: int = 988, **kwargs) -> None: Args: priority: Model priority. + kwargs: Ignored """ self.eclipseObject = eclipse.Eclipse() self.eclipseObject.addPlanetToModel( @@ -205,6 +221,7 @@ def _set_eclipse_object(self, priority: int = 988, **kwargs) -> None: ) def __del__(self) -> None: + """Log when environment is deleted and unload SPICE.""" super().__del__() try: self.gravFactory.unloadSpiceKernels() @@ -213,7 +230,7 @@ def __del__(self) -> None: class GroundStationEnvModel(BasicEnvironmentModel): - """Model that includes downlink ground stations""" + """Model that includes downlink ground stations.""" def _init_environment_objects(self, **kwargs) -> None: super()._init_environment_objects(**kwargs) @@ -254,6 +271,7 @@ def _set_ground_locations( gsMaximumRange: Maximum range from station to satellite when downlinking. -1 to disable. [m] priority: Model priority. + kwargs: Ignored """ self.groundStations = [] self.groundLocationPlanetRadius = groundLocationPlanetRadius diff --git a/src/bsk_rl/envs/general_satellite_tasking/simulation/fsw.py b/src/bsk_rl/envs/general_satellite_tasking/simulation/fsw.py index e238bf17..6bb33047 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/simulation/fsw.py +++ b/src/bsk_rl/envs/general_satellite_tasking/simulation/fsw.py @@ -1,3 +1,5 @@ +"""Basilisk flight software models.""" + from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Callable, Iterable, Optional from weakref import proxy @@ -39,8 +41,7 @@ def action( func: Callable[..., None] ) -> Callable[Callable[..., None], Callable[..., None]]: - """Wrapper to do housekeeping for action functions that should be called by the - satellite class.""" + """Decorate to run housekeeping for action functions called by the satellite.""" def inner(self, *args, **kwargs) -> Callable[..., None]: self.fsw_proc.disableAllTasks() @@ -54,27 +55,33 @@ def inner(self, *args, **kwargs) -> Callable[..., None]: class FSWModel(ABC): + """Abstract Basilisk flight software model. + + One FSWModel is instantiated for each satellite in the environment each time a + new simulator is created. + """ + @classmethod @property - def requires_dyn(cls) -> list[type["DynamicsModel"]]: + def _requires_dyn(cls) -> list[type["DynamicsModel"]]: """Define minimum DynamicsModels for compatibility.""" return [] def __init__( self, satellite: "Satellite", fsw_rate: float, priority: int = 100, **kwargs ) -> None: - """Base FSWModel + """Construct a base flight software model. Args: satellite: Satellite modelled by this model fsw_rate: Rate of FSW simulation [s] priority: Model priority. + kwargs: Passed to task creation functions """ - self.satellite = satellite self.logger = self.satellite.logger.getChild(self.__class__.__name__) - for required in self.requires_dyn: + for required in self._requires_dyn: if not issubclass(satellite.dyn_type, required): raise TypeError( f"{satellite.dyn_type} must be a subclass of {required} to " @@ -91,25 +98,28 @@ def __init__( task.create_task() for task in self.tasks: - task.create_module_data() + task._create_module_data() self._set_messages() for task in self.tasks: - task.init_objects(**kwargs) + task._init_objects(**kwargs) self.fsw_proc.disableAllTasks() @property def simulator(self) -> "Simulator": + """Reference to the episode simulator.""" return self.satellite.simulator @property def environment(self) -> "EnvironmentModel": + """Reference to the episode environment model.""" return self.simulator.environment @property def dynamics(self) -> "DynamicsModel": + """Reference to the satellite dynamics model for the episode.""" return self.satellite.dynamics def _make_task_list(self) -> list["Task"]: @@ -117,7 +127,7 @@ def _make_task_list(self) -> list["Task"]: @abstractmethod # pragma: no cover def _set_messages(self) -> None: - """Message setup after task creation""" + """Message setup after task creation.""" pass def is_alive(self, log_failure=False) -> bool: @@ -129,17 +139,20 @@ def is_alive(self, log_failure=False) -> bool: return check_aliveness_checkers(self, log_failure=log_failure) def __del__(self): + """Log when FSW model is deleted.""" self.logger.debug("Basilisk FSW deleted") class Task(ABC): + """Abstract class for defining FSW tasks.""" + @property @abstractmethod # pragma: no cover - def name(self) -> str: + def name(self) -> str: # noqa: D102 pass def __init__(self, fsw: FSWModel, priority: int) -> None: - """Template class for defining FSW processes + """Template class for defining FSW processes. Args: fsw: FSW model task contributes to @@ -158,12 +171,12 @@ def create_task(self) -> None: ) @abstractmethod # pragma: no cover - def create_module_data(self) -> None: + def _create_module_data(self) -> None: """Create module data wrappers.""" pass @abstractmethod # pragma: no cover - def init_objects(self, **kwargs) -> None: + def _init_objects(self, **kwargs) -> None: """Initialize model parameters with satellite arguments.""" pass @@ -181,15 +194,19 @@ def _add_model_to_task(self, module, priority) -> None: ) def reset_for_action(self) -> None: - """Housekeeping for task when a new action is called; by default, disable - task.""" + """Housekeeping for task when a new action is called. + + Disables task by default, can be overridden by subclasses. + """ self.fsw.simulator.disableTask(self.name + self.fsw.satellite.id) class BasicFSWModel(FSWModel): + """Basic FSW model with minimum necessary Basilisk components.""" + @classmethod @property - def requires_dyn(cls) -> list[type["DynamicsModel"]]: + def _requires_dyn(cls) -> list[type["DynamicsModel"]]: return [dynamics.BasicDynamicsModel] def _make_task_list(self) -> list[Task]: @@ -222,13 +239,11 @@ def _set_thrusters_config_msg(self) -> None: self.thrusterConfigMsg = self.dynamics.thrFactory.getConfigMessage() def _set_rw_config_msg(self) -> None: - """Configure RW pyramid exactly as it is in the Dynamics (i.e. FSW with perfect - knowledge).""" + """Configure RW pyramid exactly as it is in dynamics.""" self.fswRwConfigMsg = self.dynamics.rwFactory.getConfigMessage() def _set_gateway_msgs(self) -> None: - """Create C-wrapped gateway messages such that different modules can write to - this message and provide a common input msg for down-stream modules.""" + """Create C-wrapped gateway messages.""" self.attRefMsg = cMsgPy.AttRefMsg_C() self.attGuidMsg = cMsgPy.AttGuidMsg_C() @@ -249,7 +264,7 @@ def _zero_gateway_msgs(self) -> None: @action def action_drift(self) -> None: - """Action to disable all tasks.""" + """Disable all tasks.""" self.simulator.disableTask( BasicFSWModel.MRPControlTask.name + self.satellite.id ) @@ -259,18 +274,19 @@ class SunPointTask(Task): name = "sunPointTask" - def __init__(self, fsw, priority=99) -> None: + def __init__(self, fsw, priority=99) -> None: # noqa: D107 super().__init__(fsw, priority) - def create_module_data(self) -> None: + def _create_module_data(self) -> None: self.sunPoint = self.fsw.sunPoint = locationPointing.locationPointing() self.sunPoint.ModelTag = "sunPoint" - def init_objects(self, nHat_B: Iterable[float], **kwargs) -> None: + def _init_objects(self, nHat_B: Iterable[float], **kwargs) -> None: """Configure the sun-pointing task. Args: nHat_B: Solar array normal vector + kwargs: Ignored """ self.sunPoint.pHat_B = nHat_B self.sunPoint.scAttInMsg.subscribeTo( @@ -293,7 +309,7 @@ def init_objects(self, nHat_B: Iterable[float], **kwargs) -> None: @action def action_charge(self) -> None: - """Action to charge solar panels.""" + """Charge battery using solar panels.""" self.sunPoint.Reset(self.simulator.sim_time_ns) self.simulator.enableTask(self.SunPointTask.name + self.satellite.id) @@ -302,14 +318,14 @@ class NadirPointTask(Task): name = "nadirPointTask" - def __init__(self, fsw, priority=98) -> None: + def __init__(self, fsw, priority=98) -> None: # noqa: D107 super().__init__(fsw, priority) - def create_module_data(self) -> None: + def _create_module_data(self) -> None: self.hillPoint = self.fsw.hillPoint = hillPoint.hillPoint() self.hillPoint.ModelTag = "hillPoint" - def init_objects(self, **kwargs) -> None: + def _init_objects(self, **kwargs) -> None: """Configure the nadir-pointing task.""" self.hillPoint.transNavInMsg.subscribeTo( self.fsw.dynamics.simpleNavObject.transOutMsg @@ -330,10 +346,10 @@ class RWDesatTask(Task): name = "rwDesatTask" - def __init__(self, fsw, priority=97) -> None: + def __init__(self, fsw, priority=97) -> None: # noqa: D107 super().__init__(fsw, priority) - def create_module_data(self) -> None: + def _create_module_data(self) -> None: """Set up momentum dumping and thruster control.""" # Momentum dumping configuration self.thrDesatControl = ( @@ -350,7 +366,7 @@ def create_module_data(self) -> None: ) = thrForceMapping.thrForceMapping() self.thrForceMapping.ModelTag = "thrForceMapping" - def init_objects(self, **kwargs) -> None: + def _init_objects(self, **kwargs) -> None: self._set_thruster_mapping(**kwargs) self._set_momentum_dumping(**kwargs) @@ -364,6 +380,7 @@ def _set_thruster_mapping( controlAxes_B: Control unit axes thrForceSign: Flag indicating if pos (+1) or negative (-1) thruster solutions are found + kwargs: Ignored """ self.thrForceMapping.cmdTorqueInMsg.subscribeTo( self.thrDesatControl.deltaHOutMsg @@ -399,6 +416,7 @@ def _set_momentum_dumping( desatAttitude: Direction to point while desaturating: "sun" points panels at sun, "nadir" points instrument nadir, None disables attitude control + kwargs: Ignored """ self.fsw.desatAttitude = desatAttitude self.thrDesatControl.hs_min = hs_min # Nms @@ -419,12 +437,13 @@ def _set_momentum_dumping( self._add_model_to_task(self.thrDump, priority=1191) def reset_for_action(self) -> None: + """Disable power draw for thrusters.""" super().reset_for_action() self.fsw.dynamics.thrusterPowerSink.powerStatus = 0 @action def action_desat(self) -> None: - """Action to charge while desaturating reaction wheels.""" + """Charge while desaturating reaction wheels.""" self.trackingError.Reset(self.simulator.sim_time_ns) self.thrDesatControl.Reset(self.simulator.sim_time_ns) self.thrDump.Reset(self.simulator.sim_time_ns) @@ -449,16 +468,16 @@ class TrackingErrorTask(Task): name = "trackingErrTask" - def __init__(self, fsw, priority=90) -> None: + def __init__(self, fsw, priority=90) -> None: # noqa: D107 super().__init__(fsw, priority) - def create_module_data(self) -> None: + def _create_module_data(self) -> None: self.trackingError = ( self.fsw.trackingError ) = attTrackingError.attTrackingError() self.trackingError.ModelTag = "trackingError" - def init_objects(self, **kwargs) -> None: + def _init_objects(self, **kwargs) -> None: self.trackingError.attNavInMsg.subscribeTo( self.fsw.dynamics.simpleNavObject.attOutMsg ) @@ -474,10 +493,10 @@ class MRPControlTask(Task): name = "mrpControlTask" - def __init__(self, fsw, priority=80) -> None: + def __init__(self, fsw, priority=80) -> None: # noqa: D107 super().__init__(fsw, priority) - def create_module_data(self) -> None: + def _create_module_data(self) -> None: # Attitude controller configuration self.mrpFeedbackControl = ( self.fsw.mrpFeedbackControl @@ -488,7 +507,7 @@ def create_module_data(self) -> None: self.rwMotorTorque = self.fsw.rwMotorTorque = rwMotorTorque.rwMotorTorque() self.rwMotorTorque.ModelTag = "rwMotorTorque" - def init_objects(self, **kwargs) -> None: + def _init_objects(self, **kwargs) -> None: self._set_mrp_feedback_rwa(**kwargs) self._set_rw_motor_torque(**kwargs) @@ -496,12 +515,13 @@ def init_objects(self, **kwargs) -> None: def _set_mrp_feedback_rwa( self, K: float, Ki: float, P: float, **kwargs ) -> None: - """Defines the control properties. + """Set the MRP feedback control properties. Args: K: Proportional gain Ki: Integral gain P: Derivative gain + kwargs: Ignored """ self.mrpFeedbackControl.guidInMsg.subscribeTo(self.fsw.attGuidMsg) self.mrpFeedbackControl.vehConfigInMsg.subscribeTo(self.fsw.vcConfigMsg) @@ -517,10 +537,11 @@ def _set_mrp_feedback_rwa( def _set_rw_motor_torque( self, controlAxes_B: Iterable[float], **kwargs ) -> None: - """Defines the motor torque from the control law. + """Set parameters for finding motor torque from the control law. Args: - controlAxes_B): Control unit axes + controlAxes_B: Control unit axes + kwargs: Ignored """ self.rwMotorTorque.rwParamsInMsg.subscribeTo(self.fsw.fswRwConfigMsg) self.rwMotorTorque.vehControlInMsg.subscribeTo( @@ -536,15 +557,16 @@ def reset_for_action(self) -> None: class ImagingFSWModel(BasicFSWModel): - """Extend FSW with instrument pointing and triggering control""" + """Extend FSW with instrument pointing and triggering control.""" @classmethod @property - def requires_dyn(cls) -> list[type["DynamicsModel"]]: - return super().requires_dyn + [dynamics.ImagingDynModel] + def _requires_dyn(cls) -> list[type["DynamicsModel"]]: + return super()._requires_dyn + [dynamics.ImagingDynModel] @property def c_hat_P(self): + """Instrument pointing direction in the planet frame.""" c_hat_B = self.locPoint.pHat_B return np.matmul(self.dynamics.BP.T, c_hat_B) @@ -558,14 +580,14 @@ def _set_gateway_msgs(self) -> None: ) class LocPointTask(Task): - """Task to point at targets and trigger the instrument""" + """Task to point at targets and trigger the instrument.""" name = "locPointTask" - def __init__(self, fsw, priority=96) -> None: + def __init__(self, fsw, priority=96) -> None: # noqa: D107 super().__init__(fsw, priority) - def create_module_data(self) -> None: + def _create_module_data(self) -> None: # Location pointing configuration self.locPoint = self.fsw.locPoint = locationPointing.locationPointing() self.locPoint.ModelTag = "locPoint" @@ -576,7 +598,7 @@ def create_module_data(self) -> None: ) = simpleInstrumentController.simpleInstrumentController() self.insControl.ModelTag = "instrumentController" - def init_objects(self, **kwargs) -> None: + def _init_objects(self, **kwargs) -> None: self._set_location_pointing(**kwargs) self._set_instrument_controller(**kwargs) @@ -584,10 +606,11 @@ def init_objects(self, **kwargs) -> None: def _set_location_pointing( self, inst_pHat_B: Iterable[float], **kwargs ) -> None: - """Defines the Earth location pointing guidance module. + """Set the Earth location pointing guidance module. Args: inst_pHat_B: Instrument pointing direction + kwargs: Ignored """ self.locPoint.pHat_B = inst_pHat_B self.locPoint.scAttInMsg.subscribeTo( @@ -613,13 +636,14 @@ def _set_instrument_controller( imageRateErrorRequirement: float, **kwargs, ) -> None: - """Defines the instrument controller parameters. + """Set the instrument controller parameters. Args: imageAttErrorRequirement: Pointing attitude error tolerance for imaging [MRP norm] imageRateErrorRequirement: Rate tolerance for imaging. Disable with None. [rad/s] + kwargs: Ignored """ self.insControl.attErrTolerance = imageAttErrorRequirement if imageRateErrorRequirement is not None: @@ -633,6 +657,7 @@ def _set_instrument_controller( self._add_model_to_task(self.insControl, priority=987) def reset_for_action(self) -> None: + """Reset pointing controller.""" self.fsw.dynamics.imagingTarget.Reset(self.fsw.simulator.sim_time_ns) self.locPoint.Reset(self.fsw.simulator.sim_time_ns) self.insControl.controllerStatus = 0 @@ -640,7 +665,7 @@ def reset_for_action(self) -> None: @action def action_image(self, location: Iterable[float], data_name: str) -> None: - """Action to image a target at a location. + """Attempt to image a target at a location. Args: location: PCPF target location [m] @@ -655,7 +680,7 @@ def action_image(self, location: Iterable[float], data_name: str) -> None: @action def action_downlink(self) -> None: - """Action to attempt to downlink data.""" + """Attempt to downlink data.""" self.hillPoint.Reset(self.simulator.sim_time_ns) self.trackingError.Reset(self.simulator.sim_time_ns) self.dynamics.transmitter.dataStatus = 1 @@ -667,10 +692,12 @@ def action_downlink(self) -> None: class ContinuousImagingFSWModel(ImagingFSWModel): + """FSW model for continuous nadir scanning.""" + class LocPointTask(ImagingFSWModel.LocPointTask): - """Task to point at targets and trigger the instrument""" + """Task to point at targets and trigger the instrument.""" - def create_module_data(self) -> None: + def _create_module_data(self) -> None: # Location pointing configuration self.locPoint = self.fsw.locPoint = locationPointing.locationPointing() self.locPoint.ModelTag = "locPoint" @@ -688,13 +715,14 @@ def _set_instrument_controller( imageRateErrorRequirement: float, **kwargs, ) -> None: - """Defines the instrument controller parameters. + """Set the instrument controller parameters. Args: imageAttErrorRequirement: Pointing attitude error tolerance for imaging [MRP norm] imageRateErrorRequirement: Rate tolerance for imaging. Disable with None. [rad/s] + kwargs: Ignored """ self.insControl.attErrTolerance = imageAttErrorRequirement if imageRateErrorRequirement is not None: @@ -708,6 +736,7 @@ def _set_instrument_controller( self._add_model_to_task(self.insControl, priority=987) def reset_for_action(self) -> None: + """Reset scanning controller.""" self.instMsg = cMsgPy.DeviceCmdMsg_C() self.instMsg.write(messaging.DeviceCmdMsgPayload()) self.fsw.dynamics.instrument.nodeStatusInMsg.subscribeTo(self.instMsg) @@ -715,7 +744,7 @@ def reset_for_action(self) -> None: @action def action_nadir_scan(self) -> None: - """Action scan nadir. + """Scan nadir. Args: location: PCPF target location [m] @@ -732,19 +761,22 @@ def action_nadir_scan(self) -> None: @action def action_image(self, *args, **kwargs) -> None: + """Disable imaging from parent class.""" raise NotImplementedError("Use action_nadir_scan instead") class SteeringFSWModel(BasicFSWModel): - """FSW extending MRP control to use MRP steering instesd of MRP feedback.""" + """FSW extending MRP control to use MRP steering instead of MRP feedback.""" class MRPControlTask(Task): + """Task that uses MRP steering to control reaction wheels.""" + name = "mrpControlTask" - def __init__(self, fsw, priority=80) -> None: + def __init__(self, fsw, priority=80) -> None: # noqa: D107 super().__init__(fsw, priority) - def create_module_data(self) -> None: + def _create_module_data(self) -> None: # Attitude controller configuration self.mrpSteeringControl = ( self.fsw.mrpSteeringControl @@ -761,7 +793,7 @@ def create_module_data(self) -> None: self.rwMotorTorque = self.fsw.rwMotorTorque = rwMotorTorque.rwMotorTorque() self.rwMotorTorque.ModelTag = "rwMotorTorque" - def init_objects(self, **kwargs) -> None: + def _init_objects(self, **kwargs) -> None: self._set_mrp_steering_rwa(**kwargs) self._set_rw_motor_torque(**kwargs) @@ -775,7 +807,7 @@ def _set_mrp_steering_rwa( servo_P: float, **kwargs, ) -> None: - """Defines the control properties. + """Define the control properties. Args: K1: MRP steering gain @@ -783,6 +815,7 @@ def _set_mrp_steering_rwa( omega_max: Maximum targetable spacecraft body rate [rad/s] servo_Ki: Servo gain servo_P: Servo gain + kwargs: Ignored """ self.mrpSteeringControl.guidInMsg.subscribeTo(self.fsw.attGuidMsg) self.mrpSteeringControl.K1 = K1 @@ -810,10 +843,11 @@ def _set_mrp_steering_rwa( def _set_rw_motor_torque( self, controlAxes_B: Iterable[float], **kwargs ) -> None: - """Defines the motor torque from the control law. + """Define the motor torque from the control law. Args: controlAxes_B: Control unit axes + kwargs: Ignored """ self.rwMotorTorque.rwParamsInMsg.subscribeTo(self.fsw.fswRwConfigMsg) self.rwMotorTorque.vehControlInMsg.subscribeTo(self.servo.cmdTorqueOutMsg) @@ -822,9 +856,11 @@ def _set_rw_motor_torque( self._add_model_to_task(self.rwMotorTorque, priority=1194) def reset_for_action(self) -> None: - # MRP control enabled by default + """Keep MRP control enabled on action calls.""" self.fsw.simulator.enableTask(self.name + self.fsw.satellite.id) class SteeringImagerFSWModel(SteeringFSWModel, ImagingFSWModel): + """Convenience type for ImagingFSWModel with MRP steering.""" + pass diff --git a/src/bsk_rl/envs/general_satellite_tasking/simulation/simulator.py b/src/bsk_rl/envs/general_satellite_tasking/simulation/simulator.py index 024c5a27..8b5fe8e4 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/simulation/simulator.py +++ b/src/bsk_rl/envs/general_satellite_tasking/simulation/simulator.py @@ -1,3 +1,5 @@ +"""Extended Basilisk SimBaseClass for GeneralSatelliteTasking environments.""" + from typing import TYPE_CHECKING, Any if TYPE_CHECKING: # pragma: no cover @@ -15,6 +17,8 @@ class Simulator(SimulationBaseClass.SimBaseClass): + """Basilisk simulator for GeneralSatelliteTasking environments.""" + def __init__( self, satellites: list["Satellite"], @@ -24,7 +28,7 @@ def __init__( max_step_duration: float = 600.0, time_limit: float = float("inf"), ) -> None: - """Basilisk simulator for GeneralSatelliteTasking environments. + """Construct Basilisk simulator. Args: satellites: Satellites to be simulated @@ -69,7 +73,7 @@ def sim_time(self) -> float: def _set_environment( self, env_type: type["EnvironmentModel"], env_args: dict[str, Any] ) -> None: - """Construct the simulator environment model + """Construct the simulator environment model. Args: env_type: type of environment model to be constructed @@ -78,7 +82,7 @@ def _set_environment( self.environment = env_type(self, self.sim_rate, **env_args) def run(self) -> None: - """Propagate the simulator""" + """Propagate the simulator.""" simulation_time = mc.sec2nano( min(self.sim_time + self.max_step_duration, self.time_limit) ) @@ -87,10 +91,14 @@ def run(self) -> None: self.ExecuteSimulation() def delete_event(self, event_name) -> None: - """Removes an event from the event map. Makes event checking faster""" + """Remove an event from the event map. + + Makes event checking faster. + """ event = self.eventMap[event_name] self.eventList.remove(event) del self.eventMap[event_name] def __del__(self): + """Log when simulator is deleted.""" logger.debug("Basilisk simulator deleted") diff --git a/src/bsk_rl/envs/general_satellite_tasking/utils/functional.py b/src/bsk_rl/envs/general_satellite_tasking/utils/functional.py index 378d176b..5bbb64c1 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/utils/functional.py +++ b/src/bsk_rl/envs/general_satellite_tasking/utils/functional.py @@ -1,3 +1,5 @@ +"""General utility functions.""" + import inspect import re import warnings @@ -8,7 +10,7 @@ def valid_func_name(name: str) -> str: - """Converts a string into a valid function name. + """Convert a string into a valid function name. Args: name: desired function name @@ -25,7 +27,7 @@ def valid_func_name(name: str) -> str: def safe_dict_merge(updates: dict, base: dict) -> dict: - """Merges a dict with another dict, warning for conflicts + """Merge a dict with another dict, warning for conflicts. Args: updates: dictionary to be added to base @@ -43,8 +45,7 @@ def safe_dict_merge(updates: dict, base: dict) -> dict: def default_args(**defaults) -> Callable: - """Decorator to enumerate default arguments of certain functions so they can be - collected""" + """Decorate function to enumerate default arguments for collection.""" def inner_dec(func) -> Callable: def inner(*args, **kwargs) -> Callable: @@ -57,7 +58,7 @@ def inner(*args, **kwargs) -> Callable: def collect_default_args(object: object) -> dict[str, Any]: - """Collect all function @default_args in an object + """Collect all function @default_args in an object. Args: object: object with @default_args decorated functions @@ -77,8 +78,7 @@ def collect_default_args(object: object) -> dict[str, Any]: def vectorize_nested_dict(dictionary: dict) -> np.ndarray: - """Flattens a dictionary of dictionaries, arrays, and scalars into a single - vector.""" + """Flattens a dictionary of dicts, arrays, and scalars into a single vector.""" values = list(dictionary.values()) for i, value in enumerate(values): if isinstance(value, np.ndarray): @@ -92,7 +92,7 @@ def vectorize_nested_dict(dictionary: dict) -> np.ndarray: def aliveness_checker(func: Callable[..., bool]) -> Callable[..., bool]: - """Decorator to evaluate func -> bool when checking for satellite aliveness""" + """Decorate function to evaluate when checking for satellite aliveness.""" def inner(*args, log_failure=False, **kwargs) -> bool: self = args[0] @@ -106,10 +106,11 @@ def inner(*args, log_failure=False, **kwargs) -> bool: def check_aliveness_checkers(model: Any, log_failure=False) -> bool: - """Evaluate all functions with @aliveness_checker in a model + """Evaluate all functions with @aliveness_checker in a model. Args: - model (Any): Model to search for checkers in + model: Model to search for checkers in + log_failure: Whether to log on checker failure Returns: bool: Model aliveness status @@ -127,15 +128,14 @@ def check_aliveness_checkers(model: Any, log_failure=False) -> bool: def is_property(obj: Any, attr_name: str) -> bool: - """Check if obj has an @property attr_name without calling it""" + """Check if obj has an @property attr_name without calling it.""" cls = type(obj) attribute = getattr(cls, attr_name, None) return attribute is not None and isinstance(attribute, property) def configurable(cls): - """Class decorator to create new instance of a class with different defaults to - __init__""" + """Class decorator to create class with different init defaults.""" @classmethod def configure(cls, **config_kwargs): @@ -162,10 +162,10 @@ def __init__(self, *args, **kwargs): def bind(instance, func, as_name=None): - """ - Bind the function *func* to *instance*, with either provided name *as_name* - or the existing name of *func*. The provided *func* should accept the - instance as the first argument, i.e. "self". + """Bind the function *func* to *instance*. + + Uses either provided name *as_name* or the existing name of *func*. The provided + *func* should accept the instance as the first argument, i.e. "self". """ if as_name is None: as_name = func.__name__ diff --git a/src/bsk_rl/envs/general_satellite_tasking/utils/logging_config.py b/src/bsk_rl/envs/general_satellite_tasking/utils/logging_config.py index 3be1771d..56fbe14f 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/utils/logging_config.py +++ b/src/bsk_rl/envs/general_satellite_tasking/utils/logging_config.py @@ -1,10 +1,6 @@ +# ruff: noqa import logging -# sim_format = logging.Formatter( -# "\x1b[30;3m%(asctime)s\x1b[0m %(shortname)-30s %(levelname)-10s <%(sim_time)-.2f> %(message)s", -# defaults={"sim_time": -1}, -# ) - fstr = "\x1b[30;3m%(asctime)s\x1b[0m %(shortname)-30s %(levelname)-10s <%(sim_time)-.2f> %(message)s" colors = dict( diff --git a/src/bsk_rl/envs/general_satellite_tasking/utils/orbital.py b/src/bsk_rl/envs/general_satellite_tasking/utils/orbital.py index bd2a48a6..4ccc92b4 100644 --- a/src/bsk_rl/envs/general_satellite_tasking/utils/orbital.py +++ b/src/bsk_rl/envs/general_satellite_tasking/utils/orbital.py @@ -1,3 +1,5 @@ +"""Utilities for computing orbital events.""" + from typing import Iterable, Optional import numpy as np @@ -19,8 +21,9 @@ def random_orbit( omega: Optional[float] = 0, f: Optional[float] = None, ) -> ClassicElements: - """Create a set of orbit elements. Parameters are fixed if specified and randomized - if None. + """Create a set of orbit elements. + + Parameters are fixed if specified and randomized if None. Args: i: inclination [deg], randomized in [-pi, pi] @@ -49,7 +52,9 @@ def random_orbit( def random_epoch(start: int = 2000, end: int = 2022): - """Generates a random epoch. + """Generate a random epoch in a year range. + + Date will always be in the first 28 days of the month. Args: start: Initial year. @@ -91,7 +96,7 @@ def random_epoch(start: int = 2000, end: int = 2022): def elevation(r_sat: np.ndarray, r_target: np.ndarray) -> np.ndarray: - """Find the elevation angle from a target to a satellite + """Find the elevation angle from a target to a satellite. Args: r_sat: Satellite position(s) @@ -113,6 +118,8 @@ def elevation(r_sat: np.ndarray, r_target: np.ndarray) -> np.ndarray: class TrajectorySimulator(SimulationBaseClass.SimBaseClass): + """Class for propagating trajectory using a point mass simulation.""" + def __init__( self, utc_init: str, @@ -122,9 +129,11 @@ def __init__( mu: Optional[float] = None, dt: float = 30.0, ) -> None: - """Class for propagating trajectory using a point mass simulation under the - effect of Earth's gravity. Returns interpolators for position as well as - upcoming eclipse predictions. Specify either (rN, vN) or (oe, mu). + """Initialize simulator conditions. + + Simulated under the effect of Earth's gravity. Returns interpolators for + position as well as upcoming eclipse predictions. Specify either (rN, vN) or + (oe, mu). Args: utc_init: Simulation start time. @@ -158,9 +167,9 @@ def __init__( self._eclipse_ends: list[float] = [] self._eclipse_search_time = 0.0 - self.init_simulator() + self._init_simulator() - def init_simulator(self) -> None: + def _init_simulator(self) -> None: simTaskName = "simTask" simProcessName = "simProcess" dynProcess = self.CreateNewProcess(simProcessName) @@ -217,16 +226,16 @@ def init_simulator(self) -> None: @property def sim_time(self) -> float: - """Current simulator end time""" + """Current simulator end time.""" return macros.NANO2SEC * self.TotalSim.CurrentNanos @property def times(self) -> np.ndarray: - """Recorder times in seconds""" + """Recorder times in seconds.""" return np.array([macros.NANO2SEC * t for t in self.sc_state_log.times()]) def extend_to(self, t: float) -> None: - """Compute the trajectory of the satellite up to t + """Compute the trajectory of the satellite up to t. Args: t: Computation end [s] @@ -236,7 +245,7 @@ def extend_to(self, t: float) -> None: self.ConfigureStopTime(macros.sec2nano(t)) self.ExecuteSimulation() - def generate_eclipses(self, t: float) -> None: + def _generate_eclipses(self, t: float) -> None: self.extend_to(t + self.dt) upcoming_times = self.times[self.times > self._eclipse_search_time] upcoming_eclipse = ( @@ -250,11 +259,14 @@ def generate_eclipses(self, t: float) -> None: self._eclipse_search_time = t def next_eclipse(self, t: float, max_tries: int = 100) -> tuple[float, float]: - """Find the soonest eclipse transitions. The returned values are not necessarily - from the same eclipse event, such as when the search start time is in eclipse. + """Find the soonest eclipse transitions. + + The returned values are not necessarily from the same eclipse event, such as + when the search start time is in eclipse. Args: t: Time to start searching [s] + max_tries: Maximum number of times to search Returns: eclipse_start: Nearest upcoming eclipse beginning @@ -270,13 +282,13 @@ def next_eclipse(self, t: float, max_tries: int = 100) -> tuple[float, float]: eclipse_end = min([t_end for t_end in self._eclipse_ends if t_end > t]) return eclipse_start, eclipse_end - self.generate_eclipses(t + i * self.dt * 10) + self._generate_eclipses(t + i * self.dt * 10) return 1.0, 1.0 @property def r_BN_N(self) -> interp1d: - """Interpolator for r_BN_N""" + """Interpolator for r_BN_N.""" if self.sim_time < self.dt * 3: self.extend_to(self.dt * 3) return interp1d( @@ -289,7 +301,7 @@ def r_BN_N(self) -> interp1d: @property def r_BP_P(self) -> interp1d: - """Interpolator for r_BP_P""" + """Interpolator for r_BP_P.""" if self.sim_time < self.dt * 3: self.extend_to(self.dt * 3) return interp1d( @@ -306,6 +318,7 @@ def r_BP_P(self) -> interp1d: ) def __del__(self) -> None: + """Unload spice kernels when object is deleted.""" try: self.gravFactory.unloadSpiceKernels() except AttributeError: diff --git a/tests/unittest/envs/general_satellite_tasking/simulation/test_dynamics.py b/tests/unittest/envs/general_satellite_tasking/simulation/test_dynamics.py index 0198f6bf..1a9d610a 100644 --- a/tests/unittest/envs/general_satellite_tasking/simulation/test_dynamics.py +++ b/tests/unittest/envs/general_satellite_tasking/simulation/test_dynamics.py @@ -36,10 +36,10 @@ def test_is_alive(self): def test_basic_requires_env(): - assert environment.BasicEnvironmentModel in BasicDynamicsModel.requires_env + assert environment.BasicEnvironmentModel in BasicDynamicsModel._requires_env -@patch(basicdyn + "requires_env", MagicMock(return_value=[])) +@patch(basicdyn + "_requires_env", MagicMock(return_value=[])) @patch(basicdyn + "_set_spacecraft_hub") @patch(basicdyn + "_set_drag_effector") @patch(basicdyn + "_set_reaction_wheel_dyn_effector") @@ -56,7 +56,7 @@ def test_basic_init_objects(self, *args): setter.assert_called_once() -@patch(basicdyn + "requires_env", MagicMock(return_value=[])) +@patch(basicdyn + "_requires_env", MagicMock(return_value=[])) @patch(basicdyn + "_init_dynamics_objects", MagicMock()) class TestBasicDynamicsModel: def test_dynamic_properties(self): @@ -139,7 +139,7 @@ def test_battery_valid(self, level, valid): class TestLOSCommDynModel: losdyn = module + "LOSCommDynModel." - @patch(losdyn + "requires_env", MagicMock(return_value=[])) + @patch(losdyn + "_requires_env", MagicMock(return_value=[])) @patch(module + "BasicDynamicsModel._init_dynamics_objects", MagicMock()) @patch(losdyn + "_set_los_comms") def test_init_objects(self, *args): @@ -147,7 +147,7 @@ def test_init_objects(self, *args): for setter in args: setter.assert_called_once() - @patch(losdyn + "requires_env", MagicMock(return_value=[])) + @patch(losdyn + "_requires_env", MagicMock(return_value=[])) @patch(losdyn + "_init_dynamics_objects", MagicMock()) @patch(module + "spacecraftLocation", MagicMock()) def test_set_los_comms(self): @@ -185,7 +185,7 @@ def test_set_los_comms(self): imdyn = module + "ImagingDynModel." -@patch(imdyn + "requires_env", MagicMock(return_value=[])) +@patch(imdyn + "_requires_env", MagicMock(return_value=[])) @patch(module + "BasicDynamicsModel._init_dynamics_objects", MagicMock()) @patch(imdyn + "_set_instrument_power_sink") @patch(imdyn + "_set_transmitter_power_sink") @@ -199,7 +199,7 @@ def test_init_objects(*args): setter.assert_called_once() -@patch(imdyn + "requires_env", MagicMock(return_value=[])) +@patch(imdyn + "_requires_env", MagicMock(return_value=[])) @patch(imdyn + "_init_dynamics_objects", MagicMock()) class TestImagingDynModel: def test_storage_properties(self): @@ -255,11 +255,11 @@ def test_set_storage_unit(self, buffers, names, expected): class TestGroundStationDynModel: def test_requires_env(self): - assert environment.GroundStationEnvModel in GroundStationDynModel.requires_env + assert environment.GroundStationEnvModel in GroundStationDynModel._requires_env gsdyn = module + "GroundStationDynModel." - @patch(gsdyn + "requires_env", MagicMock(return_value=[])) + @patch(gsdyn + "_requires_env", MagicMock(return_value=[])) @patch(module + "ImagingDynModel._init_dynamics_objects", MagicMock()) @patch(gsdyn + "_set_ground_station_locations") def test_init_objects(self, *args): @@ -268,7 +268,7 @@ def test_init_objects(self, *args): setter.assert_called_once() -@patch(imdyn + "requires_env", MagicMock(return_value=[])) +@patch(imdyn + "_requires_env", MagicMock(return_value=[])) @patch(imdyn + "_init_dynamics_objects", MagicMock()) class TestContinuousImagingDynModel: def test_storage_properties(self): diff --git a/tests/unittest/envs/general_satellite_tasking/simulation/test_fsw.py b/tests/unittest/envs/general_satellite_tasking/simulation/test_fsw.py index 326fa1cd..364ded40 100644 --- a/tests/unittest/envs/general_satellite_tasking/simulation/test_fsw.py +++ b/tests/unittest/envs/general_satellite_tasking/simulation/test_fsw.py @@ -53,7 +53,7 @@ def test_base_class(self): task = Task(fsw, 1) task.create_task() task.fsw.simulator.CreateNewTask.assert_called_once() - task.init_objects() + task._init_objects() task.reset_for_action() task.fsw.simulator.disableTask.assert_called_once() @@ -61,7 +61,7 @@ def test_base_class(self): basicfsw = module + "BasicFSWModel." -@patch(basicfsw + "requires_dyn", MagicMock(return_value=[])) +@patch(basicfsw + "_requires_dyn", MagicMock(return_value=[])) class TestBasicFSWModel: @patch(basicfsw + "_set_messages", MagicMock()) @patch(basicfsw + "SunPointTask") @@ -73,14 +73,14 @@ def test_make_tasks(self, *args): fsw = BasicFSWModel(MagicMock(), 1) for task in fsw.tasks: task.create_task.assert_called_once() - task.create_module_data.assert_called_once() - task.init_objects.assert_called_once() + task._create_module_data.assert_called_once() + task._init_objects.assert_called_once() imagingfsw = module + "ImagingFSWModel." -@patch(imagingfsw + "requires_dyn", MagicMock(return_value=[])) +@patch(imagingfsw + "_requires_dyn", MagicMock(return_value=[])) @patch(imagingfsw + "_make_task_list", MagicMock()) @patch(imagingfsw + "_set_messages", MagicMock()) class TestImagingFSWModel: