Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request]: Method For Directly Sampling From The Posterior Predictive From Output #416

Open
TimothyWillard opened this issue Dec 9, 2024 · 1 comment
Labels
enhancement Request for improvement or addition of new feature(s). gempyor Concerns the Python core. medium priority Medium priority. post-processing Concern the post-processing.

Comments

@TimothyWillard
Copy link
Contributor

Label

enhancement, gempyor, post-processing

Priority Label

medium priority

Is your feature request related to a problem? Please describe.

This issue was originally reported by @MacdonaldJoshuaCaleb in GH-413.

For more complicated or atypical post-processing it is helpful to directly sample the posterior predictive distribution. This allows for operations to develop post-processing that takes the fitted parameters as input and produces a distribution of some derived quantity.

Is your feature request related to a new application, scenario round, pathogen? Please describe.

No response

Describe the solution you'd like

Here's the start of a function like that, also developed for flu scenarios, note that to really generalize this we would want to index by the parameter labels rather than just the ordering . Note that chains can be gotten from a h5 file like so, arviz is another (python) package that can read h5 files:

sampler = emcee.backends.HDFBackend(filename, read_only=True)
chains = sampler.get_chain()

these chains can then be fed to gempyor to simulate the model given a config

 def shuffle_params(chains, idx_array, intersect, keep_list, Num_samples = None, Num_seasons = None, Num_params = None):

    if Num_samples == None:
        Num_samples = 100
    if Num_seasons == None:
        Num_seasons = 3
    if Num_params == None:
        Num_params = 9
        
    samples = chains[-1,:,:]
    shuffled_samples = np.zeros([Num_samples, Num_params])
    shuffled_chains = np.zeros([chains.shape[0], Num_samples, Num_params])
    r0_seasons = []
    indices = []
    for k in range(Num_samples):
        r_season_idx = np.random.randint(0,len(intersect),Num_params)
        r_chain_idx = np.random.randint(0,chains.shape[1],Num_params)
        r0_seasons.append(keep_list[r_season_idx[0]])
        for j in range(Num_params):
            shuffled_samples[k,j] = samples[r_chain_idx[j],idx_array[r_season_idx[j]][j]]
            shuffled_chains[:,k,j] = chains[:,r_chain_idx[j],idx_array[r_season_idx[j]][j]]
            indices.append([r_chain_idx[j],idx_array[r_season_idx[j]][j]])
           


        
    return shuffled_chains, shuffled_samples, indices, np.array(r0_seasons)
######################################################################
# usage 
    gempyor_inference = GempyorInference(
                config_filepath=state_dst_config,
                run_id=run_id,
                prefix=None,
                first_sim_index=1,
                stoch_traj_flag=False,
                rng_seed=None,
                nslots=1,
                inference_filename_prefix="global/final/",  # usually for {global or chimeric}/{intermediate or final}
                inference_filepath_suffix="",  # usually for the slot_id
                out_run_id=None,  # if out_run_id is different from in_run_id, fill this
                out_prefix=None,  # if out_prefix is different from in_prefix, fill this
            # in case the data folder is on another directory
                autowrite_seir=False,
            )
        
           # generate a list of data frames from gempyor
            result = gempyor_inference.simulate_proposal(shuffled_samples[0])
@TimothyWillard TimothyWillard added enhancement Request for improvement or addition of new feature(s). gempyor Concerns the Python core. post-processing Concern the post-processing. medium priority Medium priority. labels Dec 9, 2024
@MacdonaldJoshuaCaleb
Copy link
Collaborator

fleshing this out with some more detail, here's a complete block of code that also allows for random sampling of initial conditions and has simulation filtering

def intersection_indices(list1, list2):
    result = []
    for i, element in enumerate(list1):
        if element in list2:
            result.append(i)
    return result

####################################################################################################

