Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting for semiclassical objects #404

Merged
merged 13 commits into from
Aug 11, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ RecursiveArrayTools = "2, 3"
Reexport = "0.2, 1.0"
StochasticDiffEq = "6"
WignerSymbols = "1, 2"
julia = "1.3"
julia = "1.10"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
1 change: 1 addition & 0 deletions src/QuantumOptics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module QuantumOptics
using Reexport
@reexport using QuantumOpticsBase
using SparseArrays, LinearAlgebra
import RecursiveArrayTools

export
ylm,
Expand Down
110 changes: 95 additions & 15 deletions src/semiclassical.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
module semiclassical

using QuantumOpticsBase
import Base: ==
import QuantumOpticsBase: IncompatibleBases
import Base: ==, isapprox, +, -, *, /
import ..timeevolution: integrate, recast!, jump, integrate_mcwf, jump_callback,
JumpRNGState, threshold, roll!, as_vector, QO_CHECKS
import LinearAlgebra: normalize, normalize!
import RecursiveArrayTools

using Random, LinearAlgebra
import OrdinaryDiffEq
Expand All @@ -31,26 +33,104 @@
new{B,T,C}(quantum, classical)
end
end

Base.length(state::State) = length(state.quantum) + length(state.classical)
Base.copy(state::State) = State(copy(state.quantum), copy(state.classical))
Base.eltype(state::State) = promote_type(eltype(state.quantum),eltype(state.classical))
normalize!(state::State) = (normalize!(state.quantum); state)
normalize(state::State) = State(normalize(state.quantum),copy(state.classical))

function ==(a::State, b::State)
QuantumOpticsBase.samebases(a.quantum, b.quantum) &&
length(a.classical)==length(b.classical) &&
(a.classical==b.classical) &&
(a.quantum==b.quantum)
end
State{B}(q::T, c::C) where {B,T<:QuantumState{B},C} = State(q,c)

