diff --git a/src/dcegm/pre_processing/state_space.py b/src/dcegm/pre_processing/state_space.py index 2158d33d..a34192dd 100644 --- a/src/dcegm/pre_processing/state_space.py +++ b/src/dcegm/pre_processing/state_space.py @@ -56,7 +56,7 @@ def create_state_space_and_choice_objects( state_space_names=states_names_without_exog + exog_states_names, ) - model_structure = { + model_structure_raw = { "state_space": state_space, "choice_range": jnp.asarray(options["state_space"]["choices"]), "state_space_dict": state_space_dict, @@ -70,8 +70,8 @@ def create_state_space_and_choice_objects( "map_state_choice_to_parent_state": map_state_choice_to_parent_state, "map_state_choice_to_child_states": map_state_choice_to_child_states, } - # Return model structure with reduced dtypes - return jax.tree.map(create_array_with_smallest_int_dtype, model_structure) + + return jax.tree.map(create_array_with_smallest_int_dtype, model_structure_raw) def test_state_space_objects( @@ -136,8 +136,6 @@ def create_state_space(options): n_periods = state_space_options["n_periods"] n_choices = len(state_space_options["choices"]) - need_signed_dtype = _contains_negative_value(state_space_options) - ( add_endog_state_func, endog_states_names, @@ -273,10 +271,7 @@ def create_state_choice_space( state_space_names = states_names_without_exog + exog_state_names n_periods = state_space_options["n_periods"] - need_signed_dtype = _contains_negative_value(state_space_options) - dtype_exog_state_space = get_smallest_int_type(n_exog_states) - dtype_state_choice_space = get_smallest_int_type(n_states * n_choices) max_int_state_choice_space = np.iinfo(dtype_state_choice_space).max @@ -487,18 +482,12 @@ def add_endog_states(id_endog_state): def create_indexer_for_space(space): - """Create indexer for space. - - We need to think about which datatype we want to use and what is our invalid number. - Who doesn't like -99999999? Will anybody ever have 10 Billion state choices. - - """ + """Create indexer for space.""" # Indexer has always unsigned data type with integers starting at zero data_type = get_smallest_int_type(space.shape[0]) max_value = np.iinfo(data_type).max - # Account for negative entries max_var_values = np.max(space, axis=0) - np.min(space, axis=0) map_vars_to_index = np.full( @@ -549,30 +538,16 @@ def create_array_with_smallest_int_dtype(arr): return arr.astype(get_smallest_int_type(arr.max())) else: return arr - return arr + + else: + return arr def get_smallest_int_type(n_values): """Return the smallest integer type that can hold n_values.""" - int_types = [np.uint8, np.uint16, np.uint32, np.uint64] + uint_types = [np.uint8, np.uint16, np.uint32, np.uint64] - for dtype in int_types: + for dtype in uint_types: if np.iinfo(dtype).max > n_values: return dtype - - -def _contains_negative_value(d): - """Check recursively if (nested) dictionary contains any negative values.""" - - if isinstance(d, dict): - return any(_contains_negative_value(v) for v in d.values()) - elif isinstance(d, list): - # Convert list to numpy array then check for negatives - return _contains_negative_value(np.array(d)) - elif isinstance(d, np.ndarray): - return np.any(d < 0) - elif isinstance(d, (int, float)): - return d < 0 - - return False