From 5914e2318bd2d563329a6d0039e15de3b9075c79 Mon Sep 17 00:00:00 2001 From: Lint Action Date: Tue, 7 May 2024 10:13:30 +0000 Subject: [PATCH] Fix code style issues with Black --- src/postprocessing.py | 46 ++++++++++++++++++++++++++---------- src/preprocessing.py | 20 +++++++++++----- src/utils.py | 37 ++++++++++++++++++++--------- tests/test_postprocessing.py | 4 +++- 4 files changed, 76 insertions(+), 31 deletions(-) diff --git a/src/postprocessing.py b/src/postprocessing.py index 1346916..1723f77 100644 --- a/src/postprocessing.py +++ b/src/postprocessing.py @@ -187,7 +187,8 @@ def _compute_imports_difference(self) -> None: imports_difference[country] = 0 else: imports_difference[country] = ( - (imports[country] - self.imports[0][country]) / self.imports[0][country] + (imports[country] - self.imports[0][country]) + / self.imports[0][country] ) * 100 self.imports_difference.append(imports_difference) @@ -980,10 +981,15 @@ def _compute_community_satisfaction_difference(self) -> None: """ self.community_satisfaction_difference = [ { - country: satisfaction - self.community_satisfaction[0][country] - if country in self.community_satisfaction[0] - else (print(f"Warning: {country} not found in the base scenario.") or np.nan) - for country, satisfaction in community_satisfaction.items() + country: ( + satisfaction - self.community_satisfaction[0][country] + if country in self.community_satisfaction[0] + else ( + print(f"Warning: {country} not found in the base scenario.") + or np.nan + ) + ) + for country, satisfaction in community_satisfaction.items() } for community_satisfaction in self.community_satisfaction[1:] ] @@ -1241,7 +1247,9 @@ def _compute_participation(self) -> None: for community in scenario.trade_communities ] # we sum up the squares of the number of edges to each community ) - / (total_degree[country] ** 2 if total_degree[country] != 0 else 1) # Avoid division by zero + / ( + total_degree[country] ** 2 if total_degree[country] != 0 else 1 + ) # Avoid division by zero for country in undirected_trade_graph # for each country } self.participation.append(coefficients) @@ -1308,7 +1316,12 @@ def plot_roles( axs = [axs] for ax, (idx, scenario) in zip(axs, enumerate(self.scenarios)): ax.scatter( - self.participation[idx].values(), self.zscores[idx].values(), zorder=5, color="black", alpha=0.8, **kwargs + self.participation[idx].values(), + self.zscores[idx].values(), + zorder=5, + color="black", + alpha=0.8, + **kwargs, ) ax.set_title( f"Country roles for {scenario.crop} with base year {scenario.base_year[1:]}" @@ -1318,7 +1331,9 @@ def plot_roles( else "\n(no scenario)" ) ) - fill_sector_by_colour(ax, z_threshold, p_thresholds, alpha, labels, fontsize) + fill_sector_by_colour( + ax, z_threshold, p_thresholds, alpha, labels, fontsize + ) ax.set_xlabel("Participation coefficient") ax.set_ylabel("Within community degree") # Turn off the grid @@ -1419,11 +1434,16 @@ def _compute_node_stability_difference(self) -> None: """ self.node_stability_difference = [ { - country: (stability - self.node_stability[0][country]) - / self.node_stability[0][country] - if country in self.node_stability[0] - else (print(f"Warning: {country} not found in the base scenario."), np.nan) - for country, stability in node_stability.items() + country: ( + (stability - self.node_stability[0][country]) + / self.node_stability[0][country] + if country in self.node_stability[0] + else ( + print(f"Warning: {country} not found in the base scenario."), + np.nan, + ) + ) + for country, stability in node_stability.items() } for node_stability in self.node_stability[1:] ] diff --git a/src/preprocessing.py b/src/preprocessing.py index 4524884..eca5cfe 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -108,7 +108,9 @@ def _prep_trade_matrix( # Make sure that we are not trying to filter for a unit, element or item which is not in the data # This can happen because the FAO data is not consistent across all datasets assert unit in trad["Unit"].unique(), f"unit {unit} not in {trad['Unit'].unique()}" - assert element in trad["Element"].unique(), f"element {element} not in {trad['Element'].unique()}" + assert ( + element in trad["Element"].unique() + ), f"element {element} not in {trad['Element'].unique()}" assert item in trad["Item"].unique(), f"item {item} not in {trad['Item'].unique()}" print("Filter trade matrix") trad = trad[ @@ -280,8 +282,8 @@ def rename_countries( # different states now # rename China; Taiwan Province of to Taiwan codes.loc[codes["Area"] == "China; Taiwan Province of", "Area"] = "Taiwan" - codes.loc[codes["Area"] == 'Serbia and Montenegro', "Area"] = "Serbia" - codes.loc[codes["Area"] == 'Belgium-Luxembourg', "Area"] = "Belgium" + codes.loc[codes["Area"] == "Serbia and Montenegro", "Area"] = "Serbia" + codes.loc[codes["Area"] == "Belgium-Luxembourg", "Area"] = "Belgium" # Create a dictionary with the country codes as keys and country names as values cc = coco.CountryConverter() @@ -431,9 +433,15 @@ def main( # Make sure that production index and trade matrix index/columns are the same # and print out the difference if there is any - assert trade_matrix.index.equals(trade_matrix.columns), f"difference: {trade_matrix.index.difference(trade_matrix.columns)}" - assert production.index.equals(trade_matrix.index), f"difference: {production.index.difference(trade_matrix.index)}" - assert production.index.equals(trade_matrix.columns), f"difference: {production.index.difference(trade_matrix.columns)}" + assert trade_matrix.index.equals( + trade_matrix.columns + ), f"difference: {trade_matrix.index.difference(trade_matrix.columns)}" + assert production.index.equals( + trade_matrix.index + ), f"difference: {production.index.difference(trade_matrix.index)}" + assert production.index.equals( + trade_matrix.columns + ), f"difference: {production.index.difference(trade_matrix.columns)}" # Replace "All_Data" with "global" for readability if region == "All_Data": diff --git a/src/utils.py b/src/utils.py index 87ecdb7..a0296a8 100644 --- a/src/utils.py +++ b/src/utils.py @@ -589,7 +589,12 @@ def get_percolation_threshold( def fill_sector_by_colour( - ax: Axes, z_threshold: float, p_thresholds: list[float], alpha: float, labels=True, fontsize=7 + ax: Axes, + z_threshold: float, + p_thresholds: list[float], + alpha: float, + labels=True, + fontsize=7, ): """ This is a helper function to colour the background in the z-score, participation @@ -701,13 +706,23 @@ def fill_sector_by_colour( alpha=alpha, ) if labels: - ax.text(0.01, 5.5, 'Provincial hub', fontsize=fontsize, color='red', alpha=0.5) - ax.text(0.31, 5.5, 'Connector hub', fontsize=fontsize, color='green', alpha=0.5) - ax.text(0.76, 5.5, 'Kinless hub', fontsize=fontsize, color='gold', alpha=1) - ax.text(0.001, -1.8, 'Ultra\nperi-\npheral\nnon\nhub', fontsize=fontsize, color='purple', alpha=0.5) - ax.text(0.06, -1.8, 'Peripheral non-hub', fontsize=fontsize, color='blue', alpha=0.5) - ax.text(0.63, -1.8, 'Connector non-hub', fontsize=fontsize, color='orange', alpha=1) - ax.text(0.81, -1.8, 'Kinless non-hub', fontsize=fontsize, color='brown', alpha=1) - - - + ax.text(0.01, 5.5, "Provincial hub", fontsize=fontsize, color="red", alpha=0.5) + ax.text(0.31, 5.5, "Connector hub", fontsize=fontsize, color="green", alpha=0.5) + ax.text(0.76, 5.5, "Kinless hub", fontsize=fontsize, color="gold", alpha=1) + ax.text( + 0.001, + -1.8, + "Ultra\nperi-\npheral\nnon\nhub", + fontsize=fontsize, + color="purple", + alpha=0.5, + ) + ax.text( + 0.06, -1.8, "Peripheral non-hub", fontsize=fontsize, color="blue", alpha=0.5 + ) + ax.text( + 0.63, -1.8, "Connector non-hub", fontsize=fontsize, color="orange", alpha=1 + ) + ax.text( + 0.81, -1.8, "Kinless non-hub", fontsize=fontsize, color="brown", alpha=1 + ) diff --git a/tests/test_postprocessing.py b/tests/test_postprocessing.py index ad0598a..780114d 100644 --- a/tests/test_postprocessing.py +++ b/tests/test_postprocessing.py @@ -219,7 +219,9 @@ def test_compute_participation(self, postprocessing_object) -> None: def test_compute_imports(self, postprocessing_object) -> None: assert postprocessing_object.imports is not None - assert len(postprocessing_object.imports) == len(postprocessing_object.scenarios) + assert len(postprocessing_object.imports) == len( + postprocessing_object.scenarios + ) assert all( [ isinstance(v, float)