diff --git a/src/correlation.jl b/src/correlation.jl index 9b196ef0..a0280f0b 100644 --- a/src/correlation.jl +++ b/src/correlation.jl @@ -301,12 +301,18 @@ function MTK.ODESystem(c::CorrelationFunction; kwargs...) ps_ = [ps..., steady_vals...] else avg = average(c.op2_0) - if avg ∈ Set(c.de0.states) || _conj(avg) ∈ Set(c.de0.states) - ps_ = [ps..., average(c.op2)] + if avg ∈ Set(c.de0.states) + avg2 = average(c.op2) + ps_ = [ps..., avg] + de = substitute(c.de, Dict(avg2 => avg)) + elseif _conj(avg) ∈ Set(c.de0.states) + avg2 = average(c.op2) + ps_ = [ps..., _conj(avg)] + de = substitute(c.de, Dict(avg2 => avg)) else ps_ = [ps...] + de = c.de end - de = c.de end ps_avg = filter(x->x isa Average, ps_) @@ -321,16 +327,19 @@ function MTK.ODESystem(c::CorrelationFunction; kwargs...) de_.equations[i] = Symbolics.Equation(lhs, rhs) end + avg0 = average(c.op2_0) if c.steady_state steady_params = map(_make_parameter, steady_vals) subs_params = Dict(steady_vals .=> steady_params) - for i=1:length(de.equations) - de_.equations[i] = substitute(de_.equations[i], subs_params) - end + de_ = substitute(de_, subs_params) + elseif avg0 ∈ Set(c.de0.states) + avg0_par = _make_parameter(avg0) + de_ = substitute(de_, Dict(avg0 => avg0_par)) + elseif _conj(avg0) ∈ Set(c.de0.states) + avg0_par = _make_parameter(_conj(avg0)) + de_ = substitute(de_, Dict(_conj(avg0) => avg0_par)) end - ps_ = map(_make_parameter, ps_) - eqs = MTK.equations(de_) return MTK.ODESystem(eqs, τ; kwargs...) end diff --git a/test/test_correlation.jl b/test/test_correlation.jl index bfe95e0a..c236f870 100644 --- a/test/test_correlation.jl +++ b/test/test_correlation.jl @@ -127,6 +127,19 @@ csol = solve(cprob, RK4()) @test csol.retcode == :Success +# Mollow when not in steady state +c = CorrelationFunction(σ(:e,:g), σ(:g,:e), eqs; steady_state=false) +csys = ODESystem(c) +cu0 = correlation_u0(c, sol.u[end]) +@test length(cu0) == 3 +cp0 = correlation_p0(c, sol.u[end], ps .=> p0) +@test length(cp0) == 4 + +cprob = ODEProblem(csys,cu0,(0.0,20.0),cp0) +csol_ns = solve(cprob, RK4()) + +@test all(csol_ns.u .≈ csol.u) + # When not in steady state -- cavity that decays h = FockSpace(:fock) a = Destroy(h,:a)