Skip to content

Commit

Permalink
fixed issues with Arrow integration; create module CompareResonance f…
Browse files Browse the repository at this point in the history
…or analysing resonance optimization; first test runs
  • Loading branch information
Alexander-Reimer committed Mar 20, 2024
1 parent 9d81ba0 commit 6883a68
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 3 deletions.
106 changes: 106 additions & 0 deletions src/CompareResonance.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
module CompareResonance
using MNN
using DataFrames, CSV # for writing to csv
using Dates # for timestamp
using UUIDs # for uuid1
using Random # for setting random seed
using Arrow # for writing to Arrow

function create_df()
df = DataFrame(;
time=DateTime[],
uuid=UUID[],
epochs=Int64[],
rows=Int64[],
columns=Int64[],
goal_num=Int64[],
network=Network[],
trainer=Trainer[],
loss=Float64[],
)
timestamp = Dates.format(now(), "yyyy-mm-ddTHH-MM-SS")
metadata!(df, "time", timestamp; style=:note)
return df
end

function save_df(df)
filepath =
"src/data/ResonanceCurveOptimization/" *
"ResonanceEpochs" *
"_" *
metadata(df)["time"] *
".arrow"
return Arrow.write(filepath, df; maxdepth=7)
end

function make_entry(net, trainer, num_goals_resonance, loss, id)
lock(lk)
push!(
df,
(
now(),
id,
trainer.optimization.epochs,
net.rows,
net.columns,
num_goals_resonance,
net,
trainer,
loss,
),
)
save_df(df)
unlock(lk)
return nothing
end

function main()
global df = create_df()
global lk = ReentrantLock()
epochs = 1000
network_number = 10
number_goals = 3
@sync for _ in 1:network_number
Threads.@spawn begin
id = uuid1()
Random.seed!(id.value)
net = Network(11, 4)
b = MNN.Resonance(net, 3)
t = Trainer(b, Diff(100), PPS())
for _ in 1:(epochs / 20)
loss = train!(net, 20, t)
make_entry(net, t, number_goals, loss, id)
end
end
end
end

function load(path::String)
return DataFrame(Arrow.Table(path))
end

function load()
files = filter(
x -> length(x) > 15 && x[1:15] == "ResonanceEpochs",
readdir("src/data/ResonanceCurveOptimization/"),
)
if isempty(files)
return nothing
end
newest = files[1]
newest_time = Dates.DateTime(newest[17:35], "yyyy-mm-ddTHH-MM-SS")
if length(files) == 1
return load("src/data/ResonanceCurveOptimization/" * newest)
end

for filename in files[2:end]
time = Dates.DateTime(filename[17:35], "yyyy-mm-ddTHH-MM-SS")
if time > newest_time
newest = filename
newest_time = time
end
end
return load("src/data/ResonanceCurveOptimization/" * newest)
end

end
35 changes: 35 additions & 0 deletions src/DataHandling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,31 @@ function Base.convert(
return Diff(x.time)
end

function Base.convert(::Type{Simulation}, x::NamedTuple)
return Diff(x.time)
end

function Base.convert(
::Type{Simulation},
x::NamedTuple{
(:time, :modifier),
Tuple{
Int64,
NamedTuple{
(:behaviour,),
Tuple{
NamedTuple{
(:goals, :modifiers),
Tuple{Dict{Int64,Float64},Dict{Int64,Vector{Float64}}},
},
},
},
},
},
)
return Diff(x.time)
end

function Base.convert(
::Type{Optimization},
x::NamedTuple{
Expand All @@ -66,6 +91,16 @@ function Base.convert(
return PPS(x.initialized, x.init, x.increment, x.selected, x.epochs)
end

function Base.convert(
::Type{Optimization},
x::NamedTuple{
(:initialized, :init, :increment, :selected, :epochs),
Tuple{Bool,Float64,Float64,Set{Missing},Int64},
},
)
return PPS(x.initialized, x.init, x.increment, x.selected, x.epochs)
end

ArrowTypes.arrowname(::Type{Diff}) = :Diff
ArrowTypes.ArrowKind(::Type{Diff}) = ArrowTypes.StructKind()
ArrowTypes.JuliaType(::Val{:Diff}) = Diff
5 changes: 3 additions & 2 deletions src/Evolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ function train!(
else
opt.mutation_strength *= 0.95
end
println(loss)
@info "Current best loss: $loss"

next_gen = [zeros(length(spring_data)) for _ in 1:(opt.popsize)]
next_gen[1:Int(floor(opt.popsize / 5))] = opt.candidates[index[1:Int(
Expand All @@ -139,7 +139,8 @@ function train!(
opt.candidates = copy(next_gen)
opt.epochs += 1
end
return set_spring_data!(network, spring_data)
set_spring_data!(network, spring_data)
return loss
end

#TODO
Expand Down
2 changes: 1 addition & 1 deletion src/PPSOptimizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,5 @@ function train!(
epochs == 0 && break
end
set_spring_data!(network, spring_data)
return nothing
return base_loss
end
Binary file not shown.

0 comments on commit 6883a68

Please sign in to comment.