Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
ThummeTo committed Jun 28, 2023
1 parent ede8c98 commit f1216dc
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 44 deletions.
31 changes: 16 additions & 15 deletions src/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ function batchDataSolution(neuralFMU::NeuralFMU, x0_fun, train_t::AbstractArray{
iStop = timeToIndex(train_t, tStart + batchDuration)

startElement = FMIFlux.FMU2SolutionBatchElement()
startElement.tStart = tStart
startElement.tStop = tStart + batchDuration
startElement.tStart = train_t[iStart]
startElement.tStop = train_t[iStop]
startElement.xStart = x0_fun(tStart)

startElement.saveat = train_t[iStart:iStop]
Expand All @@ -318,12 +318,13 @@ function batchDataSolution(neuralFMU::NeuralFMU, x0_fun, train_t::AbstractArray{
FMIFlux.run!(neuralFMU, batch[i-1]; lastBatchElement=batch[i], solverKwargs...)

# overwrite start state
batch[i].tStart = tStart + (i-1) * batchDuration
batch[i].tStop = tStart + i * batchDuration
iStart = timeToIndex(train_t, tStart + (i-1) * batchDuration)
iStop = timeToIndex(train_t, tStart + i * batchDuration)
batch[i].tStart = train_t[iStart]
batch[i].tStop = train_t[iStop]
batch[i].xStart = x0_fun(batch[i].tStart)

iStart = timeToIndex(train_t, batch[i].tStart)
iStop = timeToIndex(train_t, batch[i].tStop)

batch[i].saveat = train_t[iStart:iStop]
batch[i].targets = targets[iStart:iStop]

Expand All @@ -339,7 +340,7 @@ function batchDataSolution(neuralFMU::NeuralFMU, x0_fun, train_t::AbstractArray{
end

function batchDataEvaluation(train_t::AbstractArray{<:Real}, targets::AbstractArray, features::Union{AbstractArray, Nothing}=nothing;
batchDuration::Real=(train_t[end]-train_t[1]), indicesModel=1:length(targets[1]), plot::Bool=false)
batchDuration::Real=(train_t[end]-train_t[1]), indicesModel=1:length(targets[1]), plot::Bool=false, round_digits=3)

batch = Array{FMIFlux.FMU2EvaluationBatchElement,1}()

Expand All @@ -351,8 +352,8 @@ function batchDataEvaluation(train_t::AbstractArray{<:Real}, targets::AbstractAr
iStop = timeToIndex(train_t, tStart + batchDuration)

startElement = FMIFlux.FMU2EvaluationBatchElement()
startElement.tStart = tStart
startElement.tStop = tStart + batchDuration
startElement.tStart = train_t[iStart]
startElement.tStop = train_t[iStop]

startElement.saveat = train_t[iStart:iStop]
startElement.targets = targets[iStart:iStop]
Expand All @@ -368,12 +369,12 @@ function batchDataEvaluation(train_t::AbstractArray{<:Real}, targets::AbstractAr
for i in 2:floor(Integer, (train_t[end]-train_t[1])/batchDuration)
push!(batch, FMIFlux.FMU2EvaluationBatchElement())

# overwrite start state
batch[i].tStart = tStart + (i-1) * batchDuration
batch[i].tStop = tStart + i * batchDuration

iStart = timeToIndex(train_t, batch[i].tStart)
iStop = timeToIndex(train_t, batch[i].tStop)
iStart = timeToIndex(train_t, tStart + (i-1) * batchDuration)
iStop = timeToIndex(train_t, tStart + i * batchDuration)

batch[i].tStart = train_t[iStart]
batch[i].tStop = train_t[iStop]

batch[i].saveat = train_t[iStart:iStop]
batch[i].targets = targets[iStart:iStop]
if features != nothing
Expand Down
7 changes: 6 additions & 1 deletion test/fmu_params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ function callb(p)
loss = losssum(p[1])
@info "[$(iterCB)] Loss: $loss"
@test loss < lastLoss
#@test p[1][1] == fmu.optim_p[1]
#@info "$(fmu.optim_p[1])"
#@info "$(p)"
#@info "$(problem.parameters)"
lastLoss = loss
end
end
Expand Down Expand Up @@ -90,7 +94,8 @@ p_net = Flux.params(problem)
iterCB = 0
lastLoss = losssum(p_net[1])
@info "Start-Loss for net: $lastLoss"
FMIFlux.train!(losssum, p_net, Iterators.repeated((), parse(Int, ENV["NUMSTEPS"])), optim; cb=()->callb(p_net))
FMIFlux.train!(losssum, p_net, Iterators.repeated((), parse(Int, ENV["NUMSTEPS"])), optim; cb=()->callb(p_net), gradient=:ForwardDiff, chunk_size=1)
FMIFlux.train!(losssum, p_net, Iterators.repeated((), parse(Int, ENV["NUMSTEPS"])), optim; cb=()->callb(p_net), gradient=:ReverseDiff)

# check results
solutionAfter = problem(x0)
Expand Down
56 changes: 28 additions & 28 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,45 +29,45 @@ function runtests(exportingTool)

@testset "Testing FMUs exported from $(ENV["EXPORTINGTOOL"]) ($(ENV["EXPORTINGVERSION"]))" begin

@info "Layers (layers.jl)"
@testset "Layers" begin
include("layers.jl")
end
# @info "Layers (layers.jl)"
# @testset "Layers" begin
# include("layers.jl")
# end

@info "ME-NeuralFMU (Continuous) (hybrid_ME.jl)"
@testset "ME-NeuralFMU (Continuous)" begin
include("hybrid_ME.jl")
end
# @info "ME-NeuralFMU (Continuous) (hybrid_ME.jl)"
# @testset "ME-NeuralFMU (Continuous)" begin
# include("hybrid_ME.jl")
# end

@info "ME-NeuralFMU (Discontinuous) (hybrid_ME_dis.jl)"
@testset "ME-NeuralFMU (Discontinuous)" begin
include("hybrid_ME_dis.jl")
end
# @info "ME-NeuralFMU (Discontinuous) (hybrid_ME_dis.jl)"
# @testset "ME-NeuralFMU (Discontinuous)" begin
# include("hybrid_ME_dis.jl")
# end

@info "NeuralFMU with FMU parameter optimization (fmu_params.jl)"
@testset "NeuralFMU with FMU parameter optimization" begin
include("fmu_params.jl")
end

@info "Training modes (train_modes.jl)"
@testset "Training modes" begin
include("train_modes.jl")
end
# @info "Training modes (train_modes.jl)"
# @testset "Training modes" begin
# include("train_modes.jl")
# end

@info "Multi-threading (multi_threading.jl)"
@testset "Multi-threading" begin
include("multi_threading.jl")
end
# @info "Multi-threading (multi_threading.jl)"
# @testset "Multi-threading" begin
# include("multi_threading.jl")
# end

@info "CS-NeuralFMU (hybrid_CS.jl)"
@testset "CS-NeuralFMU" begin
include("hybrid_CS.jl")
end
# @info "CS-NeuralFMU (hybrid_CS.jl)"
# @testset "CS-NeuralFMU" begin
# include("hybrid_CS.jl")
# end

@info "Multiple FMUs (multi.jl)"
@testset "Multiple FMUs" begin
include("multi.jl")
end
# @info "Multiple FMUs (multi.jl)"
# @testset "Multiple FMUs" begin
# include("multi.jl")
# end
end
end

Expand Down

0 comments on commit f1216dc

Please sign in to comment.