Skip to content

Commit

Permalink
Array conversions.
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxBlesch committed Jun 1, 2024
1 parent e07432b commit 5fc4254
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions src/dcegm/pre_processing/state_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,19 +185,18 @@ def create_state_space(options):
(state_space_wo_exog_full, exog_state_space_full), axis=1
)

dtype_state_space = get_smallest_uint_type(state_space.shape[0])
max_int_state_space = np.iinfo(dtype_state_space).max
state_space = array_with_smallest_uint_dtype(state_space)

# Create indexer array that maps states to indexes
map_state_to_index = create_indexer_for_space(
state_space, dtype_state_space, max_int_state_space
)
map_state_to_index = create_indexer_for_space(state_space)

state_space_dict = {
key: state_space[:, i]
key: array_with_smallest_uint_dtype(state_space[:, i])
for i, key in enumerate(states_names_without_exog + exog_states_names)
}

exog_state_space = array_with_smallest_uint_dtype(exog_state_space)

return (
state_space,
state_space_dict,
Expand Down Expand Up @@ -347,9 +346,7 @@ def create_state_choice_space(
idx += 1

state_choice_space_final = state_choice_space[:idx]
map_state_choice_to_index = create_indexer_for_space(
state_choice_space_final, dtype_state_choice_space, max_int_state_choice_space
)
map_state_choice_to_index = create_indexer_for_space(state_choice_space_final)

return (
state_choice_space_final,
Expand Down Expand Up @@ -483,7 +480,7 @@ def add_endog_states(id_endog_state):
return add_endog_states


def create_indexer_for_space(space, dtype_state_space, max_int_state_space):
def create_indexer_for_space(space):
"""Creates indexer for spaces.
We need to think about which datatype we want to use and what is our invalid number.
Expand Down Expand Up @@ -535,8 +532,13 @@ def check_options(options):
return options


def array_with_smallest_uint_dtype(array):
"""Return array with the smallest unsigned integer dtype."""
return array.astype(get_smallest_uint_type(array.max()))


def get_smallest_uint_type(n_values):
uint_types = [np.uint8, np.uint16, np.uint32, np.uint64]
for dtype in uint_types:
if np.iinfo(dtype).max >= n_values:
if np.iinfo(dtype).max > n_values:
return dtype

0 comments on commit 5fc4254

Please sign in to comment.