Skip to content

Commit

Permalink
doc fix for FMIFlux.train! (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
0815Creeper authored Feb 21, 2024
1 parent 90ddf7b commit aeacdfd
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/neural.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1682,23 +1682,23 @@ end

"""
train!(loss, params::Union{Flux.Params, Zygote.Params}, data, optim::Flux.Optimise.AbstractOptimiser; gradient::Symbol=:Zygote, cb=nothing, chunk_size::Integer=64, printStep::Bool=false)
train!(loss, neuralFMU::Union{ME_NeuralFMU, CS_NeuralFMU}, data, optim; gradient::Symbol=:ReverseDiff, kwargs...)
A function analogous to Flux.train! but with additional features and explicit parameters (faster).
# Arguments
- `loss` a loss function in the format `loss(p)`
- `params` a object holding the parameters
- `neuralFMU` a object holding the neuralFMU with its parameters
- `data` the training data (or often an iterator)
- `optim` the optimizer used for training
# Keywords
- `gradient` a symbol determining the AD-library for gradient computation, available are `:ForwardDiff`, `:Zygote` and :ReverseDiff (default)
- `cb` a custom callback function that is called after every training step
- `chunk_size` the chunk size for AD using ForwardDiff (ignored for other AD-methods)
- `printStep` a boolean determining wheater the gradient min/max is printed after every step (for gradient debugging)
- `proceed_on_assert` a boolean that determins wheater to throw an ecxeption on error or proceed training and just print the error
- `numThreads` [WIP]: an integer determining how many threads are used for training (how many gradients are generated in parallel)
- `cb` a custom callback function that is called after every training step (default `nothing`)
- `chunk_size` the chunk size for AD using ForwardDiff (ignored for other AD-methods) (default `:auto_fmiflux`)
- `printStep` a boolean determining wheater the gradient min/max is printed after every step (for gradient debugging) (default `false`)
- `proceed_on_assert` a boolean that determins wheater to throw an ecxeption on error or proceed training and just print the error (default `false`)
- `multiThreading`: a boolean that determins if multiple gradients are generated in parallel (default `false`)
- `multiObjective`: set this if the loss function returns multiple values (multi objective optimization), currently gradients are fired to the optimizer one after another (default `false`)
"""
function train!(loss, neuralFMU::Union{ME_NeuralFMU, CS_NeuralFMU}, data, optim; gradient::Symbol=:ReverseDiff, kwargs...)
Expand Down

0 comments on commit aeacdfd

Please sign in to comment.