def filter_chains(chains, keep_list = None, strain = None, Num_seasons = None, Num_params = None):
    '''
    takes chain output and filters out 'bad' seasons, keep_seasons should be a list of the seasons
    that you don't want filtered out
    '''
    if Num_seasons == None:
        Num_seasons = 3
    if Num_params == None:
        Num_params = 9
        
    if strain == 'H1N1':
        seasons = ['15to16', '19to20', '23to24']
        if keep_list == None:
            keep_list = seasons
    elif strain == 'H3N2':
        seasons = ["14to15", "17to18", "22to23"]
        if keep_list == None:
            keep_list = ["14to15", "17to18", "22to23"]
    else:
        print("please select either 'H1N1' or 'H3N2' for variable strain")


    idx_1 = []
    idx_2 = []
    idx_3 = []
    for j in range(Num_params):
        idx_1.append(Num_seasons * j)
        idx_2.append(Num_seasons * j + 1)
        idx_3.append(Num_seasons * j + 2)
    idx_array = np.array([idx_1,idx_2,idx_3])
    intersect = intersection_indices(seasons,keep_list)
    idx_array = idx_array[intersect,:]
    return intersect, idx_array

####################################################################################################

def shuffle_params(chains, idx_array, intersect, keep_list, Num_samples = None, Num_seasons = None, Num_params = None):

    if Num_samples == None:
        Num_samples = 100
    if Num_seasons == None:
        Num_seasons = 3
    if Num_params == None:
        Num_params = 9
        
    samples = chains[-1,:,:]
    shuffled_samples = np.zeros([Num_samples, Num_params])
    shuffled_chains = np.zeros([chains.shape[0], Num_samples, Num_params])
    r0_seasons = []
    indices = []
    for k in range(Num_samples):
        r_season_idx = np.random.randint(0,len(intersect),Num_params)
        r_chain_idx = np.random.randint(0,chains.shape[1],Num_params)
        r0_seasons.append(keep_list[r_season_idx[0]])
        for j in range(Num_params):
            shuffled_samples[k,j] = samples[r_chain_idx[j],idx_array[r_season_idx[j]][j]]
            shuffled_chains[:,k,j] = chains[:,r_chain_idx[j],idx_array[r_season_idx[j]][j]]
            indices.append([r_chain_idx[j],idx_array[r_season_idx[j]][j]])
           


        
    return shuffled_chains, shuffled_samples, indices, np.array(r0_seasons)

####################################################################################################

def get_seasons_keep(strain, fip):
    if strain == 'H1N1':
        data = pd.read_csv(r'~/Documents/seasons_keep_H1.csv')
    if strain == 'H3N2':
        data = pd.read_csv(r'~/Documents/seasons_keep_H3.csv')
    fips = []
    for j in range(data.shape[0]):
        temp = str(data['state_fips'].values[j])
        if len(temp) < 5:
            temp = '0' + temp
        fips.append(temp)
    data['state'] = ['Alabama','Alaska','Arizona','Arkansas','California','Colorado','Connecticut','Delaware','District of Columbia',
        'Florida','Georgia','Hawaii','Idaho','Illinois','Indiana','Iowa','Kansas','Kentucky','Louisiana','Maine','Maryland',
        'Massachusetts','Michigan','Minnesota','Mississippi','Missouri','Montana','Nebraska','Nevada','New Hampshire',
        'New Jersey','New Mexico','New York','North Carolina','North Dakota','Ohio','Oklahoma','Oregon','Pennsylvania',
        'Rhode Island','South Carolina','South Dakota','Tennessee','Texas','Utah','Vermont','Virginia','Washington',
        'West Virginia','Wisconsin','Wyoming']

    idx = np.where(np.array(fips) == fip)[0][0]

    seasons_keep = []
    seasons = data['subpops_keep'].values[idx][2:-1].split(',')
    for j in range(len(seasons)):
        seasons_keep.append(seasons[j].split('"')[1])
    state_name = data['state'].values[idx]
    return seasons_keep, state_name, fips   

################################################################################################

def plot_sample(chains,samples, n_per = None):
    if n_per == None:
        n_per = chains.shape[1]

    for k in range(chains.shape[2]):
        # dip, p_value = diptest(samples[:,k])
        # print('multimodality test p-value (null is unimodal):',np.round(p_value,4))
        sns.histplot(samples[:,k])
        plt.show()
        for j in range(n_per):
            plt.plot(chains[:,j,k],color='blue')
        plt.show()
        print('######################')



