Skip to content

Commit

Permalink
Issue #58: Expose state vector element names
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark2000 committed Jan 30, 2024
1 parent f5b0e76 commit 93ee5a1
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 12 deletions.
6 changes: 4 additions & 2 deletions examples/general_satellite_tasking/satellite_customization.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ class CustomDynModel(dynamics.ImagingDynModel, dynamics.LOSCommDynModel):
sat_args,
variable_interval=True,
)
# The composed satellite action space returns a human-readable action map

# The composed satellite state and action space returns a human-readable map
print("Actions:", satellite.action_map)

# Make the environment with Gymnasium
Expand Down Expand Up @@ -204,8 +205,9 @@ class CustomDynModel(dynamics.ImagingDynModel, dynamics.LOSCommDynModel):
env.action_space.sample() # Task random actions
)

# Show the custom normalized observation vector
# Show the custom normalized observation vector and an array of what the states correspond to
print("\tObservation:", observation)
print("\tStates:", satellite.obs_array_keys)

# Using the composed satellite features also provides a human-readable state:
for k, v in env.satellite.obs_dict.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,14 @@ def obs_dict(self):
@property
def obs_ndarray(self):
"""Numpy vector observation format."""
return vectorize_nested_dict(self.obs_dict)
_, obs = vectorize_nested_dict(self.obs_dict)
return obs

@property
def obs_array_keys(self):
"""Utility to get the keys of the obs_ndarray."""
keys, _ = vectorize_nested_dict(self.obs_dict)
return keys

@property
def obs_list(self):
Expand Down
15 changes: 11 additions & 4 deletions src/bsk_rl/envs/general_satellite_tasking/utils/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,25 @@ def collect_default_args(object: object) -> dict[str, Any]:
return defaults


def vectorize_nested_dict(dictionary: dict) -> np.ndarray:
"""Flattens a dictionary of dicts, arrays, and scalars into a single vector."""
def vectorize_nested_dict(dictionary: dict) -> tuple[list[str], np.ndarray]:
"""Flattens a dictionary of dictionaries, arrays, and scalars into a vector."""
keys = list(dictionary.keys())
values = list(dictionary.values())
for i, value in enumerate(values):
if isinstance(value, np.ndarray):
values[i] = value.flatten()
keys[i] = [keys[i] + f"[{j}]" for j in range(len(value.flatten()))]
elif isinstance(value, list):
keys[i] = [keys[i] + f"[{j}]" for j in range(len(value))]
elif isinstance(value, (float, int)):
values[i] = [value]
keys[i] = [keys[i]]
elif isinstance(value, dict):
values[i] = vectorize_nested_dict(value)
prepend = keys[i]
keys[i], values[i] = vectorize_nested_dict(value)
keys[i] = [prepend + "." + key for key in keys[i]]

return np.concatenate(values)
return list(np.concatenate(keys)), np.concatenate(values)


def aliveness_checker(func: Callable[..., bool]) -> Callable[..., bool]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,24 @@ class C13(self.C1, self.C3):


@pytest.mark.parametrize(
"input,output",
"input,outkeys,outvec",
[
({"a": np.array([1]), "b": 2, "c": [3]}, np.array([1, 2, 3])),
({"a": {"b": 1, "c": 2}, "d": 3}, np.array([1, 2, 3])),
(
{"alpha": np.array([1]), "b": 2, "c": [3]},
["alpha[0]", "b", "c[0]"],
np.array([1, 2, 3]),
),
(
{"a": {"b": 1, "charlie": 2}, "d": 3},
["a.b", "a.charlie", "d"],
np.array([1, 2, 3]),
),
],
)
def test_vectorize_nested_dict(input, output):
assert np.equal(output, functional.vectorize_nested_dict(input)).all()
def test_vectorize_nested_dict(input, outkeys, outvec):
keys, vec = functional.vectorize_nested_dict(input)
assert np.equal(outvec, vec).all()
assert outkeys == keys


class TestAlivenessChecker:
Expand Down

0 comments on commit 93ee5a1

Please sign in to comment.