Skip to content

Commit

Permalink
test RPI - debug
Browse files Browse the repository at this point in the history
  • Loading branch information
dussong committed Oct 31, 2024
1 parent 3a9a424 commit 64d8b44
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 87 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpheriCart = "5caf2b29-02d9-47a3-9434-5931c85ba645"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
StaticArrays = "1.5"
Expand Down
176 changes: 174 additions & 2 deletions src/O3_alternative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using PartialWaveFunctions
using Combinatorics
using LinearAlgebra

export re_basis_new
export re_basis_new, ri_basis_new, ind_corr_s1, ind_corr_s2, MatFmi, ML0

function CG(l,m,L,N)
M=m[1]+m[2]
Expand Down Expand Up @@ -50,6 +50,30 @@ function SetLl0(l,N)
return set
end

function SetLl(l,N,L)
set = Vector{Int64}[]
for k in abs(l[1]-l[2]):l[1]+l[2]
push!(set, [0; k])
end
for k in 3:N-1
setL=set
set=Vector{Int64}[]
for a in setL
for b in abs(a[k-1]-l[k]):a[k-1]+l[k]
push!(set, [a; b])
end
end
end
setL=set
set=Vector{Int64}[]
for a in setL
if (abs.(a[N-1]-l[N]) <= L)&&(L <= (a[N-1]+l[N]))
push!(set, [a; L])
end
end
return set
end

# Function that computes the set ML0
function ML0(l,N)
setML = [[i] for i in -abs(l[1]):abs(l[1])]
Expand All @@ -70,7 +94,29 @@ function ML0(l,N)
return setML0
end

function re_basis_new(l)
# Function that computes the set ML (relative to equivariance L)
function ML(l,N,L)
setML = [[i] for i in -abs(l[1]):abs(l[1])]
for k in 2:N-1
set = setML
setML = Vector{Int64}[]
for m in set
append!(setML, [m; lk] for lk in -abs(l[k]):abs(l[k]) )
end
end
setML0=Vector{Int64}[]
for m in setML
s=sum(m)
for mn in -L-s:L-s
if abs(mn) < abs(l[N])+1
push!(setML0, [m; mn])
end
end
end
return setML0
end

function ri_basis_new(l)
N=size(l,1)
L=SetLl0(l,N)
r=size(L,1)
Expand All @@ -89,4 +135,130 @@ function re_basis_new(l)
end
end
return U,M
end

function re_basis_new(l,L)
N=size(l,1)
Ll=SetLl(l,N,L)
r=size(Ll,1)
if r==0
return zeros(Float64, 0, 0)
else
setML0=ML(l,N,L)
sizeML0=length(setML0)
U=zeros(Float64, r, sizeML0)
M = Vector{Int64}[]
for (j,m) in enumerate(setML0)
push!(M,m)
for i in 1:r
U[i,j]=CG(l,m,Ll[i],N)
end
end
end
return U,M
end


# Function that computes the permutations that let n and l invariant
function Snl(N,n,l)
if n==n[1]*ones(N)
if l==l[1]*ones(N)
return permutations(1:N)
end
end
if N==1
return Set([[1]])
elseif (n[N-1],l[N-1])!=(n[N],l[N])
S=Set()
Sn=Snl(N-1,n[1:N-1],l[1:N-1])
for x in Sn
append!(x,[N])
union!(S,Set([x]))
end
else
S=Set()
k=N
while (n[k-1],l[k-1])==(n[k],l[k]) && k>2
k-=1
end
if k==2 && (n[1],l[1])==(n[2],l[2])
return Set(permutations(1:N))
else
Sn=Snl(k-1,n[1:k-1],l[1:k-1])
for x in Sn
for s in Set(permutations(k:N))
y=copy(x)
append!(y,s)
union!(S,Set([y]))
end
end
end
end
return S
end


#Function that computes the set of classes using the set Ml0 and the possible permutations
function class(setML0,sigma,N,l)
setclass=Vector{Vector{Int64}}[]
pop!(setML0,zeros(Int64,N))
while setML0!=Set()
x=pop!(setML0)
p=[x]
for s in sigma
y=x[s]
if y in setML0
append!(p,[y])
pop!(setML0,y)
end
end
append!(setclass,[p])
end
setclasses=Vector{Vector{Int64}}[]
for x in setclass
for y in setclass
if x==y
if minimum(x)==minimum(-x)
if iseven(sum(l))
append!(setclasses,[x])
end
end
elseif minimum(x)==minimum(-y)
if y<x
append!(setclasses,[x])
end
end
end
end
if iseven(sum(l))
append!(setclasses,[[zeros(N)]])
end
setclasses
end



