-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature request]: Method For Directly Sampling From The Posterior Predictive From Output #416
Labels
enhancement
Request for improvement or addition of new feature(s).
gempyor
Concerns the Python core.
medium priority
Medium priority.
post-processing
Concern the post-processing.
Milestone
Comments
TimothyWillard
added
enhancement
Request for improvement or addition of new feature(s).
gempyor
Concerns the Python core.
post-processing
Concern the post-processing.
medium priority
Medium priority.
labels
Dec 9, 2024
fleshing this out with some more detail, here's a complete block of code that also allows for random sampling of initial conditions and has simulation filtering def intersection_indices(list1, list2):
result = []
for i, element in enumerate(list1):
if element in list2:
result.append(i)
return result
####################################################################################################
def filter_chains(chains, keep_list = None, strain = None, Num_seasons = None, Num_params = None):
'''
takes chain output and filters out 'bad' seasons, keep_seasons should be a list of the seasons
that you don't want filtered out
'''
if Num_seasons == None:
Num_seasons = 3
if Num_params == None:
Num_params = 9
if strain == 'H1N1':
seasons = ['15to16', '19to20', '23to24']
if keep_list == None:
keep_list = seasons
elif strain == 'H3N2':
seasons = ["14to15", "17to18", "22to23"]
if keep_list == None:
keep_list = ["14to15", "17to18", "22to23"]
else:
print("please select either 'H1N1' or 'H3N2' for variable strain")
idx_1 = []
idx_2 = []
idx_3 = []
for j in range(Num_params):
idx_1.append(Num_seasons * j)
idx_2.append(Num_seasons * j + 1)
idx_3.append(Num_seasons * j + 2)
idx_array = np.array([idx_1,idx_2,idx_3])
intersect = intersection_indices(seasons,keep_list)
idx_array = idx_array[intersect,:]
return intersect, idx_array
####################################################################################################
def shuffle_params(chains, idx_array, intersect, keep_list, Num_samples = None, Num_seasons = None, Num_params = None):
if Num_samples == None:
Num_samples = 100
if Num_seasons == None:
Num_seasons = 3
if Num_params == None:
Num_params = 9
samples = chains[-1,:,:]
shuffled_samples = np.zeros([Num_samples, Num_params])
shuffled_chains = np.zeros([chains.shape[0], Num_samples, Num_params])
r0_seasons = []
indices = []
for k in range(Num_samples):
r_season_idx = np.random.randint(0,len(intersect),Num_params)
r_chain_idx = np.random.randint(0,chains.shape[1],Num_params)
r0_seasons.append(keep_list[r_season_idx[0]])
for j in range(Num_params):
shuffled_samples[k,j] = samples[r_chain_idx[j],idx_array[r_season_idx[j]][j]]
shuffled_chains[:,k,j] = chains[:,r_chain_idx[j],idx_array[r_season_idx[j]][j]]
indices.append([r_chain_idx[j],idx_array[r_season_idx[j]][j]])
return shuffled_chains, shuffled_samples, indices, np.array(r0_seasons)
####################################################################################################
def get_seasons_keep(strain, fip):
if strain == 'H1N1':
data = pd.read_csv(r'~/Documents/seasons_keep_H1.csv')
if strain == 'H3N2':
data = pd.read_csv(r'~/Documents/seasons_keep_H3.csv')
fips = []
for j in range(data.shape[0]):
temp = str(data['state_fips'].values[j])
if len(temp) < 5:
temp = '0' + temp
fips.append(temp)
data['state'] = ['Alabama','Alaska','Arizona','Arkansas','California','Colorado','Connecticut','Delaware','District of Columbia',
'Florida','Georgia','Hawaii','Idaho','Illinois','Indiana','Iowa','Kansas','Kentucky','Louisiana','Maine','Maryland',
'Massachusetts','Michigan','Minnesota','Mississippi','Missouri','Montana','Nebraska','Nevada','New Hampshire',
'New Jersey','New Mexico','New York','North Carolina','North Dakota','Ohio','Oklahoma','Oregon','Pennsylvania',
'Rhode Island','South Carolina','South Dakota','Tennessee','Texas','Utah','Vermont','Virginia','Washington',
'West Virginia','Wisconsin','Wyoming']
idx = np.where(np.array(fips) == fip)[0][0]
seasons_keep = []
seasons = data['subpops_keep'].values[idx][2:-1].split(',')
for j in range(len(seasons)):
seasons_keep.append(seasons[j].split('"')[1])
state_name = data['state'].values[idx]
return seasons_keep, state_name, fips
################################################################################################
def plot_sample(chains,samples, n_per = None):
if n_per == None:
n_per = chains.shape[1]
for k in range(chains.shape[2]):
# dip, p_value = diptest(samples[:,k])
# print('multimodality test p-value (null is unimodal):',np.round(p_value,4))
sns.histplot(samples[:,k])
plt.show()
for j in range(n_per):
plt.plot(chains[:,j,k],color='blue')
plt.show()
print('######################')
#####################################
import dill
data_hosp = hosps = pd.read_csv(r'~/Documents/GitHub/Flu_USA/model_input/SMH_Flu_2024_R1_allflu_medVax_H3_training_multiseason_emcee_difflocvarseas/us_data_Flu_2024_R1_allflu_training_multiseason_emcee_difflocvarseas.csv')
keep_list, state_name, fips = get_seasons_keep('H3N2','01000')
lbs = []
ubs = []
t_min_idx = []
t_max_idx = []
locs = ['Alabama','Alaska','Arizona','Arkansas','California','Colorado','Connecticut','Delaware','District of Columbia',
'Florida','Georgia','Hawaii','Idaho','Illinois','Indiana','Iowa','Kansas','Kentucky','Louisiana','Maine','Maryland',
'Massachusetts','Michigan','Minnesota','Mississippi','Missouri','Montana','Nebraska','Nevada','New Hampshire',
'New Jersey','New Mexico','New York','North Carolina','North Dakota','Ohio','Oklahoma','Oregon','Pennsylvania',
'Rhode Island','South Carolina','South Dakota','Tennessee','Texas','Utah','Vermont','Virginia','Washington',
'West Virginia','Wisconsin','Wyoming']
from datetime import datetime
# date_min = datetime.strptime(dates_[tmin_idx[0]], '%Y-%m-%d')
# date_max =datetime.strptime(dates_[tmax_idx[0]], '%Y-%m-%d')
for k in range(len(locs)):
temp = data_hosp[data_hosp['source'] == locs[k]]
# seasons = temp['season'].unique()[:-1]
maxes = []
t_maxes = []
seasons_keep,state_name, fips = get_seasons_keep('H3N2', fips[k])
seasons = []
for j in range(len(seasons_keep)):
split = seasons_keep[j].split("to")
seasons.append('20' + split[0] + '-' + split[1])
for j in range(len(seasons)):
temp_season = temp[temp['season'] == seasons[j]]
maxes.append(int(np.ceil(temp_season['incidH'].max())))
t_maxes.append(temp_season['incidH'].argmax())
lbs.append(np.ceil(0.25 * np.min(maxes)))
t_min_idx.append(np.min(t_maxes)-8)
t_max_idx.append(np.max(t_maxes)+8)
ubs.append(np.ceil(1 * np.max(maxes)))
dates_ = temp_season['date'].values
keep_bounds = pd.DataFrame([locs,fips,lbs,ubs,t_min_idx,t_max_idx]).T
keep_bounds.columns = ['state','fips','lb','ub','time_lb','time_ub']
shutil.rmtree("model_output/", ignore_errors=True)
shutil.rmtree("model_output/", ignore_errors=True)
shutil.rmtree("/model_output/", ignore_errors=True)
states = [sp.split("_")[0] for sp in subpop_names]
states = list(set(states))
nsamples=100
chain_index = -1 # -1 for last
all_results = {}
all_params = {}
all_IC = {}
fips = list(np.array(fips)[19:])
for i, sp in enumerate(fips):
lb = keep_bounds[keep_bounds['fips'] == sp]['lb'].values[0]
ub = keep_bounds[keep_bounds['fips'] == sp]['ub'].values[0]
date_min = datetime.strptime(dates_[keep_bounds['time_lb'].values[i]], '%Y-%m-%d')
date_max =datetime.strptime(dates_[keep_bounds['time_ub'].values[i]], '%Y-%m-%d')
# print(f"Subpop: {i} {sp}", end=" ")
# try:
filename = f"SMH_Flu_2024_R1_allflu_medVax_H3_training_multiseason_emcee_difflocvarseas_{sp}-*.h5"
# find a file that matches the pattern
import glob
filename = glob.glob(filename)
if len(filename) == 0:
print(f"File {filename} does not exist")
continue
filename = filename[0]
# check if the file exists
sampler = emcee.backends.HDFBackend(filename, read_only=True)
chains = sampler.get_chain()
# gempyor_inference.set_save(True)
run_id = "flu_" + sp
# scenario config
state_src_config = f"config_SMH_Flu_2024_R1_allflu_highVax_H3_projection.yml"
results_list = []
params_list = []
IC_list = []
count_trys = 0
keep_list, state_name, fips = get_seasons_keep('H3N2',sp)
print("################################################")
print('Working on ' + state_name)
print("################################################")
count = 0
for k in range(nsamples):
flag = 0
while flag == 0:
count_trys = count_trys + 1
keep_list, state_name, fips = get_seasons_keep('H3N2',sp)
intersect, ids= filter_chains(chains, strain = 'H3N2',keep_list=keep_list)
shuffled_chains, shuffled_samples, indicies, seasons = shuffle_params(chains = chains, idx_array = ids, intersect = intersect, keep_list = keep_list, Num_samples = 1)
# state specific scenario config (to be generated)
state_dst_config = f"config_SMH_Flu_2024_R1_allflu_highVax_H3_projection_{sp}.yml"
shutil.copy(state_src_config, state_dst_config)
# update the config file to use the correct subpop, replace SUBPOP_PLACEHOLDER with this_subpop
with open(state_dst_config, 'r') as file :
filedata = file.read()
filedata = filedata.replace('SUBPOP_PLACEHOLDER_A', '"' + sp + '"')
with open(state_dst_config, 'w') as file:
file.write(filedata)
with open(state_dst_config, 'r') as file :
filedata = file.read()
filedata = filedata.replace('SUBPOP_PLACEHOLDER_B', '"' + seasons[0] + '"')
with open(state_dst_config, 'w') as file:
file.write(filedata)
# 1) modify the config with the sapmpled subpop for each season and call initial condition plugin
# 2) call gempyor_inference.get_logloss_as_single_number, [(shuffled_samples[i, :],) for i in range(shuffled_samples.shape[0])]
# for each sample
gempyor_inference = GempyorInference(
config_filepath=state_dst_config,
run_id=run_id,
prefix=None,
first_sim_index=1,
stoch_traj_flag=False,
rng_seed=None,
nslots=1,
inference_filename_prefix="global/final/", # usually for {global or chimeric}/{intermediate or final}
inference_filepath_suffix="", # usually for the slot_id
out_run_id=None, # if out_run_id is different from in_run_id, fill this
out_prefix=None, # if out_prefix is different from in_prefix, fill this
# in case the data folder is on another directory
autowrite_seir=False,
)
gempyor_inference.set_save(False)
result = gempyor_inference.simulate_proposal(shuffled_samples[0])
if result['incidH'].max() > lb and result['incidH'].max() < ub: #and date_min < result.index[result['incidH'].argmax()] and date_max > result.index[result['incidH'].argmax()]:
flag = 1
count = count + 1
print("#########################################################")
print(state_name +', ' + str(count) + ' simulations accepted, acceptance rate: ', np.round(count/count_trys,2))
print("#########################################################")
results_list.append(result)
params_list.append(shuffled_samples[0])
IC_list.append(seasons[0])
with open("./scenario_output/all_results_H3_HiVax_" + state_name + ".pkl", "wb") as f:
dill.dump(results_list, f)
with open("./scenario_output/all_params_H3_HiVax_" + state_name + ".pkl", "wb") as f:
dill.dump(params_list, f)
with open("./scenario_output/all_IC_H3_HiVax_" + state_name + ".pkl", "wb") as f:
dill.dump(IC_list, f) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
enhancement
Request for improvement or addition of new feature(s).
gempyor
Concerns the Python core.
medium priority
Medium priority.
post-processing
Concern the post-processing.
Label
enhancement, gempyor, post-processing
Priority Label
medium priority
Is your feature request related to a problem? Please describe.
This issue was originally reported by @MacdonaldJoshuaCaleb in GH-413.
For more complicated or atypical post-processing it is helpful to directly sample the posterior predictive distribution. This allows for operations to develop post-processing that takes the fitted parameters as input and produces a distribution of some derived quantity.
Is your feature request related to a new application, scenario round, pathogen? Please describe.
No response
Describe the solution you'd like
Here's the start of a function like that, also developed for flu scenarios, note that to really generalize this we would want to index by the parameter labels rather than just the ordering . Note that chains can be gotten from a h5 file like so, arviz is another (python) package that can read h5 files:
these chains can then be fed to gempyor to simulate the model given a config
The text was updated successfully, but these errors were encountered: