-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.lua
123 lines (94 loc) · 3.35 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
require 'torch'
require 'rnn'
utils = require './utils'
Inputs = require './inputs'
NeuralNetwork = require './network'
CMAES = require './cmaes'
StateNumber = 1
State = savestate.create(StateNumber)
savestate.load(State)
Player = 1
MaxEnemies = 5
LeftMargin = 10
TopMargin = 40
BottomMargin = 230
LineHeight = 10
MaxDistance = 255
MaxTime = 400
MaxEvaluations = 1000
EndLevel = 3000
EndLevelBonus = 1000
BoxRadius = 5
NTiles = (BoxRadius*2+1)*(BoxRadius*2+1)
Ninputs = MaxEnemies + NTiles
Nhidden = 5
Noutputs = 3
net = NeuralNetwork(Ninputs, Nhidden, Noutputs)
-- |hidden weights| + |hidden bias| + |out weights| + |out bias|
GenomeSize = Nhidden * Ninputs + Nhidden + Noutputs * Nhidden + Noutputs
MaxGenerations = 100
cmaes = CMAES(GenomeSize)
Lambda = cmaes.lambda
offspring = cmaes:generateOffspring()
currentOffspring = 1
generationCount = 1
net:setWeights(offspring[currentOffspring].genome)
GenerationStats = {}
framecounter = 0
while true do
emu.speedmode('turbo')
local mario = Inputs.getMario()
local marioState = Inputs.getMarioState()
local marioDead = marioState == 'Dying' or marioState == 'Player dies'
gui.text(LeftMargin, BottomMargin - 3*LineHeight, 'Generation: ' .. generationCount)
gui.text(LeftMargin, BottomMargin - 2*LineHeight, 'Individual: ' .. currentOffspring)
if framecounter > MaxEvaluations or marioDead then
local marioScore = Inputs.getMarioScore() + mario.x + (mario.x > EndLevel and EndLevelBonus or 0)
local gameTime = Inputs.getTime()
local fitness = marioScore + (MaxTime - gameTime)
-- if marioDead then print('Mario\'s dead... :(') end
cmaes:setFitness(currentOffspring, fitness)
-- print('Evaluated offspring ' .. currentOffspring .. ' with fitness of ' .. fitness .. ' ended in ' .. mario.x)
framecounter = 0
currentOffspring = currentOffspring + 1
if currentOffspring > Lambda then
local stats = cmaes:endGeneration()
print('Best fit (' .. generationCount .. '): ' .. stats.best.fitness)
table.insert(GenerationStats, stats)
if generationCount == MaxGenerations then
_.each(GenerationStats, function(k, v)
print(string.format('%.4f', v.best.fitness))
print(v.best.genome)
end)
os.exit()
end
offspring = cmaes:generateOffspring()
currentOffspring = 1
generationCount = generationCount + 1
end
net:setWeights(offspring[currentOffspring].genome)
savestate.load(State)
else
local sprites = Inputs.getSprites()
local distances = Inputs.getDistances(mario, sprites)
local tDistances = torch.Tensor(1, 5):fill(MaxDistance)
for i = 1, #distances do
tDistances[1][i] = distances[i]
end
tDistances = (tDistances:div(MaxDistance) * 2) - 1
if framecounter % 3 == 0 then
local tiles = Inputs.getTiles(BoxRadius, mario)
local tTiles = torch.Tensor(tiles):reshape(1, #tiles)
local input = torch.cat(tTiles, tDistances)
output = net:feed(input)
end
joypad.set(Player, { A = (output[1][1] > 0), left = (output[1][2] > 0), right = (output[1][3] > 0) })
gui.text(LeftMargin, TopMargin, 'Mario ' .. (mario and string.format('%d, %d', mario.x, mario.y) or 'NaN'))
for i = 1, MaxEnemies do
local text = sprites[i] and string.format('%d, %d, %.3f', sprites[i].x, sprites[i].y, tDistances[1][i]) or 'NaN'
gui.text(LeftMargin, TopMargin + (i*LineHeight), 'Sprite' .. i .. ' ' .. text)
end
framecounter = framecounter + 1
end
emu.frameadvance()
end