diff --git a/app/index.py b/app/index.py index a69816e..ec49ba1 100644 --- a/app/index.py +++ b/app/index.py @@ -339,10 +339,10 @@ def update_production_based_on_user_data(df: pd.DataFrame) -> pd.DataFrame: | UID | SupplyAmount | Branch | |-----|-------------------|---------------| | 0 | 1 | NaN | - | 1 | 0.25 | [0,1] | + | 1 | 0.25 | [0,1] | NOTA BENE! | 2 | 0.2 * (0.25/0.5) | [0,1,2] | | 3 | 0.1 | [0,3] | - | 4 | 0.18 | [0,1,2,4] | + | 4 | 0.18 | [0,1,2,4] | NOTA BENE! | 5 | 0.05 * (0.1/0.18) | [0,1,2,4,5] | | 6 | 0.01 * (0.1/0.18) | [0,1,2,4,5,6] | @@ -391,10 +391,20 @@ def update_production_based_on_user_data(df: pd.DataFrame) -> pd.DataFrame: def multiplier(row): if not isinstance(row['Branch'], list): return row['SupplyAmount'] - for branch_UID in reversed(row['Branch']): - if branch_UID in dict_user_input: - return row['SupplyAmount'] * dict_user_input[branch_UID] - return row['SupplyAmount'] + elif ( + row['UID'] == row['Branch'][-1] and + np.isnan(row['SupplyAmount_USER']) + ): + return row['SupplyAmount'] + elif ( + row['UID'] == row['Branch'][-1] and not + np.isnan(row['SupplyAmount_USER']) + ): + return row['SupplyAmount_USER'] + else: + for branch_UID in reversed(row['Branch']): + if branch_UID in dict_user_input: + return row['SupplyAmount'] * dict_user_input[branch_UID] df['SupplyAmount_EDITED'] = df.apply(multiplier, axis=1) diff --git a/dev/test_edit_graph.py b/dev/test_edit_graph.py index e4d2c42..d9f056f 100644 --- a/dev/test_edit_graph.py +++ b/dev/test_edit_graph.py @@ -81,10 +81,7 @@ def create_user_input_column( return df_merged -def update_production_based_on_user_data( - df: pd.DataFrame, - column_name: str - ) -> pd.DataFrame: +def update_production_based_on_user_data(df: pd.DataFrame) -> pd.DataFrame: """ Updates the production amount of all nodes which are upstream of a node with user-supplied production amount. @@ -143,8 +140,8 @@ def update_production_based_on_user_data( Output DataFrame. """ - df_filtered = df[~df[f'{column_name}_USER'].isna()] - dict_user_input = df_filtered.set_index('UID').to_dict()[f'{column_name}_USER'] + df_filtered = df[~df['SupplyAmount_USER'].isna()] + dict_user_input = df_filtered.set_index('UID').to_dict()['SupplyAmount_USER'] """ For the example DataFrame from the docstrings above, @@ -159,20 +156,34 @@ def update_production_based_on_user_data( df = df.copy(deep=True) def multiplier(row): if not isinstance(row['Branch'], list): - return row[column_name] - for branch_UID in reversed(row['Branch']): - if branch_UID in dict_user_input: - return row[column_name] * dict_user_input[branch_UID] - return row[column_name] - - df[column_name] = df.apply(multiplier, axis=1) - df.drop(columns=[f'{column_name}_USER'], inplace=True) - - return df, dict_user_input - + return row['SupplyAmount'] + elif ( + row['UID'] == row['Branch'][-1] and + np.isnan(row['SupplyAmount_USER']) + ): + return row['SupplyAmount'] + elif ( + row['UID'] == row['Branch'][-1] and not + np.isnan(row['SupplyAmount_USER']) + ): + return row['SupplyAmount_USER'] + else: + for branch_UID in reversed(row['Branch']): + if branch_UID in dict_user_input: + return row['SupplyAmount'] * dict_user_input[branch_UID] + + df['SupplyAmount_EDITED'] = df.apply(multiplier, axis=1) + + df.drop(columns=['SupplyAmount_USER'], inplace=True) + df['SupplyAmount'] = df['SupplyAmount_EDITED'] + df.drop(columns=['SupplyAmount_EDITED'], inplace=True) + + return df df_user_col = create_user_input_column( df_original=df_original, df_user_input=df_user_input, column_name='SupplyAmount' ) + +df_updated = update_production_based_on_user_data(df_user_col) \ No newline at end of file diff --git a/pyodide/index.html b/pyodide/index.html index 8c5d0d3..74495d4 100644 --- a/pyodide/index.html +++ b/pyodide/index.html @@ -107,14 +107,14 @@