-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat: neat implementation, training code and running code
- Loading branch information
Showing
2 changed files
with
276 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
# https://neat-python.readthedocs.io/en/latest/xor_example.html | ||
from src.pong import Game | ||
import pygame | ||
import neat | ||
import os | ||
import time | ||
import pickle | ||
|
||
|
||
class PongGame: | ||
def __init__(self, window, width, height): | ||
self.game = Game(window, width, height) | ||
self.ball = self.game.ball | ||
self.left_paddle = self.game.left_paddle | ||
self.right_paddle = self.game.right_paddle | ||
|
||
def test_ai(self, net): | ||
""" | ||
Test the AI against a human player by passing a NEAT neural network | ||
""" | ||
clock = pygame.time.Clock() | ||
run = True | ||
while run: | ||
clock.tick(60) | ||
game_info = self.game.loop() | ||
|
||
for event in pygame.event.get(): | ||
if event.type == pygame.QUIT: | ||
run = False | ||
break | ||
|
||
output = net.activate( | ||
( | ||
self.right_paddle.y, | ||
abs(self.right_paddle.x - self.ball.x), | ||
self.ball.y, | ||
) | ||
) | ||
decision = output.index(max(output)) | ||
|
||
if decision == 1: # AI moves up | ||
self.game.move_paddle(left=False, up=True) | ||
elif decision == 2: # AI moves down | ||
self.game.move_paddle(left=False, up=False) | ||
|
||
keys = pygame.key.get_pressed() | ||
if keys[pygame.K_w]: | ||
self.game.move_paddle(left=True, up=True) | ||
elif keys[pygame.K_s]: | ||
self.game.move_paddle(left=True, up=False) | ||
|
||
self.game.draw(draw_score=True) | ||
pygame.display.update() | ||
|
||
def train_ai(self, genome1, genome2, config, draw=False): | ||
""" | ||
Train the AI by passing two NEAT neural networks and the NEAt config object. | ||
These AI's will play against eachother to determine their fitness. | ||
""" | ||
run = True | ||
start_time = time.time() | ||
|
||
net1 = neat.nn.FeedForwardNetwork.create(genome1, config) | ||
net2 = neat.nn.FeedForwardNetwork.create(genome2, config) | ||
self.genome1 = genome1 | ||
self.genome2 = genome2 | ||
|
||
max_hits = 50 | ||
|
||
while run: | ||
for event in pygame.event.get(): | ||
if event.type == pygame.QUIT: | ||
return True | ||
|
||
game_info = self.game.loop() | ||
|
||
self.move_ai_paddles(net1, net2) | ||
|
||
if draw: | ||
self.game.draw(draw_score=False, draw_hits=True) | ||
|
||
pygame.display.update() | ||
|
||
duration = time.time() - start_time | ||
if ( | ||
game_info.left_score == 1 | ||
or game_info.right_score == 1 | ||
or game_info.left_hits >= max_hits | ||
): | ||
self.calculate_fitness(game_info, duration) | ||
break | ||
|
||
return False | ||
|
||
def move_ai_paddles(self, net1, net2): | ||
""" | ||
Determine where to move the left and the right paddle based on the two | ||
neural networks that control them. | ||
""" | ||
players = [ | ||
(self.genome1, net1, self.left_paddle, True), | ||
(self.genome2, net2, self.right_paddle, False), | ||
] | ||
for genome, net, paddle, left in players: | ||
output = net.activate((paddle.y, abs(paddle.x - self.ball.x), self.ball.y)) | ||
decision = output.index(max(output)) | ||
|
||
valid = True | ||
if decision == 0: # Don't move | ||
genome.fitness -= 0.01 # we want to discourage this | ||
elif decision == 1: # Move up | ||
valid = self.game.move_paddle(left=left, up=True) | ||
else: # Move down | ||
valid = self.game.move_paddle(left=left, up=False) | ||
|
||
if ( | ||
not valid | ||
): # If the movement makes the paddle go off the screen punish the AI | ||
genome.fitness -= 1 | ||
|
||
def calculate_fitness(self, game_info, duration): | ||
self.genome1.fitness += game_info.left_hits + duration | ||
self.genome2.fitness += game_info.right_hits + duration | ||
|
||
|
||
def eval_genomes(genomes, config): | ||
""" | ||
Run each genome against eachother one time to determine the fitness. | ||
""" | ||
width, height = 700, 500 | ||
win = pygame.display.set_mode((width, height)) | ||
pygame.display.set_caption("Pong") | ||
|
||
for i, (genome_id1, genome1) in enumerate(genomes): | ||
print(round(i / len(genomes) * 100), end=" ") | ||
genome1.fitness = 0 | ||
for genome_id2, genome2 in genomes[min(i + 1, len(genomes) - 1) :]: | ||
genome2.fitness = 0 if genome2.fitness == None else genome2.fitness | ||
pong = PongGame(win, width, height) | ||
|
||
force_quit = pong.train_ai(genome1, genome2, config, draw=True) | ||
if force_quit: | ||
quit() | ||
|
||
|
||
def run_neat(config): | ||
checkpoint_folder = "checkpoints" | ||
os.makedirs(checkpoint_folder, exist_ok=True) | ||
checkpoint_prefix = os.path.join(checkpoint_folder, "neat-checkpoint-") | ||
|
||
# Create the population | ||
p = neat.Population(config) | ||
|
||
# Add reporters to show progress in the terminal | ||
p.add_reporter(neat.StdOutReporter(True)) | ||
stats = neat.StatisticsReporter() | ||
p.add_reporter(stats) | ||
|
||
# Add a Checkpointer reporter to save checkpoints in the specified folder | ||
p.add_reporter( | ||
neat.Checkpointer(generation_interval=1, filename_prefix=checkpoint_prefix) | ||
) | ||
|
||
# Run for up to 50 generations | ||
winner = p.run(eval_genomes, 50) | ||
|
||
# Save the winning genome | ||
with open("best.pickle", "wb") as f: | ||
pickle.dump(winner, f) | ||
|
||
|
||
def test_best_network(config): | ||
with open("best.pickle", "rb") as f: | ||
winner = pickle.load(f) | ||
winner_net = neat.nn.FeedForwardNetwork.create(winner, config) | ||
|
||
width, height = 700, 500 | ||
win = pygame.display.set_mode((width, height)) | ||
pygame.display.set_caption("Pong") | ||
pong = PongGame(win, width, height) | ||
pong.test_ai(winner_net) | ||
|
||
|
||
if __name__ == "__main__": | ||
local_dir = os.path.dirname(__file__) | ||
config_path = os.path.join(local_dir, "src/config.txt") | ||
|
||
config = neat.Config( | ||
neat.DefaultGenome, | ||
neat.DefaultReproduction, | ||
neat.DefaultSpeciesSet, | ||
neat.DefaultStagnation, | ||
config_path, | ||
) | ||
|
||
run_neat(config) | ||
# test_best_network(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
[NEAT] | ||
fitness_criterion = max | ||
fitness_threshold = 400 | ||
pop_size = 50 | ||
reset_on_extinction = False | ||
|
||
[DefaultStagnation] | ||
species_fitness_func = max | ||
max_stagnation = 20 | ||
species_elitism = 2 | ||
|
||
[DefaultReproduction] | ||
elitism = 2 | ||
survival_threshold = 0.2 | ||
|
||
[DefaultGenome] | ||
# node activation options | ||
activation_default = relu | ||
activation_mutate_rate = 1.0 | ||
activation_options = relu | ||
|
||
# node aggregation options | ||
aggregation_default = sum | ||
aggregation_mutate_rate = 0.0 | ||
aggregation_options = sum | ||
|
||
# node bias options | ||
bias_init_mean = 3.0 | ||
bias_init_stdev = 1.0 | ||
bias_max_value = 30.0 | ||
bias_min_value = -30.0 | ||
bias_mutate_power = 0.5 | ||
bias_mutate_rate = 0.7 | ||
bias_replace_rate = 0.1 | ||
|
||
# genome compatibility options | ||
compatibility_disjoint_coefficient = 1.0 | ||
compatibility_weight_coefficient = 0.5 | ||
|
||
# connection add/remove rates | ||
conn_add_prob = 0.5 | ||
conn_delete_prob = 0.5 | ||
|
||
# connection enable options | ||
enabled_default = True | ||
enabled_mutate_rate = 0.01 | ||
|
||
feed_forward = True | ||
initial_connection = full_direct | ||
|
||
# node add/remove rates | ||
node_add_prob = 0.2 | ||
node_delete_prob = 0.2 | ||
|
||
# network parameters | ||
num_hidden = 2 | ||
num_inputs = 3 | ||
num_outputs = 3 | ||
|
||
# node response options | ||
response_init_mean = 1.0 | ||
response_init_stdev = 0.0 | ||
response_max_value = 30.0 | ||
response_min_value = -30.0 | ||
response_mutate_power = 0.0 | ||
response_mutate_rate = 0.0 | ||
response_replace_rate = 0.0 | ||
|
||
# connection weight options | ||
weight_init_mean = 0.0 | ||
weight_init_stdev = 1.0 | ||
weight_max_value = 30 | ||
weight_min_value = -30 | ||
weight_mutate_power = 0.5 | ||
weight_mutate_rate = 0.8 | ||
weight_replace_rate = 0.1 | ||
|
||
[DefaultSpeciesSet] | ||
compatibility_threshold = 3.0 |