Skip to content

Commit

Permalink
Reset inference workflows (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
KCGallagher committed Sep 9, 2022
1 parent 71e0ebb commit ce1979e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 136 deletions.
119 changes: 4 additions & 115 deletions periodic_sampling/fixed_bias_sampler.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,14 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 91,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 6/6 [00:00<00:00, 22.05it/s]\n"
"100%|██████████| 6/6 [00:00<00:00, 25.88it/s]\n"
]
}
],
Expand All @@ -237,125 +237,14 @@
"\n",
"step_num = int(6)\n",
"sampler = MixedSampler(params=params)\n",
"output1 = sampler.sampling_routine(step_num=step_num, sample_burnin=0)\n",
"output = sampler.sampling_routine(step_num=step_num, sample_burnin=0)\n",
"\n",
"filename = f\"synth_inference_T_{bias_method}_{time_steps}_N0_{N_0}_R0diff_{R0_diff}_It_{step_num}_seed_{seed}.csv\"\n",
"output1.to_csv('../data/outputs/scale_bias_model/' + filename)\n",
"output.to_csv('../data/outputs/scale_bias_model/' + filename)\n",
"\n",
"images_path = \"synthetic_inference/stepped_R/scale_bias_model/normalised_bias/\""
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 7/7 [00:00<00:00, 24.85it/s]\n"
]
}
],
"source": [
"I_data = list(bias_df['Confirmed'])\n",
"\n",
"params = {'bias_prior_alpha': 1, 'bias_prior_beta': 1,\n",
" 'rt_prior_alpha': 1, 'rt_prior_beta': 1} # Gamma dist\n",
"\n",
"params['serial_interval'] = RenewalModel(R0=None).serial_interval\n",
"params['Rt_window'] = 7 # Assume it is constant for 7 days\n",
"params['error_threshold_bias'] = 1 # resample bias if this is not satisfied\n",
"\n",
"for i, val in enumerate(I_data): # Observed cases - not a Parameter\n",
" params[(\"data_\" + str(i))] = val\n",
"\n",
"for i in range(0, len(I_data)): # Reproductive number values\n",
" params[(\"R_\" + str(i))] = rt_parameter(value=1, index=i)\n",
"\n",
"for i in range(7): # Weekday bias parameters\n",
" params[(\"bias_\" + str(i))] = scale_bias_parameter(1, index=i)\n",
"\n",
"step_num = int(7)\n",
"sampler = MixedSampler(params=params)\n",
"output2 = sampler.sampling_routine(step_num=step_num, sample_burnin=0)\n",
"\n",
"filename = f\"synth_inference_T_{bias_method}_{time_steps}_N0_{N_0}_R0diff_{R0_diff}_It_{step_num}_seed_{seed}.csv\"\n",
"output2.to_csv('../data/outputs/scale_bias_model/' + filename)\n",
"\n",
"images_path = \"synthetic_inference/stepped_R/scale_bias_model/normalised_bias/\""
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 88%|████████▊ | 7/8 [12:53<01:50, 110.45s/it]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/auto/users/kither/Documents/periodic-sampling/periodic_sampling/fixed_bias_sampler.ipynb Cell 13\u001b[0m in \u001b[0;36m<cell line: 21>\u001b[0;34m()\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/auto/users/kither/Documents/periodic-sampling/periodic_sampling/fixed_bias_sampler.ipynb#X41sZmlsZQ%3D%3D?line=18'>19</a>\u001b[0m step_num \u001b[39m=\u001b[39m \u001b[39mint\u001b[39m(\u001b[39m8\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell:/auto/users/kither/Documents/periodic-sampling/periodic_sampling/fixed_bias_sampler.ipynb#X41sZmlsZQ%3D%3D?line=19'>20</a>\u001b[0m sampler \u001b[39m=\u001b[39m MixedSampler(params\u001b[39m=\u001b[39mparams)\n\u001b[0;32m---> <a href='vscode-notebook-cell:/auto/users/kither/Documents/periodic-sampling/periodic_sampling/fixed_bias_sampler.ipynb#X41sZmlsZQ%3D%3D?line=20'>21</a>\u001b[0m output3 \u001b[39m=\u001b[39m sampler\u001b[39m.\u001b[39;49msampling_routine(step_num\u001b[39m=\u001b[39;49mstep_num, sample_burnin\u001b[39m=\u001b[39;49m\u001b[39m0\u001b[39;49m)\n\u001b[1;32m <a href='vscode-notebook-cell:/auto/users/kither/Documents/periodic-sampling/periodic_sampling/fixed_bias_sampler.ipynb#X41sZmlsZQ%3D%3D?line=22'>23</a>\u001b[0m filename \u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39msynth_inference_T_\u001b[39m\u001b[39m{\u001b[39;00mbias_method\u001b[39m}\u001b[39;00m\u001b[39m_\u001b[39m\u001b[39m{\u001b[39;00mtime_steps\u001b[39m}\u001b[39;00m\u001b[39m_N0_\u001b[39m\u001b[39m{\u001b[39;00mN_0\u001b[39m}\u001b[39;00m\u001b[39m_R0diff_\u001b[39m\u001b[39m{\u001b[39;00mR0_diff\u001b[39m}\u001b[39;00m\u001b[39m_It_\u001b[39m\u001b[39m{\u001b[39;00mstep_num\u001b[39m}\u001b[39;00m\u001b[39m_seed_\u001b[39m\u001b[39m{\u001b[39;00mseed\u001b[39m}\u001b[39;00m\u001b[39m.csv\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m <a href='vscode-notebook-cell:/auto/users/kither/Documents/periodic-sampling/periodic_sampling/fixed_bias_sampler.ipynb#X41sZmlsZQ%3D%3D?line=23'>24</a>\u001b[0m output3\u001b[39m.\u001b[39mto_csv(\u001b[39m'\u001b[39m\u001b[39m../data/outputs/scale_bias_model/\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m filename)\n",
"File \u001b[0;32m/auto/users/kither/Documents/periodic-sampling/periodic_sampling/sampling_methods/mixed_sampler.py:88\u001b[0m, in \u001b[0;36mMixedSampler.sampling_routine\u001b[0;34m(self, step_num, sample_period, sample_burnin, random_order, chain_num)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(params[key], GibbsParameter): \u001b[39m# resample both bias and Rt \u001b[39;00m\n\u001b[1;32m 87\u001b[0m gibbs\u001b[39m.\u001b[39mparams \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mparams\n\u001b[0;32m---> 88\u001b[0m row[key] \u001b[39m=\u001b[39m gibbs\u001b[39m.\u001b[39;49msingle_sample(key)\n\u001b[1;32m 89\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mparams[key]\u001b[39m.\u001b[39mvalue \u001b[39m=\u001b[39m row[key]\n\u001b[1;32m 90\u001b[0m bias_sum \u001b[39m=\u001b[39m \u001b[39msum\u001b[39m([row[key] \u001b[39mfor\u001b[39;00m key \u001b[39min\u001b[39;00m params\u001b[39m.\u001b[39mkeys() \u001b[39mif\u001b[39;00m (key\u001b[39m.\u001b[39mstartswith(\u001b[39m'\u001b[39m\u001b[39mbias_\u001b[39m\u001b[39m'\u001b[39m) \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m key\u001b[39m.\u001b[39mstartswith(\u001b[39m'\u001b[39m\u001b[39mbias_prior\u001b[39m\u001b[39m'\u001b[39m))])\n",
"File \u001b[0;32m/auto/users/kither/Documents/periodic-sampling/periodic_sampling/sampling_methods/gibbs_sampler.py:118\u001b[0m, in \u001b[0;36mGibbsSampler.single_sample\u001b[0;34m(self, param_name)\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[39m\"\"\"Runs single sample of a parameter, updating the value \u001b[39;00m\n\u001b[1;32m 107\u001b[0m \u001b[39minplace in the params dictionary and returning the updated value\u001b[39;00m\n\u001b[1;32m 108\u001b[0m \u001b[39mfor recording.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[39m instance to sample from.\u001b[39;00m\n\u001b[1;32m 115\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 116\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39misinstance\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mparams[param_name], GibbsParameter), \\\n\u001b[1;32m 117\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mParameter name supplied must correspond to Parameter instance\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m--> 118\u001b[0m value \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mparams[param_name]\u001b[39m.\u001b[39;49msample(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mparams)\n\u001b[1;32m 119\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mparams[param_name]\u001b[39m.\u001b[39mvalue \u001b[39m=\u001b[39m value\n\u001b[1;32m 120\u001b[0m \u001b[39mreturn\u001b[39;00m value\n",
"File \u001b[0;32m/auto/users/kither/Documents/periodic-sampling/periodic_sampling/sampling_methods/gibbs_sampler.py:85\u001b[0m, in \u001b[0;36mGibbsParameter.sample\u001b[0;34m(self, sample_params)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 84\u001b[0m param_values[k] \u001b[39m=\u001b[39m v\n\u001b[0;32m---> 85\u001b[0m post_params \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mposterior_params(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mparam_values)\n\u001b[1;32m 86\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mvalue \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconditional_posterior(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mpost_params)\n\u001b[1;32m 87\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mvalue\n",
"File \u001b[0;32m/auto/users/kither/Documents/periodic-sampling/periodic_sampling/periodic_model.py:397\u001b[0m, in \u001b[0;36mrt_parameter.<locals>.<lambda>\u001b[0;34m(**kwargs)\u001b[0m\n\u001b[1;32m 378\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrt_parameter\u001b[39m(value, index, sampling_freq \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m):\n\u001b[1;32m 379\u001b[0m \u001b[39m\"\"\"Parameters for the probability density function (pdf) for \u001b[39;00m\n\u001b[1;32m 380\u001b[0m \u001b[39m one index of the time-varying reproductive number.\u001b[39;00m\n\u001b[1;32m 381\u001b[0m \u001b[39m \u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 394\u001b[0m \u001b[39m GibbsParameter : Parameter object for constant reproductive number \u001b[39;00m\n\u001b[1;32m 395\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[1;32m 396\u001b[0m \u001b[39mreturn\u001b[39;00m GibbsParameter(value\u001b[39m=\u001b[39mvalue, conditional_posterior\u001b[39m=\u001b[39mss\u001b[39m.\u001b[39mgamma\u001b[39m.\u001b[39mrvs, sampling_freq\u001b[39m=\u001b[39msampling_freq,\n\u001b[0;32m--> 397\u001b[0m posterior_params\u001b[39m=\u001b[39m\u001b[39mlambda\u001b[39;00m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs : _rt_params(final_index\u001b[39m=\u001b[39;49mindex, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs))\n",
"File \u001b[0;32m/auto/users/kither/Documents/periodic-sampling/periodic_sampling/periodic_model.py:326\u001b[0m, in \u001b[0;36m_rt_params\u001b[0;34m(initial_index, final_index, **kwargs)\u001b[0m\n\u001b[1;32m 323\u001b[0m truth_values \u001b[39m=\u001b[39m []; gamma_values \u001b[39m=\u001b[39m []\n\u001b[1;32m 325\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m data_indicies[window_start:final_index]:\n\u001b[0;32m--> 326\u001b[0m gamma_values\u001b[39m.\u001b[39mappend(_calculate_lambda(params, i))\n\u001b[1;32m 327\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 328\u001b[0m truth_values\u001b[39m.\u001b[39mappend(params[\u001b[39m'\u001b[39m\u001b[39mtruth_\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m \u001b[39mstr\u001b[39m(i)])\n",
"File \u001b[0;32m/auto/users/kither/Documents/periodic-sampling/periodic_sampling/periodic_model.py:92\u001b[0m, in \u001b[0;36m_calculate_lambda\u001b[0;34m(params, max_t)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mreturn\u001b[39;00m (params[\u001b[39m'\u001b[39m\u001b[39mdata_0\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m/\u001b[39m params[\u001b[39m'\u001b[39m\u001b[39mbias_0\u001b[39m\u001b[39m'\u001b[39m]) \u001b[39m# Best guess of initial point\u001b[39;00m\n\u001b[1;32m 91\u001b[0m omega \u001b[39m=\u001b[39m params[\u001b[39m'\u001b[39m\u001b[39mserial_interval\u001b[39m\u001b[39m'\u001b[39m]\n\u001b[0;32m---> 92\u001b[0m cases \u001b[39m=\u001b[39m [params[k] \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m params\u001b[39m.\u001b[39mkeys() \u001b[39mif\u001b[39;00m k\u001b[39m.\u001b[39mstartswith(\u001b[39m'\u001b[39m\u001b[39mdata_\u001b[39m\u001b[39m'\u001b[39m)]\n\u001b[1;32m 93\u001b[0m n_terms_lambda \u001b[39m=\u001b[39m \u001b[39mmin\u001b[39m(max_t \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m, \u001b[39mlen\u001b[39m(omega)) \u001b[39m# Number of terms in sum for lambda\u001b[39;00m\n\u001b[1;32m 94\u001b[0m \u001b[39mif\u001b[39;00m max_t \u001b[39m<\u001b[39m \u001b[39mlen\u001b[39m(omega):\n",
"File \u001b[0;32m/auto/users/kither/Documents/periodic-sampling/periodic_sampling/periodic_model.py:92\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mreturn\u001b[39;00m (params[\u001b[39m'\u001b[39m\u001b[39mdata_0\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m/\u001b[39m params[\u001b[39m'\u001b[39m\u001b[39mbias_0\u001b[39m\u001b[39m'\u001b[39m]) \u001b[39m# Best guess of initial point\u001b[39;00m\n\u001b[1;32m 91\u001b[0m omega \u001b[39m=\u001b[39m params[\u001b[39m'\u001b[39m\u001b[39mserial_interval\u001b[39m\u001b[39m'\u001b[39m]\n\u001b[0;32m---> 92\u001b[0m cases \u001b[39m=\u001b[39m [params[k] \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m params\u001b[39m.\u001b[39mkeys() \u001b[39mif\u001b[39;00m k\u001b[39m.\u001b[39;49mstartswith(\u001b[39m'\u001b[39;49m\u001b[39mdata_\u001b[39;49m\u001b[39m'\u001b[39;49m)]\n\u001b[1;32m 93\u001b[0m n_terms_lambda \u001b[39m=\u001b[39m \u001b[39mmin\u001b[39m(max_t \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m, \u001b[39mlen\u001b[39m(omega)) \u001b[39m# Number of terms in sum for lambda\u001b[39;00m\n\u001b[1;32m 94\u001b[0m \u001b[39mif\u001b[39;00m max_t \u001b[39m<\u001b[39m \u001b[39mlen\u001b[39m(omega):\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"I_data = list(bias_df['Confirmed'])\n",
"\n",
"params = {'bias_prior_alpha': 1, 'bias_prior_beta': 1,\n",
" 'rt_prior_alpha': 1, 'rt_prior_beta': 1} # Gamma dist\n",
"\n",
"params['serial_interval'] = RenewalModel(R0=None).serial_interval\n",
"params['Rt_window'] = 7 # Assume it is constant for 7 days\n",
"params['error_threshold_bias'] = 1 # resample bias if this is not satisfied\n",
"\n",
"for i, val in enumerate(I_data): # Observed cases - not a Parameter\n",
" params[(\"data_\" + str(i))] = val\n",
"\n",
"for i in range(0, len(I_data)): # Reproductive number values\n",
" params[(\"R_\" + str(i))] = rt_parameter(value=1, index=i)\n",
"\n",
"for i in range(7): # Weekday bias parameters\n",
" params[(\"bias_\" + str(i))] = scale_bias_parameter(1, index=i)\n",
"\n",
"step_num = int(8)\n",
"sampler = MixedSampler(params=params)\n",
"output3 = sampler.sampling_routine(step_num=step_num, sample_burnin=0)\n",
"\n",
"filename = f\"synth_inference_T_{bias_method}_{time_steps}_N0_{N_0}_R0diff_{R0_diff}_It_{step_num}_seed_{seed}.csv\"\n",
"output3.to_csv('../data/outputs/scale_bias_model/' + filename)\n",
"\n",
"images_path = \"synthetic_inference/stepped_R/scale_bias_model/normalised_bias/\""
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"output = output2"
]
},
{
"cell_type": "code",
"execution_count": 72,
Expand Down
19 changes: 11 additions & 8 deletions periodic_sampling/sampling_methods/mixed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,24 +70,27 @@ def sampling_routine(self, step_num, sample_period = 1,
row[key] = gibbs.single_sample(key)
self.params[key].value = row[key]

bias_sum = sum([row[key] for key in params.keys() if (key.startswith('bias_') and not key.startswith('bias_prior'))])
# bias_sum = sum([row[key] for key in params.keys() if (key.startswith('bias_') and not key.startswith('bias_prior'))])

# Post-Hoc nromalisation
# for key in list(params.keys()):
# if key.startswith('bias_') and not key.startswith('bias_prior'):
# row[key] /= (bias_sum / 7)
# self.params[key].value = row[key]

# Final value adjustment
# random_bias = 'bias_' + str(random.randint(0, 6))
# row[random_bias] = max(0.01, 7 - (bias_sum - row[random_bias]))
# self.params[random_bias].value = row[random_bias]

while abs(bias_sum - 7) > params['error_threshold_bias']:
for key in list(params.keys()):
if isinstance(params[key], GibbsParameter): # resample both bias and Rt
gibbs.params = self.params
row[key] = gibbs.single_sample(key)
self.params[key].value = row[key]
bias_sum = sum([row[key] for key in params.keys() if (key.startswith('bias_') and not key.startswith('bias_prior'))])
# Forced resampling
# while abs(bias_sum - 7) > params['error_threshold_bias']:
# for key in list(params.keys()):
# if isinstance(params[key], GibbsParameter): # resample both bias and Rt
# gibbs.params = self.params
# row[key] = gibbs.single_sample(key)
# self.params[key].value = row[key]
# bias_sum = sum([row[key] for key in params.keys() if (key.startswith('bias_') and not key.startswith('bias_prior'))])


if (((n + 1) > sample_burnin) & ((n + 1) % sample_period == 0)):
Expand Down
13 changes: 0 additions & 13 deletions periodic_sampling/varying_Rt_sampler.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,6 @@
"from periodic_model import truth_parameter, poisson_bias_parameter, rt_parameter, single_r_parameter, constant_r_parameter"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def final_bias_sample(**kwargs):\n",
" params = kwargs\n",
" bias_sum = sum([params['bias_' + str(i)] for i in range(6)])\n",
" return max(0, 7 - bias_sum)"
]
},
{
"cell_type": "code",
"execution_count": 16,
Expand Down Expand Up @@ -264,7 +252,6 @@
"\n",
"for i in range(7): # Weekday bias parameters\n",
" params[(\"bias_\" + str(i))] = poisson_bias_parameter(value=1, index=i)\n",
"# params['bias_6'] = GibbsParameter(value=1, conditional_posterior=final_bias_sample)\n",
"\n",
"for i in range(0, len(I_data)): # Reproductive number values\n",
" params[(\"R_\" + str(i))] = rt_parameter(value=1, index=i)\n",
Expand Down

0 comments on commit ce1979e

Please sign in to comment.