Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request]: Add Output Plotting Options Of Panel Figure With Main Scenario Hub Targets #415

Open
TimothyWillard opened this issue Dec 9, 2024 · 2 comments
Labels
enhancement Request for improvement or addition of new feature(s). gempyor Concerns the Python core. medium priority Medium priority. plotting Relating to plotting and/or visualizations. post-processing Concern the post-processing.

Comments

@TimothyWillard
Copy link
Contributor

Label

enhancement, gempyor, plotting, post-processing

Priority Label

medium priority

Is your feature request related to a problem? Please describe.

This issue was originally reported by @MacdonaldJoshuaCaleb in GH-413.

When trying to assess scenario plots and/or model fit to empirical data there are a number of common targets across pathogens that it would be useful to have plotted together with sample time trajectories. This can be done fairly easily with seaborn and subfigures in matplotlib. Here's an implementation of the basic idea for a given set of results lists like is returned by the gempyor package, which is a list of data frames, this should probably be generalized to be able to read the .parquet files from the model_output folder if using other gempyor functions that populatte the model folder if the inference object is set to save.

Is your feature request related to a new application, scenario round, pathogen? Please describe.

SMH submissions

Describe the solution you'd like

incorporate something like the above function into default post-processing plots with automated post processing like currently exists for R-inference runs and will hopefully (soon) exist for emcee runs

Code without scenario facets:

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