# Function that computes the matrix ( f(m,i) )
function MatFmi(n,l)
N=size(l,1)
L=SetLl0(l,N)
r=size(L,1)
if r==0
return zeros(Float64, 0, 0)
else
ML00 = ML0(l,N)
setML0=Set(ML00)
sigma = Snl(N,n,l)
setclass=class(setML0,sigma,N,l)
sizeML0=length(setclass)
Matrix=zeros(Float64, r, sizeML0)
for i in 1:r
for j in 1:sizeML0
for m in setclass[j]
Matrix[i,j]+=CG(l,m,L[i],N)
end
end
end
end
return Matrix, ML00
end
66 changes: 62 additions & 4 deletions test/test_RI_basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ end
# for the moment the code with generalized CG only works with L=0
L = 0
cc = Rot3DCoeffs(L)
ll = SA[2,2,2,3,3]
ll = SA[1,1,2,2,4]

# version with svd
@time coeffs1, MM1 = O3.re_basis(cc, ll)
nbas = size(coeffs1, 1)

#version with gen CG coefficients
@time coeffs2, MM2 = re_basis_new(ll)
@time coeffs2, MM2 = ri_basis_new(ll)

# simple test on size
@test size(coeffs1) == size(coeffs2)
Expand All @@ -85,7 +85,7 @@ coeffsp2 = coeffs2[:,P2]
@test rank(coeffsp2) == size(coeffsp2,1)

# check that the coef span the same space - test fails
@test nbas == rank([coeffsp; coeffsp2], rtol = 1e-12)
@test nbas == rank([coeffsp1; coeffsp2], rtol = 1e-12)


Rs = [rand_sphere() for _ in 1:length(ll)]
Expand Down Expand Up @@ -122,4 +122,62 @@ end

# check that functions span same space
rk = rank([A1;A2]; rtol = 1e-12)
@test rk == nbas
@test rk == nbas

# ----------------------------
# Extension to equivariance
# ----------------------------

# Test SetLl0
N = 5
l = rand(1:10, N)
SL1 = RepLieGroups.SetLl(l,N,0)
SL2 = RepLieGroups.SetLl0(l,N)
@test SL1 == SL2

# Test SetML
N = 5
l = rand(1:10, N)
SM1 = RepLieGroups.ML(l,N,0)
SM2 = RepLieGroups.ML0(l,N)
@test SM1 == SM2


# # Version CO
# L = 1
# cc = Rot3DCoeffs(L)
# ll = SA[1,1,2,3,4]

# # version with svd
# @time coeffs1, MM1 = O3.re_basis(cc, ll)
# nbas = size(coeffs1, 1)

# # Version GD
# @time coeffs2, MM2 = re_basis_new(ll,L)

# # simple test on size
# @test size(coeffs1) == size(coeffs2)
# @test size(MM1) == size(MM2)

# P1 = sortperm(MM1)
# P2 = sortperm(MM2)
# MMsorted1 = MM1[P1]
# MMsorted2 = MM2[P2]
# # check that same mm values
# @test MMsorted1 == MMsorted2

# coeffsp1 = coeffs1[:,P1]
# coeffsp2 = coeffs2[:,P2]

# # test that full rank
# @test rank(coeffsp1) == size(coeffsp1,1)
# @test rank(coeffsp2) == size(coeffsp2,1)

# # check that the coef span the same space
# @test nbas == rank([coeffsp1; coeffsp2], rtol = 1e-12)

# Rs = [rand_sphere() for _ in 1:length(ll)]
# Q = rand_rot()
# QRs = [Q*Rs[i] for i in 1:length(Rs)]
# fRs1 = [ f(Rs, q; coeffs=coeffs1, MM=MM1, ll=ll) for q = 1:nbas ]
# fRs1Q = [ f(QRs, q; coeffs=coeffs1, MM=MM1, ll=ll) for q = 1:nbas ]
Loading

0 comments on commit 64d8b44

Please sign in to comment.