Skip to content

Commit

Permalink
interface added - rpe_basis_new
Browse files Browse the repository at this point in the history
Add an interface for EQM - to use the current code, at most 5 lines in EQM should be changed
  • Loading branch information
zhanglw0521 committed Jan 9, 2025
1 parent f2392d7 commit 5e9a5ff
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 49 deletions.
23 changes: 21 additions & 2 deletions src/O3_alternative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using PartialWaveFunctions
using Combinatorics
using LinearAlgebra

export re_basis_new, ri_basis_new, ind_corr_s1, ind_corr_s2, MatFmi, ML0, MatFmi2, ri_rpi, re_rpe
export re_basis_new, ri_basis_new, ind_corr_s1, ind_corr_s2, MatFmi, ML0, MatFmi2, ri_rpi, re_rpe, rpe_basis_new

function CG(l,m,L,N)
M=m[1]+m[2]
Expand Down Expand Up @@ -502,4 +502,23 @@ function re_rpe(n::SVector{N,Int64},l::SVector{N,Int64},L::Int64) where N
end
end
return UMatrix, FMatrix, MMmat, MM
end
end

function gram(X)
G = zeros(ComplexF64, size(X,1), size(X,1))
for i = 1:size(X,1)
for j = i:size(X,1)
G[i,j] = sum(dot(X[i,t], X[j,t]') for t = 1:size(X,2))
i == j ? nothing : (G[j,i]=G[i,j]')
end
end
return G
end

function rpe_basis_new(nn::SVector{N, Int64}, ll::SVector{N, Int64}, L::Int64) where N
t_re = @elapsed UMatrix, FMatrix, MMmat, MM = re_rpe(nn, ll, L)
@show t_re # should be removed in the final version
U, S, V = svd(gram(FMatrix))
rk = rank(Diagonal(S); rtol = 1e-12)
return Diagonal(S[1:rk]) * (U[:, 1:rk]' * UMatrix), MM
end
129 changes: 82 additions & 47 deletions test/new_rpe_test.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

using SpheriCart, StaticArrays, LinearAlgebra, RepLieGroups, WignerD,
Combinatorics, Rotations
using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real
using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, coco_dot
using RepLieGroups: gram
O3 = RepLieGroups.O3
using Test

Expand Down Expand Up @@ -140,12 +141,30 @@ function sym_rand_batch(; coeffs, MM, ll, nn,
return BB
end

function gram(X)
G = zeros(ComplexF64, size(X,1), size(X,1))
for i = 1:size(X,1)
for j = i:size(X,1)
G[i,j] = sum(dot(X[i,t], X[j,t]') for t = 1:size(X,2))
G[j,i] = G[i,j]'
# The following two functions are hacked from the EQM package, just using as reference and for comparison
function rpe_basis(A::Union{Rot3DCoeffs,Rot3DCoeffs_long,Rot3DCoeffs_real}, nn::SVector{N, TN}, ll::SVector{N, Int}) where {N, TN}
t_re_old = @elapsed Ure, Mre = O3.re_basis(A, ll)
@show t_re_old
G = _gramian(nn, ll, Ure, Mre)
S = svd(G)
rk = rank(Diagonal(S.S); rtol = 1e-7)
Urpe = S.U[:, 1:rk]'
return Diagonal(sqrt.(S.S[1:rk])) * Urpe * Ure, Mre
end


function _gramian(nn, ll, Ure, Mre)
N = length(nn)
nre = size(Ure, 1)
G = zeros(Complex{Float64}, nre, nre)
for σ in permutations(1:N)
if (nn[σ] != nn) || (ll[σ] != ll); continue; end
for (iU1, mm1) in enumerate(Mre), (iU2, mm2) in enumerate(Mre)
if mm1[σ] == mm2
for i1 = 1:nre, i2 = 1:nre
G[i1, i2] += coco_dot(Ure[i1, iU1], Ure[i2, iU2])
end
end
end
end
return G
Expand All @@ -154,15 +173,15 @@ end

##
# CASE 2: 4-correlations
for L = 1:4
for L = 0:4
@info("Testing L = $L")
cc = Rot3DCoeffs(L)
# now we fix an ll = (l1, l2, l3) triple ask for all possible linear combinations
# of the tensor product basis Y[l1, m1] * Y[l2, m2] * Y[l3, m3] * Y[l4, m4]
# that are invariant under O(3) rotations.
# ll = SA[1,2,2,2,3]
# ll_list = [SA[2,2,2,2], SA[2,2,2,2], SA[2,2,2,4], SA[2,2,3,3], SA[1,1,2,2,2], SA[1,1,2,2,2], SA[1,2,2,2,3], SA[2,2,2,2,2] ]
# nn_list = [SA[1,1,1,2], SA[1,1,2,3], SA[1,1,1,1], SA[1,1,1,1], SA[1,2,1,1,1], SA[2,2,1,1,2], SA[1,1,1,1,1], SA[1,1,1,1,1] ]
# nn_list = [SA[1,1,1,2], SA[1,1,2,3], SA[1,1,1,1], SA[1,1,1,1], SA[1,2,1,1,1], SA[2,2,1,1,2], SA[1,1,1,1,1], SA[1,1,1,1,1] ]



Expand Down Expand Up @@ -194,11 +213,11 @@ for L = 1:4

verbose = true

@info("Using ultra short nnll list for testing")
nnll_list = ultra_short_nnll_list
# @info("Using ultra short nnll list for testing")
# nnll_list = ultra_short_nnll_list

# @info("Using short nnll list for testing")
# nnll_list = short_nnll_list
@info("Using short nnll list for testing")
nnll_list = short_nnll_list

# @info("Using long nnll list for testing")
# nnll_list = long_nnll_list
Expand All @@ -207,85 +226,101 @@ for L = 1:4
for (itest, (nn, ll)) in enumerate(nnll_list)
N = length(ll)
@assert length(ll) == length(nn)
t1 = @elapsed coeffs1, MM1 = O3.re_basis(cc, ll)
nbas_ri1 = size(coeffs1, 1)
# rank(coeffs1, rtol = 1e-12)
# t_re_old = @elapsed O3.re_basis(cc, ll)
t_rpe_old = @elapsed coeffs_ind1_origin, MM1_origin = rpe_basis(cc, nn, ll) # This is the rpe coupling coefficients in EQM, which is our reference
rk1 = rank(gram(coeffs_ind1_origin); rtol=1e-12) # rank of the reference coupling coefficients

# coeffs1, MM1 = O3.re_basis(cc, ll)
# nbas_ri1 = size(coeffs1, 1)
# rk1 = rank(gram(coeffs1); rtol=1e-12)
# U, S, V = svd(gram(coeffs1))
# coeffs_ind1 = Diagonal(S[1:rk1]) * (U[:, 1:rk1]' * coeffs1)
# @test sort(MM1) == sort(MM1_origin)

# NOTE: Such constructed coeff_ind1 should span a larger space as the original coeffs1
# which can be seen by comparing the functions `gram`` and `_gramian`
# However, the function `gram`` should work for the FMatrix as the permutation has been taken care of.

# @test rank(gram(coeffs_ind1_origin); rtol=1e-12) <= rank(gram(coeffs_ind1); rtol=1e-12)

Rs = rand_config(length(ll))
θ = rand(3) * 2pi
Q = RotZYZ...)
D = transpose(wignerD(L, θ...))
QRs = [Q*Rs[i] for i in 1:length(Rs)]
fRs1 = eval_basis(Rs; coeffs = coeffs1, MM = MM1, ll = ll, nn = nn)
fRs1Q = eval_basis(QRs; coeffs = coeffs1, MM = MM1, ll = ll, nn = nn)
@test norm(fRs1 - Ref(D) .* fRs1Q) < 1e-15
# fRs1 = eval_basis(Rs; coeffs = coeffs_ind1, MM = MM1, ll = ll, nn = nn)
# fRs1Q = eval_basis(QRs; coeffs = coeffs_ind1, MM = MM1, ll = ll, nn = nn)
# @test norm(fRs1 - Ref(D) .* fRs1Q) < 1e-15
fRs1 = eval_basis(Rs; coeffs = coeffs_ind1_origin, MM = MM1_origin, ll = ll, nn = nn)
fRs1Q = eval_basis(QRs; coeffs = coeffs_ind1_origin, MM = MM1_origin, ll = ll, nn = nn)
L == 0 ? (@test norm(fRs1 - fRs1Q) < 1e-14) : (@test norm(fRs1 - Ref(D) .* fRs1Q) < 1e-14)

ntest = 1000

RR = make_batch(ntest, length(ll))

X = rand_batch(; coeffs=coeffs1, MM=MM1, ll=ll, nn=nn, batch = RR)
X = rand_batch(; coeffs=coeffs_ind1_origin, MM=MM1_origin, ll=ll, nn=nn, batch = RR)
@test rank(gram(X); rtol=1e-12) == size(X,1)

Xsym = sym_rand_batch(; coeffs=coeffs1, MM=MM1, ll=ll, nn=nn, batch = RR)
rk1 = rank(gram(Xsym); rtol=1e-12)
U, S, V = svd(gram(Xsym))
coeffs_ind1 = Diagonal(S[1:rk1]) \ (U[:, 1:rk1]' * coeffs1)

# Version GD
t_rpe = @elapsed coeffs2, coeffs_rpe, MMmat, MM2 = re_rpe(nn,ll,L)
# computes the RI coupling coefs and RPI coefs at the same time

rk2 = rank(gram(coeffs_rpe),rtol = 1e-12)
@test rk1 == rk2
Xsym = sym_rand_batch(; coeffs=coeffs_ind1_origin, MM=MM1_origin, ll=ll, nn=nn, batch = RR)
@test rank(gram(Xsym); rtol=1e-12) == rk1

# if RepLieGroups.SetLl_new(ll,L) |> length != 0
if rk1>0
U, S, V = svd(gram(coeffs_rpe))
coeffs_ind2 = Diagonal(S[1:rk2]) \ (U[:, 1:rk2]' * coeffs2)

if rk1 > 0
# Version GD
# rewritten as a new interface
# t_re = @elapsed re_rpe(nn,ll,L) # this is slightly longer than the new re, because it computes also the FMatrix for RPE
t_rpe = @elapsed coeffs_ind2, MM2 = rpe_basis_new(nn,ll,L)
# computes the RI coupling coefs and RPI coefs at the same time

# rk2 = rank(gram(coeffs2),rtol = 1e-12)
rk2 = rank(gram(coeffs_ind2),rtol = 1e-12)
@test rk1 == rk2

Xsym_new = rand_batch(; coeffs=coeffs_ind2, MM=MM2, ll=ll, nn=nn, batch=RR) #this is symmetric
@test rank(gram(Xsym_new); rtol=1e-12) == rk2

# NOTE FROM CO: same batch is used so can compare!!!
# @show rank(Xsym)
# @show rank(Xsym_new)
# @show rank([Xsym; Xsym_new], rtol = 1e-12)
fRs1 = eval_basis(Rs; coeffs = coeffs2, MM = MM2, ll = ll, nn = nn)
fRs1Q = eval_basis(QRs; coeffs = coeffs2, MM = MM2, ll = ll, nn = nn)
@test norm(fRs1 - Ref(D) .* fRs1Q) < 1e-15

fRs1 = eval_basis(Rs; coeffs = coeffs_ind2, MM = MM2, ll = ll, nn = nn)
fRs1Q = eval_basis(QRs; coeffs = coeffs_ind2, MM = MM2, ll = ll, nn = nn)
L == 0 ? (@test norm(fRs1 - fRs1Q) < 1e-14) : (@test norm(fRs1 - Ref(D) .* fRs1Q) < 1e-14)
# Up to here, we checked that coeff_ind1_origin and coeff_ind2 has the same rank and span the space with correct equivariance,
# and hence they span the same space.

P1 = sortperm(MM1)
P1 = sortperm(MM1_origin)
P2 = sortperm(MM2)
MMsorted1 = MM1[P1]
MMsorted1 = MM1_origin[P1]
MMsorted2 = MM2[P2]
# check that same mm values
@test MMsorted1 == MMsorted2

coeffsp1 = coeffs_ind1[:,P1]
coeffsp1 = coeffs_ind1_origin[:,P1]
coeffsp2 = coeffs_ind2[:,P2]

# Check that coefficients span same space
@test rank(gram([coeffsp1;coeffsp2]); rtol=1e-12) == rk2
@test rank(gram([coeffsp1;coeffsp2]); rtol=1e-12) == rank(gram(coeffsp2); rtol=1e-12) == rank(gram(coeffsp2); rtol=1e-12)


# Do the rand batch on the same set of points
ORD = length(ll) # length of each group
BB1 = complex.(zeros(typeof(coeffs_ind1[1]), size(coeffs_ind1, 1), ntest))
BB1 = complex.(zeros(typeof(coeffs_ind1_origin[1]), size(coeffs_ind1_origin, 1), ntest))
BB2 = complex.(zeros(typeof(coeffs_ind2[1]), size(coeffs_ind2, 1), ntest))
for i = 1:ntest
# construct a random set of particles with 𝐫 ∈ ball(radius=1)
Rs = [ rand_ball() for _ in 1:ORD ]
BB1[:, i] = eval_basis(Rs; coeffs=coeffs_ind1, MM=MM1, ll=ll, nn=nn)
BB1[:, i] = eval_basis(Rs; coeffs=coeffs_ind1_origin, MM=MM1_origin, ll=ll, nn=nn)
BB2[:, i] = eval_basis(Rs; coeffs=coeffs_ind2, MM=MM2, ll=ll, nn=nn)
end

# Check that values span same space
@test rank(gram([BB1;BB2]); rtol=1e-11) == rk2
@test rank(gram([BB1;BB2]); rtol=1e-11) == rank(gram(BB1); rtol=1e-11) == rank(gram(BB2); rtol=1e-11)

if verbose
@info("Test $itest: t1 = $t1, t_rpe = $t_rpe")
# @info("Test $itest: t_re_old = $t_re_old, t_re = $t_re")
@info("Test $itest: t_rpe_old = $t_rpe_old, t_rpe = $t_rpe")
else
print(".")
end
Expand Down

0 comments on commit 5e9a5ff

Please sign in to comment.