Skip to content

Commit

Permalink
debugging of cepa_tpsci
Browse files Browse the repository at this point in the history
  • Loading branch information
arnab82 committed Jun 18, 2024
1 parent 15f9155 commit 5dfcacc
Showing 1 changed file with 193 additions and 5 deletions.
198 changes: 193 additions & 5 deletions src/tpsci_outer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@ function do_fois_ci(ref::TPSCIstate{T,N,R}, cluster_ops, clustered_ham;
end

"""
tucker_cepa_solve(ref_vector::TPSCIstate, cepa_vector::TPSCIstate, cluster_ops, clustered_ham; tol=1e-5, cache=true)
tpsci_cepa_solve(ref_vector::TPSCIstate, cepa_vector::TPSCIstate, cluster_ops, clustered_ham; tol=1e-5, cache=true)
# Arguments
- `ref_vector`: Input reference state.
Expand Down Expand Up @@ -1185,7 +1185,8 @@ function tpsci_cepa_solve(ref_vector::TPSCIstate{T,N,R}, cepa_vector::TPSCIstate
# b=deepcopy(x_vector)
# zero!(b)

b=open_matvec_thread(ref_vector, cluster_ops, clustered_ham)
# b=open_matvec_thread(ref_vector, cluster_ops, clustered_ham)
b=deepcopy(sig)

@printf(" Overlap between <0|0>: %18.12e\n", orth_dot(ref_vector, ref_vector)[1])
@printf(" Overlap between <1|0>: %18.12e\n", overlap(x_vector, ref_vector)[1])
Expand Down Expand Up @@ -1405,7 +1406,7 @@ function do_fois_cepa(ref::TPSCIstate{T,N,R}, cluster_ops, clustered_ham;
println(" Do CEPA: Dim = ", length(cepa_vec_i))
println("debugging")
# error()
@time e_cepa, x_cepa_i = tpsci_cepa_solve(ref_vec_i, cepa_vec_i, cluster_ops, clustered_ham, cepa_shift, cepa_mit, tol=tol, max_iter=max_iter, verbose=verbose)
@time e_cepa, x_cepa_i = tpsci_cepa_solve2(ref_vec_i, cepa_vec_i, cluster_ops, clustered_ham, cepa_shift, cepa_mit, tol=tol, max_iter=max_iter, verbose=verbose)

@printf(" E(cepa) corr = %12.8f\n", e_cepa[1])
@printf(" X(cepa) norm = %12.8f\n", sqrt(orth_dot(x_cepa, x_cepa)[1]))
Expand All @@ -1421,5 +1422,192 @@ function do_fois_cepa(ref::TPSCIstate{T,N,R}, cluster_ops, clustered_ham;
end
add!(x_cepa, ref_vec)
orthonormalize!(x_cepa)
return e_cepa, x_cepa
end
return e_cepa_r, x_cepa
end

function do_fois_cepa2(ref::TPSCIstate{T,N,R}, cluster_ops, clustered_ham;
max_iter=20,
cepa_shift="cepa",
cepa_mit=30,
nbody=4,
thresh_foi=1e-6,
thresh_clip=1e-5,
tol=1e-5,
compress=false,
compress_type="matvec",
verbose=true) where {T,N,R}
@printf("\n-------------------------------------------------------\n")
@printf(" Do CEPA\n")
@printf(" thresh_foi = %-8.1e\n", thresh_foi)
@printf(" nbody = %-i\n", nbody)
@printf("\n")
@printf(" Length of Reference = %-i\n", length(ref))
@printf(" Calculation type = %s\n", cepa_shift)
@printf(" Compression type = %s\n", compress_type)
@printf("\n-------------------------------------------------------\n")

#
# Solve variationally in reference space
println()
ref_vec = deepcopy(ref)
@printf(" Solve zeroth-order problem. Dimension = %10i\n", length(ref_vec))
@time e0, ref_vec = tps_ci_direct(ref_vec, cluster_ops, clustered_ham, conv_thresh=tol)

#
# Get First order wavefunction
println()
println(" Compute FOIS. Reference space dim = ", length(ref_vec))
pt1_vec = deepcopy(ref_vec)
pt1_vec=open_matvec_thread(pt1_vec, cluster_ops, clustered_ham, nbody=nbody, thresh=thresh_foi)
for i in 1:R
@printf("Arnab: %12.8f\n", sqrt.(orth_dot(pt1_vec, pt1_vec))[i])
end
project_out!(pt1_vec, ref)
# display(pt1_vec)

# Compress FOIS
if compress==true
norm1 = sqrt.(orth_dot(pt1_vec, pt1_vec))
dim1 = length(pt1_vec)
clip!(pt1_vec, thresh=thresh_clip)
norm2 = sqrt.(orth_dot(pt1_vec, pt1_vec))
dim2 = length(pt1_vec)
@printf(" %-50s%10i → %-10i (thresh = %8.1e)\n", "FOIS Compressed from: ", dim1, dim2, thresh_foi)
for i in 1:R
@printf(" %-50s%10.2e → %-10.2e (thresh = %8.1e)\n", "Norm of |1>: ",norm1[i], norm2[i], thresh_foi)
end
end
for i in 1:R
@printf(" %-50s%10.6f\n", "Overlap between <1|0>: ", overlap(pt1_vec, ref_vec)[i])
end
#

# Solve CEPA
println()
cepa_vec = deepcopy(pt1_vec)
e_cepa_r=[]
println("Do CEPA: Dim = ", length(cepa_vec))

# x_cepa=TPSCIstate(ref.clusters,T=T,R=R)
x_cepa = deepcopy(ref)
zero!(x_cepa)
for i in 1:R
ref_vec_i=extract_chosen_root(ref_vec, i)
cepa_vec_i=extract_chosen_root(cepa_vec, i)
zero!(cepa_vec)
display(cepa_vec_i)
display(ref_vec_i)
println(" Do CEPA: Dim = ", length(cepa_vec_i))
println("debugging")
# error()
@time e_cepa_corr,e_cepa = tpsci_cepa_solve2(ref_vec_i, cepa_vec_i, cluster_ops, clustered_ham, cepa_shift, cepa_mit, tol=tol, max_iter=max_iter, verbose=verbose)

@printf(" E(cepa) corr = %12.8f\n", e_cepa[1])
push!(e_cepa_r, e_cepa[1])

end
return e_cepa_r
end

function tpsci_cepa_solve2(ref_vector::TPSCIstate{T,N,R}, cepa_vector::TPSCIstate, cluster_ops, clustered_ham,
cepa_shift="cepa",
cepa_mit = 50;
tol=1e-5,
max_iter=30,
thresh_foi=1e-6,
verbose=false) where {T,N,R}

ts=1
H00=build_full_H(ref_vector, cluster_ops, clustered_ham)
e0,v0=eigen(H00)
e0=e0[ts]
@printf("Reference Energy: %12.8f\n", e0)

Ec = 0.0
Ecepa = 0
pt1_vec = deepcopy(cepa_vector)
pt1_vec=open_matvec_thread(ref_vector, cluster_ops, clustered_ham, nbody=4, thresh=thresh_foi)
b_correct=deepcopy(cepa_vector)
zero!(b_correct)
for (fock,configs) in pt1_vec.data
for (config, coeffs) in configs
if haskey(cepa_vector, fock)
if haskey(cepa_vector[fock], config)
b_coeffs = pt1_vec[fock][config]
b_correct[fock][config]=b_coeffs
end
end
end
end
n_clusters = length(cepa_vector.clusters)
for it in 1:cepa_mit

if cepa_shift == "cepa"
cepa_mit = 1
shift = 0.0
elseif cepa_shift == "acpf"

shift = Ec * 2.0 / n_clusters
elseif cepa_shift == "aqcc"
shift = (1.0 - (n_clusters-3.0)*(n_clusters - 2.0)/(n_clusters * ( n_clusters-1.0) )) * Ec
elseif cepa_shift == "cisd"
shift = Ec
else
println()
println("NYI: cepa_shift is not available:",cepa_shift)
println()
exit()
end
Hdd=build_full_H(b_correct, cluster_ops, clustered_ham)
# display(size(Hdd))
Hdd .+= -Matrix{eltype(Hdd)}(I(size(Hdd, 1))) * (e0 + shift)
H0d=build_full_H_parallel(ref_vector,b_correct, cluster_ops, clustered_ham)
# display(size(H0d))
Hd0=H0d'
# display(size(v0))
Hd0 = Hd0 * v0[:, ts]
# display(size(Hd0))
Cd = Hdd \ -Hd0
# display(size(Cd))

println(" CEPA(0) Norm : ", @sprintf("%16.12f", norm(Cd)))
v0 = reshape(copy(v0[:, ts]), :, 1)
# display(size(v0))
Cd = reshape(Cd, (size(Cd, 1), 1))
# display(size(Cd))
C = vcat(v0[1,:], Cd)#if I add full v0, then the size of C exceeds the size of H00d as e0 just a number
# display(size(C))#so this is a problem.
num_rows = size(H0d, 1)
E0_column = fill(e0, num_rows)
H00d = hcat(E0_column, H0d)
# H00d = hcat([E0], H0d)

# Compute the energy E
V0_ts = v0[:, ts]
C = reshape(C, :, 1)
# display(size(H00d))
# display(size(C))
# display(size(V0_ts))
E = (V0_ts' * H00d) * C

# E = (v0[:, ts]' * H00d) * C
cepa_last_vectors = C
cepa_last_values = E
println(" CEPA(0) Energy: ", @sprintf("%16.12f", E[1]))

# Check for convergence
if abs(E[1] + e0[1] - Ec[1]) < 1e-10
println("Converged")
@printf("Reference Energy: %12.8f\n", e0)
# display(E)
# display(Ec)
@printf(" E(CEPA) = %18.12f\n", (Ec))

break
end
Ec += (E[1] + e0[1])
end

return Ec-e0, Ec
end

0 comments on commit 5dfcacc

Please sign in to comment.