#####################################

import dill
data_hosp = hosps = pd.read_csv(r'~/Documents/GitHub/Flu_USA/model_input/SMH_Flu_2024_R1_allflu_medVax_H3_training_multiseason_emcee_difflocvarseas/us_data_Flu_2024_R1_allflu_training_multiseason_emcee_difflocvarseas.csv')
keep_list, state_name, fips = get_seasons_keep('H3N2','01000')
lbs = []
ubs = []
t_min_idx = []
t_max_idx = []
locs = ['Alabama','Alaska','Arizona','Arkansas','California','Colorado','Connecticut','Delaware','District of Columbia',
        'Florida','Georgia','Hawaii','Idaho','Illinois','Indiana','Iowa','Kansas','Kentucky','Louisiana','Maine','Maryland',
        'Massachusetts','Michigan','Minnesota','Mississippi','Missouri','Montana','Nebraska','Nevada','New Hampshire',
        'New Jersey','New Mexico','New York','North Carolina','North Dakota','Ohio','Oklahoma','Oregon','Pennsylvania',
        'Rhode Island','South Carolina','South Dakota','Tennessee','Texas','Utah','Vermont','Virginia','Washington',
        'West Virginia','Wisconsin','Wyoming']

from datetime import datetime


# date_min = datetime.strptime(dates_[tmin_idx[0]], '%Y-%m-%d')
# date_max =datetime.strptime(dates_[tmax_idx[0]], '%Y-%m-%d')

for k in range(len(locs)):
    
    temp = data_hosp[data_hosp['source'] == locs[k]]
    # seasons = temp['season'].unique()[:-1]
    maxes = []
    t_maxes = []
    seasons_keep,state_name, fips = get_seasons_keep('H3N2', fips[k])
    seasons = []
    for j in range(len(seasons_keep)):
        split = seasons_keep[j].split("to")
        seasons.append('20' + split[0] + '-' + split[1])
    for j in range(len(seasons)):
        temp_season = temp[temp['season'] == seasons[j]]
        maxes.append(int(np.ceil(temp_season['incidH'].max())))
        t_maxes.append(temp_season['incidH'].argmax())
    lbs.append(np.ceil(0.25 * np.min(maxes)))
    t_min_idx.append(np.min(t_maxes)-8)
    t_max_idx.append(np.max(t_maxes)+8)
    ubs.append(np.ceil(1 * np.max(maxes)))
dates_ = temp_season['date'].values
  
keep_bounds = pd.DataFrame([locs,fips,lbs,ubs,t_min_idx,t_max_idx]).T
keep_bounds.columns = ['state','fips','lb','ub','time_lb','time_ub']

shutil.rmtree("model_output/", ignore_errors=True)
shutil.rmtree("model_output/", ignore_errors=True)
shutil.rmtree("/model_output/", ignore_errors=True)
states = [sp.split("_")[0] for sp in subpop_names]
states = list(set(states))
nsamples=100
chain_index = -1 # -1 for last
all_results = {}
all_params = {}
all_IC = {}