def plot_state(results_list, strain, scenario, fip, save=True, save_path=None, display_plot=True, data_hosp=None):
    # Load hospitalization data
    if data_hosp is None:
        # wherever your calibration data lives
        data_hosp = pd.read_csv(r'./model_input/SMH_Flu_2024_R1_allflu_medVax_H1_training_multiseason_emcee_difflocvarseas/us_data_Flu_2024_R1_allflu_training_multiseason_emcee_difflocvarseas.csv')
    
    # Get seasons to keep and state name
    seasons_keep, state_name, fips = get_seasons_keep(strain, fip)
    dates = results_list[0].index

    # Prepare simulation data
    sim_data = {
        'Date': np.concatenate([dates] * len(results_list)),
        'sim_id': np.concatenate([np.full(len(dates), str(i + 1)) for i in range(len(results_list))]),
        'value': np.concatenate([result['incidH_AllFlu'].values for result in results_list])
    }
    sim_df = pd.DataFrame(sim_data)

    # Create subplots
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))

    # Plot 1: Simulation results
    sns.lineplot(data=sim_df, x='Date', y='value', hue='sim_id', alpha=0.25, palette=["blue"], ax=axs[0, 0])
    axs[0, 0].set_ylabel('Hospitalization incidence')
    axs[0, 0].tick_params(axis='x', rotation=45)
    axs[0, 0].legend_.remove()

    # Prepare seasons_keep_2
    seasons_keep_2 = ['20' + season.split("to")[0] + '-' + season.split("to")[1] for season in seasons_keep]

    # Filter data_hosp for the state
    temp = data_hosp[data_hosp['source'] == state_name]
    seasons = temp['season'].unique()

    # Plot historical data
    for season in seasons:
        temp_season = temp[temp['season'] == season]
        if season in seasons_keep_2:
            sns.lineplot(x=temp_season['date'].astype('datetime64[ns]'), y=temp_season['incidH'], color='red', marker='o', zorder=2, alpha=0.5, ax=axs[0, 0])
    axs[0, 0].legend_.remove()

    # Prepare outcomes
    outcomes = []
    for sim_id in sim_df['sim_id'].unique():
        temp = sim_df[sim_df['sim_id'] == sim_id]
        outcomes.append(['H1 High Vax', sim_id, temp['value'].max(), temp['value'].sum(), temp['Date'].values[temp['value'].argmax()]])
    outcomes = pd.DataFrame(outcomes, columns=['scenario', 'sim_id', 'max', 'cuml', 'max_date'])

    # Plot 2: Max hospitalization incidence
    sns.histplot(data=outcomes, x='max', ax=axs[0, 1], stat='probability')
    axs[0, 1].set_xlabel('Max hospitalization incidence')

    # Plot 3: Cumulative hospitalizations
    sns.histplot(data=outcomes, x='cuml', ax=axs[1, 0], stat='probability')
    axs[1, 0].set_xlabel('Cumulative hospitalizations')

    # Plot 4: Date of max hospitalization incidence
    sns.boxplot(data=outcomes, x='max_date', ax=axs[1, 1])
    axs[1, 1].set_xlabel('Date of max hospitalization incidence')
    axs[1, 1].tick_params(axis='x', rotation=45)

    # Set the title of the figure
    fig.suptitle(f'{state_name} {strain} {scenario}', ha='center', va='bottom')
    plt.tight_layout()

    # Replace spaces in scenario with underscores for file naming
    if ' ' in scenario:
        scenario = scenario.replace(' ', '_')
    
    # Save the figure if save is True
    if save:
        if save_path is None:
            fig.savefig(fname=f'plot_{strain}_{scenario}_{state_name}.pdf', bbox_inches='tight')
        else:
            fig.savefig(fname=f'{save_path}/plot_{strain}_{scenario}_{state_name}.pdf', bbox_inches='tight')
    
    # Display the plot if display_plot is True
    if display_plot:
        plt.show()```

# usage 
# note the get_seasons_keep function is only relevant for current round of flu because we had the issue with "bad" seasons for simulations 
keep_list, state_name, fips = get_seasons_keep('H1N1','01000')
path = "./scenario_output/all_results_H1_HiVax_" + state_name + ".pkl"
    
import dill
# note this is because we are stitching together multiple seasons, for a general function should read output from model_oput folder or a user specified location 
with open(path, 'rb') as f:
    data = dill.load(f)

plot_state(data, 'H1N1', 'High Vax', fips[j], save = True,save_path = './model_plots', display_plot = False)

Code with scenario facets:

def plot_scenario_comp(results_dict, strain, fip, save=True, save_path=None, display_plot=True, data_hosp=None):
    # Load hospitalization data
    if data_hosp is None:
        # wherever your calibration data lives
        data_hosp = pd.read_csv(r'./model_input/SMH_Flu_2024_R1_allflu_medVax_H1_training_multiseason_emcee_difflocvarseas/us_data_Flu_2024_R1_allflu_training_multiseason_emcee_difflocvarseas.csv')
    
    # Prepare simulation data
    def get_sim_df(results_list, scenario):
        dates = results_list[0].index
        sim_data = {
            'Date': np.concatenate([dates] * len(results_list)),
            'sim_id': np.concatenate([np.full(len(dates), str(i + 1)) for i in range(len(results_list))]),
            'scenario': np.concatenate([np.full(len(dates), scenario) for i in range(len(results_list))]),
            'value': np.concatenate([result['incidH_AllFlu'].values for result in results_list])
        }
        return pd.DataFrame(sim_data)
    
    # Get seasons to keep and state name
    seasons_keep, state_name, fips = get_seasons_keep(strain, fip)
   

    sim_dfs = {}
    for key in results_dict.keys():
        results_list = results_dict[key]
        sim_dfs[key] = get_sim_df(results_list, key)
    sim_df = pd.concat(sim_dfs.values(), ignore_index=True)

    # Create subplots
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))

    # Plot 1: Simulation results
    sns.lineplot(data=sim_df, x='Date', y='value', style='sim_id', hue='scenario', ax=axs[0, 0], palette=sns.color_palette("Set3",3))
    axs[0, 0].set_ylabel('Hospitalization incidence')
    axs[0, 0].tick_params(axis='x', rotation=45)
    axs[0, 0].legend_.remove()

    # Prepare seasons_keep_2
    seasons_keep_2 = ['20' + season.split("to")[0] + '-' + season.split("to")[1] for season in seasons_keep]

    # Filter data_hosp for the state
    temp = data_hosp[data_hosp['source'] == state_name]
    seasons = temp['season'].unique()

    # Plot historical data
    for season in seasons:
        temp_season = temp[temp['season'] == season]
        if season in seasons_keep_2:
            sns.lineplot(x=temp_season['date'].astype('datetime64[ns]'), y=temp_season['incidH'], color='red', marker='o', zorder=2, alpha=0.5, ax=axs[0, 0])
    axs[0, 0].legend_.remove()

    # Prepare outcomes
    outcomes = []
    for key in results_dict.keys():
        for sim_id in sim_df['sim_id'].unique():
            temp = sim_df[(sim_df['sim_id'] == sim_id) & (sim_df['scenario'] == key)]
            outcomes.append([key, sim_id, temp['value'].max(), temp['value'].sum(), temp['Date'].values[temp['value'].argmax()]])
    outcomes = pd.DataFrame(outcomes, columns=['scenario', 'sim_id', 'max', 'cuml', 'max_date'])

    # Plot 2: Max hospitalization incidence
    sns.histplot(data=outcomes, x='max', ax=axs[0, 1], stat='probability', hue='scenario', palette=sns.color_palette("Set3",3),multiple="dodge")
    axs[0, 1].set_xlabel('Max hospitalization incidence')

    # Plot 3: Cumulative hospitalizations
    sns.histplot(data=outcomes, x='cuml', ax=axs[1, 0], stat='probability', hue='scenario', palette=sns.color_palette("Set3",3),multiple="dodge")
    axs[1, 0].set_xlabel('Cumulative hospitalizations')

    # Plot 4: Date of max hospitalization incidence
    metrics = ['max', 'cuml']
    combinations = ['High vs. Med', 'Med vs. Low']
    difs = []
    labels = []
    combos = []

    for combo in combinations:
        if combo == 'High vs. Med':
            hi_vals = outcomes[outcomes['scenario'] == 'HiVax'][metrics].values
            med_vals = outcomes[outcomes['scenario'] == 'MedVax'][metrics].values
            vals = -(hi_vals - med_vals) / med_vals
        elif combo == 'Med vs. Low':
            med_vals = outcomes[outcomes['scenario'] == 'MedVax'][metrics].values
            low_vals = outcomes[outcomes['scenario'] == 'LowVax'][metrics].values
            vals = -(med_vals - low_vals) / low_vals
        
        for i, metric in enumerate(metrics):
            difs.extend(vals[:, i])
            labels.extend([metric] * len(vals))
            combos.extend([combo] * len(vals))

    result_df = pd.DataFrame({'combination': combos, 'label': labels, 'difference': difs})

    sns.violinplot(data=result_df,x='label',y='difference', hue='combination',ax=axs[1,1], palette=sns.color_palette("Set2",2))
    axs[1, 1].set_ylabel(r'Hospitalizations averted (%)')
    axs[1, 1].set_xlabel('Target')


    # Set the title of the figure
    fig.suptitle(f'{state_name} {strain}', ha='center', va='bottom')
    plt.tight_layout()

    if save:
        if save_path is None:
            fig.savefig(fname=f'plot_{strain}_scenario_comp_{state_name}.pdf', bbox_inches='tight')
        else:
            fig.savefig(fname=f'{save_path}/plot_{strain}_scenario_comp_{state_name}.pdf', bbox_inches='tight')
    
    # Display the plot if display_plot is True
    if display_plot:
        plt.show()

These outputs look like
plot_H1N1_High_Vax_Alabama.pdf and
plot_H1N1_scenario_comp_Alabama.pdf, respectively. These two plotting methods share a lot of code but have different interfaces, probably makes sense to keep the two different interfaces and extract the core logic to one function.

@TimothyWillard TimothyWillard added enhancement Request for improvement or addition of new feature(s). gempyor Concerns the Python core. post-processing Concern the post-processing. medium priority Medium priority. plotting Relating to plotting and/or visualizations. labels Dec 9, 2024
@MacdonaldJoshuaCaleb MacdonaldJoshuaCaleb changed the title [Feature request]: Add Output Plotting Option Of Panel Figure With Main Scenario Hub Targets [Feature request]: Add Utilities for Scenario Hub evaluation and submission Dec 17, 2024
@MacdonaldJoshuaCaleb MacdonaldJoshuaCaleb changed the title [Feature request]: Add Utilities for Scenario Hub evaluation and submission [Feature request]: Add Plotting Utilities for Scenario Hub targets Dec 17, 2024
@TimothyWillard TimothyWillard changed the title [Feature request]: Add Plotting Utilities for Scenario Hub targets [Feature request]: Add Output Plotting Option Of Panel Figure With Main Scenario Hub Targets Dec 17, 2024
@MacdonaldJoshuaCaleb MacdonaldJoshuaCaleb changed the title [Feature request]: Add Output Plotting Option Of Panel Figure With Main Scenario Hub Targets [Feature request]: Add Output Plottings Option Of Panel Figure With Main Scenario Hub Targets Dec 18, 2024
@MacdonaldJoshuaCaleb MacdonaldJoshuaCaleb changed the title [Feature request]: Add Output Plottings Option Of Panel Figure With Main Scenario Hub Targets [Feature request]: Add Output Plotting Options Of Panel Figure With Main Scenario Hub Targets Dec 18, 2024
@MacdonaldJoshuaCaleb
Copy link
Collaborator

Here's some code for making time resolved confidence interval plots per @shauntruelove's request on slack. Note formatted is the output as an SMH formatted data frame (see #430 for details)

from datetime import timedelta

def get_week_number(date_obj, start_date):
    return (date_obj - start_date).days // 7

def get_saturday_date(week_number, start_date):
    return start_date + timedelta(days=(week_number * 7) + 5)

# Find the global maximum value across all scenarios
def global_max(loc, formatted):
    maxes = []
    
    for scenario in formatted['scenario_id'].unique():
        data = formatted[(formatted['age_group'] == '0-130') & (formatted['location'] == loc) & (formatted['scenario_id'] == scenario)]
        pivoted_data = data.pivot(index='run_grouping', columns='horizon', values='value')
        maxes.append(pivoted_data.quantile(0.975).max())
    return np.max(maxes)

def get_max_season(loc):
    data_hosp = pd.read_csv(r'~/Documents/weekly_flu_incid_complete_fixed.csv')
    data_hosp = data_hosp[(data_hosp['season'] == '2022-23') | (data_hosp['season'] == '2023-24')]

    states = ['Alabama','Alaska','Arizona','Arkansas','California','Colorado','Connecticut','Delaware','District of Columbia',
        'Florida','Georgia','Hawaii','Idaho','Illinois','Indiana','Iowa','Kansas','Kentucky','Louisiana','Maine','Maryland',
        'Massachusetts','Michigan','Minnesota','Mississippi','Missouri','Montana','Nebraska','Nevada','New Hampshire',
        'New Jersey','New Mexico','New York','North Carolina','North Dakota','Ohio','Oklahoma','Oregon','Pennsylvania',
        'Rhode Island','South Carolina','South Dakota','Tennessee','Texas','Utah','Vermont','Virginia','Washington',
        'West Virginia','Wisconsin','Wyoming','US']

    fips = ['01', '02', '04', '05', '06', '08', '09',
        '10', '11', '12', '13', '15', '16', '17',
        '18', '19', '20', '21', '22', '23', '24',
        '25', '26', '27', '28', '29', '30', '31',
        '32', '33', '34', '35', '36', '37', '38',
        '39', '40', '41', '42', '44', '45', '46',
        '47', '48', '49', '50', '51', '53', '54',
        '55', '56', 'US']

    state = states[np.where(np.array(fips)==loc)[0][0]]
    data_hosp = data_hosp[data_hosp['state'] == state]
    maxes = []
    for season in data_hosp['season'].unique():
            hosp_season = data_hosp[data_hosp['season'] == season]
            maxes.append(hosp_season['incidH'].max())
    return np.array(maxes)

def get_cuml_season(loc):
    data_hosp = pd.read_csv(r'~/Documents/weekly_flu_incid_complete_fixed.csv')
    data_hosp = data_hosp[(data_hosp['season'] == '2022-23') | (data_hosp['season'] == '2023-24')]

    states = ['Alabama','Alaska','Arizona','Arkansas','California','Colorado','Connecticut','Delaware','District of Columbia',
        'Florida','Georgia','Hawaii','Idaho','Illinois','Indiana','Iowa','Kansas','Kentucky','Louisiana','Maine','Maryland',
        'Massachusetts','Michigan','Minnesota','Mississippi','Missouri','Montana','Nebraska','Nevada','New Hampshire',
        'New Jersey','New Mexico','New York','North Carolina','North Dakota','Ohio','Oklahoma','Oregon','Pennsylvania',
        'Rhode Island','South Carolina','South Dakota','Tennessee','Texas','Utah','Vermont','Virginia','Washington',
        'West Virginia','Wisconsin','Wyoming','US']

    fips = ['01', '02', '04', '05', '06', '08', '09',
        '10', '11', '12', '13', '15', '16', '17',
        '18', '19', '20', '21', '22', '23', '24',
        '25', '26', '27', '28', '29', '30', '31',
        '32', '33', '34', '35', '36', '37', '38',
        '39', '40', '41', '42', '44', '45', '46',
        '47', '48', '49', '50', '51', '53', '54',
        '55', '56', 'US']

    state = states[np.where(np.array(fips)==loc)[0][0]]
    data_hosp = data_hosp[data_hosp['state'] == state]
    cuml = []
    for season in data_hosp['season'].unique():
            hosp_season = data_hosp[data_hosp['season'] == season]
            if season == '2023-24':
                lb = hosp_season[hosp_season['yr.wk'] == 2023.30].index[0]
                ub = hosp_season[hosp_season['yr.wk'] == 2024.23].index[0]
            else:
                lb = hosp_season[hosp_season['yr.wk'] == 2022.30].index[0]
                ub = hosp_season[hosp_season['yr.wk'] == 2023.23].index[0]
            hosp_season = hosp_season.loc[lb:ub]
            cuml.append(hosp_season['incidH'].sum())
    cuml = np.array(cuml)
    return [np.floor(0.75 * np.min(cuml)), np.ceil(1.25 * np.max(cuml))]

def scenario_plot(loc, formatted, display=True):
    states = ['Alabama', 'Alaska', 'Arizona', 'Arkansas', 'California', 'Colorado', 'Connecticut', 'Delaware', 'District of Columbia',
              'Florida', 'Georgia', 'Hawaii', 'Idaho', 'Illinois', 'Indiana', 'Iowa', 'Kansas', 'Kentucky', 'Louisiana', 'Maine', 'Maryland',
              'Massachusetts', 'Michigan', 'Minnesota', 'Mississippi', 'Missouri', 'Montana', 'Nebraska', 'Nevada', 'New Hampshire',
              'New Jersey', 'New Mexico', 'New York', 'North Carolina', 'North Dakota', 'Ohio', 'Oklahoma', 'Oregon', 'Pennsylvania',
              'Rhode Island', 'South Carolina', 'South Dakota', 'Tennessee', 'Texas', 'Utah', 'Vermont', 'Virginia', 'Washington',
              'West Virginia', 'Wisconsin', 'Wyoming', 'US']

    fips = ['01', '02', '04', '05', '06', '08', '09', '10', '11', '12', '13', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24',
            '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '44', '45', '46',
            '47', '48', '49', '50', '51', '53', '54', '55', '56', 'US']

    state = states[fips.index(loc)]

    fig, axs = plt.subplots(3, 2, figsize=(10, 15))

    start_date = pd.to_datetime(formatted['origin_date'].unique()[0])
    past_maxes = get_max_season(loc)
    global_max_val = global_max(loc, formatted)

    scenarios = formatted['scenario_id'].unique()
    for i, scenario in enumerate(scenarios):
        row, col = divmod(i, 2)
        data = formatted[(formatted['age_group'] == '0-130') & (formatted['location'] == loc) & (formatted['scenario_id'] == scenario)]
        pivoted_data = data.pivot(index='run_grouping', columns='horizon', values='value')

        quantiles = pivoted_data.quantile([0.025, 0.25, 0.75, 0.975])
        axs[row, col].fill_between(pivoted_data.columns, quantiles.loc[0.025], quantiles.loc[0.975], alpha=0.2, color='black')
        axs[row, col].fill_between(pivoted_data.columns, quantiles.loc[0.25], quantiles.loc[0.75], alpha=0.2, color='black')
        axs[row, col].set_ylim(0, global_max_val)

        max_2022_23, max_2023_24 = past_maxes
        if max_2022_23 > max_2023_24:
            axs[row, col].axhline(max_2022_23, color='red', linestyle='--', label=f'2022-23 Max: {int(np.round(max_2022_23))}')
            axs[row, col].text(0, max_2022_23, f'2022-23 Max: {int(np.round(max_2022_23))}', color='red', ha='left', va='bottom')
            axs[row, col].axhline(max_2023_24, color='red', linestyle='--')
            axs[row, col].text(0, max_2023_24, f'2023-24 Max: {int(np.round(max_2023_24))}', color='red', ha='left', va='top')
        else:
            axs[row, col].axhline(max_2023_24, color='red', linestyle='--', label=f'2023-24 Max: {int(np.round(max_2023_24))}')
            axs[row, col].text(0, max_2023_24, f'2023-24 Max: {int(np.round(max_2023_24))}', color='red', ha='left', va='bottom')
            axs[row, col].axhline(max_2022_23, color='red', linestyle='--')
            axs[row, col].text(0, max_2022_23, f'2022-23 Max: {int(np.round(max_2022_23))}', color='red', ha='left', va='top')

        if col == 0:
            axs[row, col].set_ylabel('Weekly Hosp. Incid.')
        else:
            axs[row, col].set_yticklabels([])

        axs[row, col].set_title(f'Scenario: {scenario}')
        axs[row, col].set_xlim([0, 44])
        axs[row, col].set_xticks(pivoted_data.columns[::4])
        if row == 2:
            axs[row, col].set_xticklabels([get_saturday_date(week, start_date).strftime('%Y-%m-%d') for week in pivoted_data.columns[::4]], rotation=45)
        else:
            axs[row, col].set_xticklabels([''] * len(pivoted_data.columns[::4]))

    fig.suptitle(state, ha='center', va='bottom')
    plt.tight_layout()
    save_path = './model_plots'
    fig.savefig(fname=f'{save_path}/scenariohub_{state}.pdf', bbox_inches='tight')
    if display:
        plt.show()

#############################
usage 

scenario_plot('01', formatted, display=True)

@MacdonaldJoshuaCaleb
Copy link
Collaborator

sample output from this newest block of code
scenariohub_Alabama.pdf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Request for improvement or addition of new feature(s). gempyor Concerns the Python core. medium priority Medium priority. plotting Relating to plotting and/or visualizations. post-processing Concern the post-processing.
Projects
None yet
Development

No branches or pull requests

2 participants