Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #130: Infer environment model type #131

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ TODO: Add more detail to this example
sats.FullFeaturedSatellite.default_sat_args(oe=random_orbit), n_ahead_observe=30,
n_ahead_act=15
),
env_type=environment.GroundStationEnvModel,
env_args=environment.GroundStationEnvModel.default_env_args(),
env_features=StaticTargets(n_targets=1000),
data_manager=data.UniqueImagingManager,
max_step_duration=600.0,
Expand Down
7 changes: 0 additions & 7 deletions examples/general_satellite_tasking/multisat_aeos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from bsk_rl.envs.general_satellite_tasking.scenario.environment_features import (
CityTargets,
)
from bsk_rl.envs.general_satellite_tasking.simulation import environment
from bsk_rl.utilities.initial_conditions import leo_orbit


Expand Down Expand Up @@ -72,12 +71,6 @@ def run():
env = gym.make(
"GeneralSatelliteTasking-v1",
satellites=satellites,
# Pick the type for the Basilisk environment model. Note that it is not instantiated
# here.
env_type=environment.GroundStationEnvModel,
# Like default_sat_args, default_env_args infers model parameters from the type and
# specific parameters can be overridden or randomized.
env_args=environment.GroundStationEnvModel.default_env_args(),
# Pass configuration objects
env_features=env_features,
data_manager=data_manager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from bsk_rl.envs.general_satellite_tasking.scenario.environment_features import (
CityTargets,
)
from bsk_rl.envs.general_satellite_tasking.simulation import dynamics, environment, fsw
from bsk_rl.envs.general_satellite_tasking.simulation import dynamics, fsw
from bsk_rl.envs.general_satellite_tasking.utils.orbital import random_orbit

bskLogging.setDefaultLogLevel(bskLogging.BSK_WARNING)
Expand Down Expand Up @@ -181,8 +181,6 @@ class CustomDynModel(dynamics.ImagingDynModel, dynamics.LOSCommDynModel):
"SingleSatelliteTasking-v1",
satellites=satellite,
# Select an EnvironmentModel compatible with the models in the satellite
env_type=environment.BasicEnvironmentModel,
env_args=environment.BasicEnvironmentModel.default_env_args(),
env_features=env_features,
data_manager=data_manager,
sim_rate=0.5,
Expand Down
10 changes: 4 additions & 6 deletions examples/general_satellite_tasking/single_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from bsk_rl.envs.general_satellite_tasking.scenario.environment_features import (
StaticTargets,
)
from bsk_rl.envs.general_satellite_tasking.simulation import environment
from bsk_rl.envs.general_satellite_tasking.utils.orbital import random_orbit

# This script demonstrates the configuration of an environment with a single imaging
Expand Down Expand Up @@ -44,12 +43,11 @@
"SingleSatelliteTasking-v1",
satellites=satellite,
# Pick the type for the Basilisk environment model. Note that it is not instantiated
# here.
env_type=environment.GroundStationEnvModel,
# here. This can be automatically inferred from the satellite types
# env_type=environment.GroundStationEnvModel,
# Like default_sat_args, default_env_args infers model parameters from the type and
# specific parameters can be
# overridden or randomized.
env_args=environment.GroundStationEnvModel.default_env_args(),
# specific parameters can be overridden or randomized.
# env_args={},
# Pass configuration objects
env_features=env_features,
data_manager=data_manager,
Expand Down
30 changes: 26 additions & 4 deletions src/bsk_rl/envs/general_satellite_tasking/gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ class GeneralSatelliteTasking(Env, Generic[SatObs, SatAct]):
def __init__(
self,
satellites: Union[Satellite, list[Satellite]],
env_type: type[EnvironmentModel],
env_features: EnvironmentFeatures,
data_manager: DataManager,
env_type: Optional[type[EnvironmentModel]] = None,
env_args: Optional[dict[str, Any]] = None,
communicator: Optional[CommunicationMethod] = None,
sim_rate: float = 1.0,
Expand All @@ -75,13 +75,13 @@ def __init__(

Args:
satellites: Satellites(s) to be simulated.
env_features: Information about the environment.
data_manager: Object to record and reward data collection.
communicator: Object to manage communication between satellites/
env_type: Type of environment model to be constructed.
env_args: Arguments for environment model construction. {key: value or key:
function}, where function is called at reset to set the value (used for
randomization).
env_features: Information about the environment.
data_manager: Object to record and reward data collection.
communicator: Object to manage communication between satellites
sim_rate: Rate for model simulation [s].
max_step_duration: Maximum time to propagate sim at a step [s].
failure_penalty: Reward for satellite failure. Should be nonpositive.
Expand All @@ -98,6 +98,8 @@ def __init__(
satellites = [satellites]
self.satellites = satellites
self.simulator: Simulator
if env_type is None:
env_type = self._minimum_env_model()
self.env_type = env_type
if env_args is None:
env_args = self.env_type.default_env_args()
Expand All @@ -121,6 +123,26 @@ def __init__(
self.latest_step_duration = 0
self.render_mode = render_mode

def _minimum_env_model(self) -> type[EnvironmentModel]:
"""Determine the minimum environment model required by the satellites."""
types = set(
sum(
[satellite.dyn_type._requires_env() for satellite in self.satellites],
[],
)
)
if len(types) == 1:
return list(types)[0]
for test_type in types:
if all([issubclass(test_type, other_type) for other_type in types]):
return test_type

# Else compose all types into a new class
class MinimumEnv(*types):
pass

return MinimumEnv

def _configure_logging(self, log_level, log_dir=None):
if isinstance(log_level, str):
log_level = log_level.upper()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class DynamicsModel(ABC):
"""

@classmethod
@property
def _requires_env(cls) -> list[type["EnvironmentModel"]]:
"""Define minimum EnvironmentModels for compatibility."""
return []
Expand All @@ -75,7 +74,7 @@ def __init__(
self.satellite = satellite
self.logger = self.satellite.logger.getChild(self.__class__.__name__)

for required in self._requires_env:
for required in self._requires_env():
if not issubclass(type(self.simulator.environment), required):
raise TypeError(
f"{self.simulator.environment} must be a subclass of {required} to "
Expand Down Expand Up @@ -130,7 +129,6 @@ class BasicDynamicsModel(DynamicsModel):
"""Basic Dynamics model with minimum necessary Basilisk components."""

@classmethod
@property
def _requires_env(cls) -> list[type["EnvironmentModel"]]:
return [environment.BasicEnvironmentModel]

Expand Down Expand Up @@ -977,9 +975,8 @@ class GroundStationDynModel(ImagingDynModel):
"""Model that connects satellite to environment ground stations."""

@classmethod
@property
def _requires_env(cls) -> list[type["EnvironmentModel"]]:
return super()._requires_env + [environment.GroundStationEnvModel]
return super()._requires_env() + [environment.GroundStationEnvModel]

def _init_dynamics_objects(self, **kwargs) -> None:
super()._init_dynamics_objects(**kwargs)
Expand Down
7 changes: 2 additions & 5 deletions src/bsk_rl/envs/general_satellite_tasking/simulation/fsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ class FSWModel(ABC):
"""

@classmethod
@property
def _requires_dyn(cls) -> list[type["DynamicsModel"]]:
"""Define minimum DynamicsModels for compatibility."""
return []
Expand All @@ -81,7 +80,7 @@ def __init__(
self.satellite = satellite
self.logger = self.satellite.logger.getChild(self.__class__.__name__)

for required in self._requires_dyn:
for required in self._requires_dyn():
if not issubclass(satellite.dyn_type, required):
raise TypeError(
f"{satellite.dyn_type} must be a subclass of {required} to "
Expand Down Expand Up @@ -205,7 +204,6 @@ class BasicFSWModel(FSWModel):
"""Basic FSW model with minimum necessary Basilisk components."""

@classmethod
@property
def _requires_dyn(cls) -> list[type["DynamicsModel"]]:
return [dynamics.BasicDynamicsModel]

Expand Down Expand Up @@ -560,9 +558,8 @@ class ImagingFSWModel(BasicFSWModel):
"""Extend FSW with instrument pointing and triggering control."""

@classmethod
@property
def _requires_dyn(cls) -> list[type["DynamicsModel"]]:
return super()._requires_dyn + [dynamics.ImagingDynModel]
return super()._requires_dyn() + [dynamics.ImagingDynModel]

@property
def c_hat_P(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from bsk_rl.envs.general_satellite_tasking.scenario.environment_features import (
StaticTargets,
)
from bsk_rl.envs.general_satellite_tasking.simulation import environment
from bsk_rl.utilities.initial_conditions import leo_orbit

oes_visible = leo_orbit.walker_delta(
Expand Down Expand Up @@ -53,8 +52,6 @@ def make_communication_env(oes, comm_type):
env = gym.make(
"GeneralSatelliteTasking-v1",
satellites=satellites,
env_type=environment.GroundStationEnvModel,
env_args=environment.GroundStationEnvModel.default_env_args(),
env_features=StaticTargets(n_targets=1000),
data_manager=data.UniqueImagingManager(),
communicator=comm_type(satellites),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
CityTargets,
StaticTargets,
)
from bsk_rl.envs.general_satellite_tasking.simulation import dynamics, environment, fsw
from bsk_rl.envs.general_satellite_tasking.simulation import dynamics, fsw
from bsk_rl.envs.general_satellite_tasking.utils.orbital import random_orbit


Expand All @@ -30,8 +30,6 @@ class ImageSat(
imageRateErrorRequirement=0.05,
),
),
env_type=environment.GroundStationEnvModel,
env_args=environment.GroundStationEnvModel.default_env_args(),
env_features=env_features,
data_manager=data.UniqueImagingManager(),
sim_rate=1.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
StaticTargets,
UniformNadirFeature,
)
from bsk_rl.envs.general_satellite_tasking.simulation import dynamics, environment, fsw
from bsk_rl.envs.general_satellite_tasking.simulation import dynamics, fsw
from bsk_rl.envs.general_satellite_tasking.utils.orbital import random_orbit

#########################
Expand Down Expand Up @@ -40,8 +40,6 @@ class ImageSat(
transmitterBaudRate=-1.0,
),
),
env_type=environment.GroundStationEnvModel,
env_args=environment.GroundStationEnvModel.default_env_args(),
env_features=StaticTargets(n_targets=1000),
data_manager=data.NoDataManager(),
sim_rate=1.0,
Expand Down Expand Up @@ -98,8 +96,6 @@ class ChargeSat(
storedCharge_Init=250_000,
),
),
env_type=environment.BasicEnvironmentModel,
env_args=environment.BasicEnvironmentModel.default_env_args(),
env_features=StaticTargets(n_targets=0),
data_manager=data.NoDataManager(),
sim_rate=1.0,
Expand Down Expand Up @@ -133,8 +129,6 @@ def make_env(self):
wheelSpeeds=[1000.0, -1000.0, 1000.0],
),
),
env_type=environment.BasicEnvironmentModel,
env_args=environment.BasicEnvironmentModel.default_env_args(),
env_features=StaticTargets(n_targets=0),
data_manager=data.NoDataManager(),
sim_rate=1.0,
Expand Down Expand Up @@ -201,8 +195,6 @@ class ImageSat(
transmitterBaudRate=-1.0,
),
),
env_type=environment.BasicEnvironmentModel,
env_args=environment.BasicEnvironmentModel.default_env_args(),
env_features=UniformNadirFeature(),
data_manager=data.NoDataManager(),
sim_rate=1.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from bsk_rl.envs.general_satellite_tasking.scenario.environment_features import (
StaticTargets,
)
from bsk_rl.envs.general_satellite_tasking.simulation import dynamics, environment, fsw
from bsk_rl.envs.general_satellite_tasking.simulation import dynamics, fsw
from bsk_rl.envs.general_satellite_tasking.utils.orbital import random_orbit


Expand Down Expand Up @@ -39,8 +39,6 @@ class ComposedPropSat(
"Explorer 1",
sat_args=ComposedPropSat.default_sat_args(oe=random_orbit),
),
env_type=environment.BasicEnvironmentModel,
env_args=environment.BasicEnvironmentModel.default_env_args(),
env_features=StaticTargets(n_targets=1000),
data_manager=data.NoDataManager(),
sim_rate=1.0,
Expand Down Expand Up @@ -81,8 +79,6 @@ class NormdPropSat(
"Sputnik",
sat_args=NormdPropSat.default_sat_args(oe=random_orbit(r_body=7000, alt=0)),
),
env_type=environment.BasicEnvironmentModel,
env_args=environment.BasicEnvironmentModel.default_env_args(),
env_features=StaticTargets(n_targets=0),
data_manager=data.NoDataManager(),
sim_rate=1.0,
Expand Down Expand Up @@ -111,8 +107,6 @@ class TimedSat(
"Voyager",
sat_args=TimedSat.default_sat_args(oe=random_orbit()),
),
env_type=environment.BasicEnvironmentModel,
env_args=environment.BasicEnvironmentModel.default_env_args(),
env_features=StaticTargets(n_targets=0),
data_manager=data.NoDataManager(),
sim_rate=1.0,
Expand Down Expand Up @@ -144,8 +138,6 @@ class TargetSat(
n_ahead_observe=2,
sat_args=TargetSat.default_sat_args(oe=random_orbit()),
),
env_type=environment.BasicEnvironmentModel,
env_args=environment.BasicEnvironmentModel.default_env_args(),
env_features=StaticTargets(n_targets=100),
data_manager=data.NoDataManager(),
sim_rate=1.0,
Expand Down Expand Up @@ -176,8 +168,6 @@ class EclipseSat(
obs_type=list,
sat_args=EclipseSat.default_sat_args(oe=random_orbit()),
),
env_type=environment.BasicEnvironmentModel,
env_args=environment.BasicEnvironmentModel.default_env_args(),
env_features=StaticTargets(n_targets=0),
data_manager=data.NoDataManager(),
sim_rate=1.0,
Expand Down Expand Up @@ -207,8 +197,6 @@ class GroundSat(
obs_type=list,
sat_args=GroundSat.default_sat_args(oe=random_orbit()),
),
env_type=environment.GroundStationEnvModel,
env_args=environment.GroundStationEnvModel.default_env_args(),
env_features=StaticTargets(n_targets=0),
data_manager=data.NoDataManager(),
sim_rate=1.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bsk_rl.envs.general_satellite_tasking.scenario.environment_features import (
StaticTargets,
)
from bsk_rl.envs.general_satellite_tasking.simulation import dynamics, environment, fsw
from bsk_rl.envs.general_satellite_tasking.simulation import dynamics, fsw
from bsk_rl.envs.general_satellite_tasking.utils.orbital import random_orbit


Expand All @@ -33,8 +33,6 @@ class ImageSat(
imageRateErrorRequirement=0.05,
),
),
env_type=environment.BasicEnvironmentModel,
env_args=environment.BasicEnvironmentModel.default_env_args(),
env_features=StaticTargets(n_targets=5000),
data_manager=data.NoDataManager(),
sim_rate=1.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bsk_rl.envs.general_satellite_tasking.scenario.environment_features import (
StaticTargets,
)
from bsk_rl.envs.general_satellite_tasking.simulation import dynamics, environment, fsw
from bsk_rl.envs.general_satellite_tasking.simulation import dynamics, fsw
from bsk_rl.envs.general_satellite_tasking.utils.orbital import random_orbit

###########################
Expand Down Expand Up @@ -48,8 +48,6 @@ class ImageSat(
storageInit=initial_storage,
),
),
env_type=environment.GroundStationEnvModel,
env_args=environment.GroundStationEnvModel.default_env_args(),
env_features=StaticTargets(n_targets=1000),
data_manager=data.NoDataManager(),
sim_rate=1.0,
Expand Down Expand Up @@ -92,8 +90,6 @@ class ImageSat(
storageInit=initial_storage,
),
),
env_type=environment.GroundStationEnvModel,
env_args=environment.GroundStationEnvModel.default_env_args(),
env_features=StaticTargets(n_targets=1000),
data_manager=data.NoDataManager(),
sim_rate=1.0,
Expand Down
Loading
Loading