fips = list(np.array(fips)[19:])
for i, sp in enumerate(fips):
    lb = keep_bounds[keep_bounds['fips'] == sp]['lb'].values[0]
    ub = keep_bounds[keep_bounds['fips'] == sp]['ub'].values[0]
    date_min = datetime.strptime(dates_[keep_bounds['time_lb'].values[i]], '%Y-%m-%d')
    date_max =datetime.strptime(dates_[keep_bounds['time_ub'].values[i]], '%Y-%m-%d')
    
    # print(f"Subpop: {i} {sp}", end=" ")
    # try:
    filename = f"SMH_Flu_2024_R1_allflu_medVax_H3_training_multiseason_emcee_difflocvarseas_{sp}-*.h5"
    # find a file that matches the pattern
    import glob
    filename = glob.glob(filename)
    if len(filename) == 0:
        print(f"File {filename} does not exist")
        continue
    filename = filename[0]

    # check if the file exists

    sampler = emcee.backends.HDFBackend(filename, read_only=True)
    chains = sampler.get_chain()


  
    # gempyor_inference.set_save(True)
    run_id = "flu_" + sp
    
    # scenario config
    state_src_config = f"config_SMH_Flu_2024_R1_allflu_highVax_H3_projection.yml"

    results_list = []
    params_list = []
    IC_list = []
    count_trys = 0
    keep_list, state_name, fips = get_seasons_keep('H3N2',sp)
    print("################################################")
    print('Working on ' + state_name)
    print("################################################")
    count = 0
    for k in range(nsamples):
        flag = 0
        while flag == 0:
            count_trys = count_trys + 1
            keep_list, state_name, fips = get_seasons_keep('H3N2',sp)
            intersect, ids= filter_chains(chains, strain = 'H3N2',keep_list=keep_list)
            shuffled_chains, shuffled_samples, indicies, seasons = shuffle_params(chains = chains, idx_array = ids, intersect = intersect, keep_list = keep_list, Num_samples = 1)
    
        # state specific scenario config (to be generated)
            state_dst_config = f"config_SMH_Flu_2024_R1_allflu_highVax_H3_projection_{sp}.yml"

            shutil.copy(state_src_config, state_dst_config)

            # update the config file to use the correct subpop, replace SUBPOP_PLACEHOLDER with this_subpop
            with open(state_dst_config, 'r') as file :
                filedata = file.read()
            filedata = filedata.replace('SUBPOP_PLACEHOLDER_A', '"' + sp + '"')
            with open(state_dst_config, 'w') as file:
                file.write(filedata)

        
            with open(state_dst_config, 'r') as file :
                filedata = file.read()
            filedata = filedata.replace('SUBPOP_PLACEHOLDER_B', '"' + seasons[0] + '"')
            with open(state_dst_config, 'w') as file:
                file.write(filedata)

    # 1) modify the config with the sapmpled subpop for each season and call initial condition plugin
    # 2) call gempyor_inference.get_logloss_as_single_number, [(shuffled_samples[i, :],) for i in range(shuffled_samples.shape[0])] 
    # for each sample
    
            gempyor_inference = GempyorInference(
                config_filepath=state_dst_config,
                run_id=run_id,
                prefix=None,
                first_sim_index=1,
                stoch_traj_flag=False,
                rng_seed=None,
                nslots=1,
                inference_filename_prefix="global/final/",  # usually for {global or chimeric}/{intermediate or final}
                inference_filepath_suffix="",  # usually for the slot_id
                out_run_id=None,  # if out_run_id is different from in_run_id, fill this
                out_prefix=None,  # if out_prefix is different from in_prefix, fill this
            # in case the data folder is on another directory
                autowrite_seir=False,
            )


            gempyor_inference.set_save(False)
        

            result = gempyor_inference.simulate_proposal(shuffled_samples[0])

            if result['incidH'].max() > lb and result['incidH'].max() < ub:  #and date_min < result.index[result['incidH'].argmax()] and date_max > result.index[result['incidH'].argmax()]:
                flag = 1
                count = count + 1
                print("#########################################################")
                print(state_name +', ' + str(count) + ' simulations accepted, acceptance rate: ', np.round(count/count_trys,2))
                print("#########################################################")
                results_list.append(result)
                params_list.append(shuffled_samples[0])
                IC_list.append(seasons[0])
    with open("./scenario_output/all_results_H3_HiVax_" + state_name + ".pkl", "wb") as f:
        dill.dump(results_list, f)

    with open("./scenario_output/all_params_H3_HiVax_" + state_name + ".pkl", "wb") as f:
        dill.dump(params_list, f)

    with open("./scenario_output/all_IC_H3_HiVax_" + state_name + ".pkl", "wb") as f:
        dill.dump(IC_list, f)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Request for improvement or addition of new feature(s). gempyor Concerns the Python core. medium priority Medium priority. post-processing Concern the post-processing.
Projects
None yet
Development

No branches or pull requests

2 participants