Skip to content

Commit

Permalink
Add state indicator.
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewGerber committed Sep 22, 2024
1 parent 66ae40c commit 4198689
Showing 1 changed file with 84 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,64 @@ def extract(
"""


class StateDimensionIndicator(ABC):
class StateIndicator(ABC):
"""
Abstract state-dimension indicator.
Abstract state indicator for one-hot feature encoding.
"""

@abstractmethod
def __str__(
self
) -> str:
"""
Get string.
:return: String.
"""

@abstractmethod
def get_range(
self
) -> List[Any]:
"""
Get the range (possible values) of the current indicator.
:return: Range of values.
"""

@abstractmethod
def get_value(
self,
state_vector: np.ndarray
) -> Any:
"""
Get the value of the current indicator for a state.
:param state_vector: State vector.
:return: Value, which must be in the range returned by `get_range`.
"""


class StateLambdaIndicator(StateIndicator):
"""
State indicator via lambda function.
"""

def __init__(
self,
dimension: Optional[int]
function: Callable[[np.ndarray], Any],
function_range: List[Any]
):
"""
Initialize the indicator.
:param dimension: Dimension, or None for an indicator that is not based on a value of the state vector.
:param function: Function to apply to states.
:param function_range: Range of function.
"""

self.dimension = dimension
self.function = function
self.function_range = function_range

@abstractmethod
def __str__(
self
) -> str:
Expand All @@ -59,7 +99,8 @@ def __str__(
:return: String.
"""

@abstractmethod
return '<function>'

def get_range(
self
) -> List[Any]:
Expand All @@ -69,22 +110,43 @@ def get_range(
:return: Range of values.
"""

@abstractmethod
return self.function_range

def get_value(
self,
state: np.ndarray
state_vector: np.ndarray
) -> Any:
"""
Get the value of the current indicator for a state.
:param state: State vector.
:return: Value.
:param state_vector: State vector.
:return: Value, which must be in the range returned by `get_range`.
"""

return self.function(state_vector)


class StateDimensionIndicator(StateIndicator, ABC):
"""
State-dimension indicator.
"""

def __init__(
self,
dimension: int
):
"""
Initialize the indicator.
:param dimension: Dimension.
"""

self.dimension = dimension


class StateDimensionSegment(StateDimensionIndicator):
"""
Segment of a state dimension.
Indicates a segment of a state dimension.
"""

@staticmethod
Expand Down Expand Up @@ -146,18 +208,17 @@ def get_range(

def get_value(
self,
state: np.ndarray
state_vector: np.ndarray
) -> Any:
"""
Get the value of the current indicator for a state.
:param state: State vector.
:param state_vector: State vector.
:return: Value.
"""

assert self.dimension is not None
dimension_value = float(state_vector[self.dimension])

dimension_value = float(state[self.dimension])
above_low = self.low is None or dimension_value >= self.low
below_high = self.high is None or dimension_value < self.high

Expand All @@ -171,14 +232,14 @@ class StateDimensionLambda(StateDimensionIndicator):

def __init__(
self,
dimension: Optional[int],
function: Callable[[Optional[float]], Any],
dimension: int,
function: Callable[[float], Any],
function_range: List[Any]
):
"""
Initialize the segment.
:param dimension: Dimension, or None for an indicator that is not based on a value of the state vector.
:param dimension: Dimension.
:param function: Function to apply to values in the given dimension.
:param function_range: Range of function.
"""
Expand Down Expand Up @@ -210,21 +271,16 @@ def get_range(self) -> List[Any]:

def get_value(
self,
state: np.ndarray
state_vector: np.ndarray
) -> Any:
"""
Get the value of the current indicator for a state.
:param state: State vector.
:param state_vector: State vector.
:return: Value.
"""

if self.dimension is None:
dimension_value = None
else:
dimension_value = float(state[self.dimension])

return self.function(dimension_value)
return self.function(float(state_vector[self.dimension]))


class OneHotStateIndicatorFeatureInteracter:
Expand Down Expand Up @@ -263,7 +319,7 @@ def interact(

def __init__(
self,
indicators: List[StateDimensionIndicator]
indicators: List[StateIndicator]
):
"""
Initialize the interacter.
Expand Down

0 comments on commit 4198689

Please sign in to comment.