# Standard interfaces
Base.zero(x::State) = State(zero(x.quantum), zero(x.classical))
Base.length(x::State) = length(x.quantum) + length(x.classical)
Base.axes(x::State) = (Base.OneTo(length(x)),)
Base.size(x::State) = size(x.quantum)
Base.ndims(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = ndims(T)
Base.copy(x::State) = State(copy(x.quantum), copy(x.classical))
Base.copyto!(x::State, y::State) = (copyto!(x.quantum, y.quantum); copyto!(x.classical, y.classical); x)
Base.fill!(x::State, a) = (fill!(x.quantum, a), fill!(x.classical, a))
Base.eltype(x::State) = promote_type(eltype(x.quantum),eltype(x.classical))
Base.eltype(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = promote_type(eltype(T), eltype(C))
Base.similar(x::State, ::Type{T} = eltype(x)) where {T} = State(similar(x.quantum, T), similar(x.classical, T))
Base.getindex(x::State, idx) = idx <= length(x.quantum) ? getindex(x.quantum, idx) : getindex(x.classical, idx-length(x.quantum))

normalize!(x::State) = (normalize!(x.quantum); x)
normalize(x::State) = State(normalize(x.quantum),copy(x.classical))
LinearAlgebra.norm(x::State) = LinearAlgebra.norm(x.quantum)

==(x::State{B}, y::State{B}) where {B} = (x.classical==y.classical) && (x.quantum==y.quantum)
==(x::State, y::State) = false

Check warning on line 57 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L57

Added line #L57 was not covered by tests

isapprox(x::State{B}, y::State{B}; kwargs...) where {B} = isapprox(x.quantum,y.quantum; kwargs...) && isapprox(x.classical,y.classical; kwargs...)
isapprox(x::State, y::State; kwargs...) = false

Check warning on line 60 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L60

Added line #L60 was not covered by tests

QuantumOpticsBase.expect(op, state::State) = expect(op, state.quantum)
QuantumOpticsBase.variance(op, state::State) = variance(op, state.quantum)
QuantumOpticsBase.ptrace(state::State, indices) = State(ptrace(state.quantum, indices), state.classical)

QuantumOpticsBase.dm(x::State) = State(dm(x.quantum), x.classical)

Base.broadcastable(x::State) = x

# Custom broadcasting style
struct StateStyle{B} <: Broadcast.BroadcastStyle end

# Style precedence rules
Broadcast.BroadcastStyle(::Type{<:State{B}}) where {B} = StateStyle{B}()
Broadcast.BroadcastStyle(::StateStyle{B1}, ::StateStyle{B2}) where {B1,B2} = throw(IncompatibleBases())
Broadcast.BroadcastStyle(::StateStyle{B}, ::Broadcast.DefaultArrayStyle{0}) where {B} = StateStyle{B}()
Broadcast.BroadcastStyle(::Broadcast.DefaultArrayStyle{0}, ::StateStyle{B}) where {B} = StateStyle{B}()

# Out-of-place broadcasting
@inline function Base.copy(bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
# extract quantum object from broadcast container
qobj = find_quantum(bcf)
data_q = zeros(eltype(qobj), size(qobj)...)
Nq = length(qobj)
# allocate quantum data from broadcast container
@inbounds @simd for I in 1:Nq
data_q[I] = bcf[I]
end

Check warning on line 88 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L88

Added line #L88 was not covered by tests
# extract classical object from broadcast container
cobj = find_classical(bcf)
data_c = zeros(eltype(cobj), length(cobj))
Nc = length(cobj)
# allocate classical data from broadcast container
@inbounds @simd for I in 1:Nc
data_c[I] = bcf[I+Nq]
end

Check warning on line 96 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L96

Added line #L96 was not covered by tests
type = eval(nameof(typeof(qobj)))
return State{B}(type(basis(qobj), data_q), data_c)
end

for f ∈ [:find_quantum, :find_classical]
@eval ($f)(bc::Broadcast.Broadcasted) = ($f)(bc.args)
@eval ($f)(args::Tuple) = ($f)(($f)(args[1]), Base.tail(args))
@eval ($f)(x) = x
@eval ($f)(::Any, rest) = ($f)(rest)

Check warning on line 105 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L105

Added line #L105 was not covered by tests
end
find_quantum(x::State, rest) = x.quantum
find_classical(x::State, rest) = x.classical

# In-place broadcasting
@inline function Base.copyto!(dest::State{B}, bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple}
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
bc′ = Base.Broadcast.preprocess(dest, bc)
# write broadcasted quantum data to dest
qobj = dest.quantum
@inbounds @simd for I in 1:length(qobj)
qobj.data[I] = bc′[I]
end

Check warning on line 118 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L118

Added line #L118 was not covered by tests
# write broadcasted classical data to dest
cobj = dest.classical
@inbounds @simd for I in 1:length(cobj)
cobj[I] = bc′[I+length(qobj)]
end

Check warning on line 123 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L123

Added line #L123 was not covered by tests
return dest
end
@inline Base.copyto!(dest::State{B1}, bc::Broadcast.Broadcasted{<:StateStyle{B2},Axes,F,Args}) where {B1,B2,Axes,F,Args<:Tuple} =

Check warning on line 126 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L126

Added line #L126 was not covered by tests
throw(IncompatibleBases())

Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::State, i) = Base.getindex(x, i)
RecursiveArrayTools.recursive_unitless_bottom_eltype(x::State) = eltype(x)
RecursiveArrayTools.recursivecopy!(dest::State, src::State) = copyto!(dest, src)
RecursiveArrayTools.recursivecopy(x::State) = copy(x)
RecursiveArrayTools.recursivefill!(x::State, a) = fill!(x, a)

Check warning on line 133 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L133

Added line #L133 was not covered by tests

"""
semiclassical.schroedinger_dynamic(tspan, state0, fquantum, fclassical[; fout, ...])
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ names = [

"test_timeevolution_abstractdata.jl",

"test_sciml_broadcast_interfaces.jl",
"test_ForwardDiff.jl"
]

Expand Down
25 changes: 25 additions & 0 deletions test/test_sciml_broadcast_interfaces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using Test
using QuantumOptics
using OrdinaryDiffEq

@testset "sciml interface" begin

# semiclassical ODE problem
b = SpinBasis(1//2)
psi0 = spindown(b)
u0 = ComplexF64[0.5, 0.75]
sc = semiclassical.State(psi0, u0)
t₀, t₁ = (0.0, pi)
σx = sigmax(b)

fquantum(t, q, u) = σx + cos(u[1])*identityoperator(σx)
fclassical!(du, u, q, t) = (du[1] = sin(u[2]); du[2] = 2*u[1])
f!(dstate, state, p, t) = semiclassical.dschroedinger_dynamic!(dstate, fquantum, fclassical!, state, t)
prob = ODEProblem(f!, sc, (t₀, t₁))

sol = solve(prob, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false)
tout, ψt = semiclassical.schroedinger_dynamic([t₀, t₁], sc, fquantum, fclassical!; reltol = 1.0e-8, abstol = 1.0e-10)

@test sol[end] ≈ ψt[end]

end
25 changes: 25 additions & 0 deletions test/test_semiclassical.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test
using QuantumOptics
using LinearAlgebra
using QuantumOpticsBase: IncompatibleBases

@testset "semiclassical" begin

Expand Down Expand Up @@ -175,4 +176,28 @@ after_jump = findlast(t-> !(t∈T), tout4)
@test ψt4[before_jump].quantum == ψ0.quantum
@test ψt4[after_jump].quantum == spindown(ba)⊗fockstate(bf,0)

# Test broadcasting interface
b = FockBasis(10)
bn = FockBasis(20)
u0 = ComplexF64[0.7, 0.2]
psi = fockstate(b, 2)
psin = fockstate(bn, 2)
rho = dm(psi)

sc_ket = semiclassical.State(psi, u0)
sc_ketn = semiclassical.State(psin, u0)
sc_dm = semiclassical.State(rho, u0)

@test Base.size(sc_dm) == Base.size(rho)
@test (copy_sc = copy(sc_ket); Base.fill!(copy_sc, 0.0); copy_sc) == semiclassical.State(fill!(copy(psi), 0.0), fill!(copy(u0), 0.0))
@test Base.similar(sc_ket, Int) isa semiclassical.State
@test normalize!(copy(sc_ket)) == semiclassical.State(normalize!(copy(psi)), u0)
@test !(sc_ket == sc_ketn)
@test !(isapprox(sc_ket, sc_ketn))
@test sc_ket .* 1.0 == sc_ket
@test sc_dm .* 1.0 == sc_dm
@test sc_ket .+ 2.0 == semiclassical.State(psi .+ 2.0, u0 .+ 2.0)
@test sc_dm .+ 2.0 == semiclassical.State(rho .+ 2.0, u0 .+ 2.0)
@test_throws IncompatibleBases sc_ket .+ semiclassical.State(spinup(SpinBasis(10)), u0)

end # testsets