Skip to content

Commit

Permalink
Allow only for non-negative entries in state space
Browse files Browse the repository at this point in the history
  • Loading branch information
segsell committed Jun 6, 2024
1 parent 4d248d7 commit 42ca95c
Showing 1 changed file with 9 additions and 34 deletions.
43 changes: 9 additions & 34 deletions src/dcegm/pre_processing/state_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def create_state_space_and_choice_objects(
state_space_names=states_names_without_exog + exog_states_names,
)

model_structure = {
model_structure_raw = {
"state_space": state_space,
"choice_range": jnp.asarray(options["state_space"]["choices"]),
"state_space_dict": state_space_dict,
Expand All @@ -70,8 +70,8 @@ def create_state_space_and_choice_objects(
"map_state_choice_to_parent_state": map_state_choice_to_parent_state,
"map_state_choice_to_child_states": map_state_choice_to_child_states,
}
# Return model structure with reduced dtypes
return jax.tree.map(create_array_with_smallest_int_dtype, model_structure)

return jax.tree.map(create_array_with_smallest_int_dtype, model_structure_raw)


def test_state_space_objects(
Expand Down Expand Up @@ -136,8 +136,6 @@ def create_state_space(options):
n_periods = state_space_options["n_periods"]
n_choices = len(state_space_options["choices"])

need_signed_dtype = _contains_negative_value(state_space_options)

(
add_endog_state_func,
endog_states_names,
Expand Down Expand Up @@ -273,10 +271,7 @@ def create_state_choice_space(
state_space_names = states_names_without_exog + exog_state_names
n_periods = state_space_options["n_periods"]

need_signed_dtype = _contains_negative_value(state_space_options)

dtype_exog_state_space = get_smallest_int_type(n_exog_states)

dtype_state_choice_space = get_smallest_int_type(n_states * n_choices)
max_int_state_choice_space = np.iinfo(dtype_state_choice_space).max

Expand Down Expand Up @@ -487,18 +482,12 @@ def add_endog_states(id_endog_state):


def create_indexer_for_space(space):
"""Create indexer for space.
We need to think about which datatype we want to use and what is our invalid number.
Who doesn't like -99999999? Will anybody ever have 10 Billion state choices.
"""
"""Create indexer for space."""

# Indexer has always unsigned data type with integers starting at zero
data_type = get_smallest_int_type(space.shape[0])
max_value = np.iinfo(data_type).max

# Account for negative entries
max_var_values = np.max(space, axis=0) - np.min(space, axis=0)

map_vars_to_index = np.full(
Expand Down Expand Up @@ -549,30 +538,16 @@ def create_array_with_smallest_int_dtype(arr):
return arr.astype(get_smallest_int_type(arr.max()))
else:
return arr
return arr

else:
return arr


def get_smallest_int_type(n_values):
"""Return the smallest integer type that can hold n_values."""

int_types = [np.uint8, np.uint16, np.uint32, np.uint64]
uint_types = [np.uint8, np.uint16, np.uint32, np.uint64]

for dtype in int_types:
for dtype in uint_types:
if np.iinfo(dtype).max > n_values:
return dtype


def _contains_negative_value(d):
"""Check recursively if (nested) dictionary contains any negative values."""

if isinstance(d, dict):
return any(_contains_negative_value(v) for v in d.values())
elif isinstance(d, list):
# Convert list to numpy array then check for negatives
return _contains_negative_value(np.array(d))
elif isinstance(d, np.ndarray):
return np.any(d < 0)
elif isinstance(d, (int, float)):
return d < 0

return False

0 comments on commit 42ca95c

Please sign in to comment.