diff --git a/src/dcegm/pre_processing/state_space.py b/src/dcegm/pre_processing/state_space.py index e12956fd..12b5a2f9 100644 --- a/src/dcegm/pre_processing/state_space.py +++ b/src/dcegm/pre_processing/state_space.py @@ -185,19 +185,18 @@ def create_state_space(options): (state_space_wo_exog_full, exog_state_space_full), axis=1 ) - dtype_state_space = get_smallest_uint_type(state_space.shape[0]) - max_int_state_space = np.iinfo(dtype_state_space).max + state_space = array_with_smallest_uint_dtype(state_space) # Create indexer array that maps states to indexes - map_state_to_index = create_indexer_for_space( - state_space, dtype_state_space, max_int_state_space - ) + map_state_to_index = create_indexer_for_space(state_space) state_space_dict = { - key: state_space[:, i] + key: array_with_smallest_uint_dtype(state_space[:, i]) for i, key in enumerate(states_names_without_exog + exog_states_names) } + exog_state_space = array_with_smallest_uint_dtype(exog_state_space) + return ( state_space, state_space_dict, @@ -347,9 +346,7 @@ def create_state_choice_space( idx += 1 state_choice_space_final = state_choice_space[:idx] - map_state_choice_to_index = create_indexer_for_space( - state_choice_space_final, dtype_state_choice_space, max_int_state_choice_space - ) + map_state_choice_to_index = create_indexer_for_space(state_choice_space_final) return ( state_choice_space_final, @@ -483,7 +480,7 @@ def add_endog_states(id_endog_state): return add_endog_states -def create_indexer_for_space(space, dtype_state_space, max_int_state_space): +def create_indexer_for_space(space): """Creates indexer for spaces. We need to think about which datatype we want to use and what is our invalid number. @@ -535,8 +532,13 @@ def check_options(options): return options +def array_with_smallest_uint_dtype(array): + """Return array with the smallest unsigned integer dtype.""" + return array.astype(get_smallest_uint_type(array.max())) + + def get_smallest_uint_type(n_values): uint_types = [np.uint8, np.uint16, np.uint32, np.uint64] for dtype in uint_types: - if np.iinfo(dtype).max >= n_values: + if np.iinfo(dtype).max > n_values: return dtype