Skip to content

Commit

Permalink
Merge pull request #310 from ahmedfgad/github-actions
Browse files Browse the repository at this point in the history
Create the plot_pareto_front_curve() method to plot the pareto front …
  • Loading branch information
ahmedfgad authored Jan 7, 2025
2 parents 9eeaa8d + 82fa0f8 commit dd07de9
Showing 1 changed file with 101 additions and 0 deletions.
101 changes: 101 additions & 0 deletions pygad/visualize/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,104 @@ def plot_genes(self,
matplotlib.pyplot.show()

return fig

def plot_pareto_front_curve(self,
title="Pareto Front Curve",
xlabel="Objective 1",
ylabel="Objective 2",
linewidth=3,
font_size=14,
label="Pareto Front",
color="#FF6347",
color_fitness="#4169E1",
grid=True,
alpha=0.7,
marker="o",
save_dir=None):
"""
Creates, shows, and returns the pareto front curve. Can only be used with multi-objective problems.
It only works with 2 objectives.
It also works only after completing at least 1 generation. If no generation is completed, an exception is raised.
Accepts the following:
title: Figure title.
xlabel: Label on the X-axis.
ylabel: Label on the Y-axis.
linewidth: Line width of the plot. Defaults to 3.
font_size: Font size for the labels and title. Defaults to 14.
label: The label used for the legend.
color: Color of the plot.
color_fitness: Color of the fitness points.
grid: Either True or False to control the visibility of the grid.
alpha: The transparency of the pareto front curve.
marker: The marker of the fitness points.
save_dir: Directory to save the figure.
Returns the figure.
"""

if self.generations_completed < 1:
self.logger.error("The plot_pareto_front_curve() method can only be called after completing at least 1 generation but ({self.generations_completed}) is completed.")
raise RuntimeError("The plot_pareto_front_curve() method can only be called after completing at least 1 generation but ({self.generations_completed}) is completed.")

if type(self.best_solutions_fitness[0]) in [list, tuple, numpy.ndarray] and len(self.best_solutions_fitness[0]) > 1:
# Multi-objective optimization problem.
if len(self.best_solutions_fitness[0]) == 2:
# Only 2 objectives. Proceed.
pass
else:
# More than 2 objectives.
self.logger.error(f"The plot_pareto_front_curve() method only supports 2 objectives but there are {self.best_solutions_fitness[0]} objectives.")
raise RuntimeError(f"The plot_pareto_front_curve() method only supports 2 objectives but there are {self.best_solutions_fitness[0]} objectives.")
else:
# Single-objective optimization problem.
self.logger.error("The plot_pareto_front_curve() method only works with multi-objective optimization problems.")
raise RuntimeError("The plot_pareto_front_curve() method only works with multi-objective optimization problems.")

# Plot the pareto front curve.
remaining_set = list(zip(range(0, self.last_generation_fitness.shape[0]), self.last_generation_fitness))
dominated_set, non_dominated_set = self.get_non_dominated_set(remaining_set)

# Extract the fitness values (objective values) of the non-dominated solutions for plotting.
pareto_front_x = [self.last_generation_fitness[item[0]][0] for item in dominated_set]
pareto_front_y = [self.last_generation_fitness[item[0]][1] for item in dominated_set]

# Sort the Pareto front solutions (optional but can make the plot cleaner)
sorted_pareto_front = sorted(zip(pareto_front_x, pareto_front_y))

# Plotting
fig = matplotlib.pyplot.figure()
# First, plot the scatter of all points (population)
all_points_x = [self.last_generation_fitness[i][0] for i in range(self.sol_per_pop)]
all_points_y = [self.last_generation_fitness[i][1] for i in range(self.sol_per_pop)]
matplotlib.pyplot.scatter(all_points_x,
all_points_y,
marker=marker,
color=color_fitness,
label='Fitness',
alpha=1.0)

# Then, plot the Pareto front as a curve
pareto_front_x_sorted, pareto_front_y_sorted = zip(*sorted_pareto_front)
matplotlib.pyplot.plot(pareto_front_x_sorted,
pareto_front_y_sorted,
marker=marker,
label=label,
alpha=alpha,
color=color,
linewidth=linewidth)

matplotlib.pyplot.title(title, fontsize=font_size)
matplotlib.pyplot.xlabel(xlabel, fontsize=font_size)
matplotlib.pyplot.ylabel(ylabel, fontsize=font_size)
matplotlib.pyplot.legend()

matplotlib.pyplot.grid(grid)

if not save_dir is None:
matplotlib.pyplot.savefig(fname=save_dir,
bbox_inches='tight')

matplotlib.pyplot.show()

return fig

0 comments on commit dd07de9

Please sign in to comment.