Skip to content

Commit

Permalink
MPI parallelism for pbc
Browse files Browse the repository at this point in the history
  • Loading branch information
obackhouse committed Aug 17, 2023
1 parent c4237d2 commit 2d945c6
Show file tree
Hide file tree
Showing 10 changed files with 277 additions and 186 deletions.
11 changes: 11 additions & 0 deletions momentGW/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,17 @@ def with_df(self):
raise ValueError("GW solvers require density fitting.")
return self._scf.with_df

@property
def has_fock_loop(self):
"""
Returns a boolean indicating whether the solver requires a Fock
loop. For most GW methods, this is simply `self.fock_loop`. In
some methods such as qsGW, a Fock loop is required with or
without `self.fock_loop` for the quasiparticle self-consistency,
with this property acting as a hook to indicate this.
"""
return self.fock_loop

get_nmo = get_nmo
get_nocc = get_nocc
get_frozen_mask = get_frozen_mask
Expand Down
2 changes: 1 addition & 1 deletion momentGW/gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def ao2mo(self):
self.mo_occ,
compression=self.compression,
compression_tol=self.compression_tol,
store_full=self.fock_loop,
store_full=self.has_fock_loop,
)
integrals.transform()

Expand Down
2 changes: 1 addition & 1 deletion momentGW/pbc/gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def ao2mo(self):
self.mo_occ,
compression=self.compression,
compression_tol=self.compression_tol,
store_full=self.fock_loop,
store_full=self.has_fock_loop,
)
integrals.transform()

Expand Down
275 changes: 163 additions & 112 deletions momentGW/pbc/ints.py

Large diffs are not rendered by default.

34 changes: 30 additions & 4 deletions momentGW/pbc/kpts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import scipy.linalg
from dyson import Lehmann
from pyscf import lib
from pyscf.agf2 import GreensFunction, SelfEnergy
from pyscf.agf2 import GreensFunction, SelfEnergy, mpi_helper
from pyscf.pbc.lib import kpts_helper

# TODO make sure this is rigorous
Expand Down Expand Up @@ -98,14 +98,40 @@ def conserve(self, ki, kj, kk):
"""
return self._kconserv[ki, kj, kk]

def loop(self, depth):
def loop(self, depth, mpi=False):
"""
Iterate over all combinations of k-points up to a given depth.
"""

if depth == 1:
yield from range(len(self))
seq = range(len(self))
else:
yield from itertools.product(range(len(self)), repeat=depth)
seq = itertools.product(range(len(self)), repeat=depth)

if mpi:
size = len(self) * depth
split = lambda x: x * size // mpi_helper.size

p0 = split(mpi_helper.rank)
p1 = size if mpi_helper.rank == (mpi_helper.size - 1) else split(mpi_helper.rank + 1)

seq = itertools.islice(seq, p0, p1)

yield from seq

def loop_size(self, depth=1):
"""
Return the size of `loop`. Without MPI, this is equivalent to
`len(self)**depth`.
"""

size = len(self) * depth
split = lambda x: x * size // mpi_helper.size

p0 = split(mpi_helper.rank)
p1 = size if mpi_helper.rank == (mpi_helper.size - 1) else split(mpi_helper.rank + 1)

return p1 - p0

@allow_single_kpt(output_is_kpts=False)
def is_zero(self, kpts):
Expand Down
120 changes: 61 additions & 59 deletions momentGW/pbc/tda.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ def __init__(
else:
self.compression_tol = None

def compress_eris(self):
"""Compress the ERI tensors."""

return # TODO

def build_dd_moments(self):
"""Build the moments of the density-density response."""

Expand All @@ -91,38 +86,39 @@ def build_dd_moments(self):
moments = np.zeros((self.nkpts, self.nkpts, self.nmom_max + 1), dtype=object)

# Get the zeroth order moment
for q, kb in kpts.loop(2):
kj = kpts.member(kpts.wrap_around(self.kpts[kb] - self.kpts[q]))
moments[q, kb, 0] += self.integrals.Lia[kj, kb] / self.nkpts
for q in kpts.loop(1):
for kj in kpts.loop(1, mpi=True):
kb = kpts.member(kpts.wrap_around(kpts[q] + kpts[kj]))
moments[q, kb, 0] += self.integrals.Lia[kj, kb] / self.nkpts
cput1 = lib.logger.timer(self.gw, "zeroth moment", *cput0)

# Get the higher order moments
for i in range(1, self.nmom_max + 1):
for q, kb in kpts.loop(2):
kj = kpts.member(kpts.wrap_around(self.kpts[kb] - self.kpts[q]))

d = lib.direct_sum(
"a-i->ia",
self.mo_energy_w[kb][self.mo_occ_w[kb] == 0],
self.mo_energy_w[kj][self.mo_occ_w[kj] > 0],
)
moments[q, kb, i] += moments[q, kb, i - 1] * d.ravel()[None]

for q, ka, kb in kpts.loop(3):
ki = kpts.member(kpts.wrap_around(self.kpts[ka] - self.kpts[q]))
kj = kpts.member(kpts.wrap_around(self.kpts[kb] - self.kpts[q]))

moments[q, kb, i] += (
np.linalg.multi_dot(
(
moments[q, ka, i - 1],
self.integrals.Lia[ki, ka].T.conj(), # NOTE missing conj in notes
self.integrals.Lai[kj, kb].conj(),
)
for q in kpts.loop(1):
for kj in kpts.loop(1, mpi=True):
kb = kpts.member(kpts.wrap_around(kpts[q] + kpts[kj]))

d = lib.direct_sum(
"a-i->ia",
self.mo_energy_w[kb][self.mo_occ_w[kb] == 0],
self.mo_energy_w[kj][self.mo_occ_w[kj] > 0],
)
* 2.0
/ self.nkpts
)
moments[q, kb, i] += moments[q, kb, i - 1] * d.ravel()[None]

tmp = np.zeros((self.naux[q], self.naux[q]), dtype=complex)
for ki in kpts.loop(1, mpi=True):
ka = kpts.member(kpts.wrap_around(kpts[q] + kpts[ki]))

tmp += np.dot(moments[q, ka, i - 1], self.integrals.Lia[ki, ka].T.conj())

tmp = mpi_helper.allreduce(tmp)
tmp *= 2.0
tmp /= self.nkpts

for kj in kpts.loop(1, mpi=True):
kb = kpts.member(kpts.wrap_around(kpts[q] + kpts[kj]))

moments[q, kb, i] += np.dot(tmp, self.integrals.Lai[kj, kb].conj())

cput1 = lib.logger.timer(self.gw, "moment %d" % i, *cput1)

Expand All @@ -135,6 +131,8 @@ def build_se_moments(self, moments_dd):
lib.logger.info(self.gw, "Building self-energy moments")
lib.logger.debug(self.gw, "Memory usage: %.2f GB", self._memory_usage())

kpts = self.kpts

# Setup dependent on diagonal SE
if self.gw.diagonal_se:
pqchar = pchar = qchar = "p"
Expand All @@ -148,25 +146,27 @@ def build_se_moments(self, moments_dd):

# Get the moments in (aux|aux) and rotate to (mo|mo)
for n in range(self.nmom_max + 1):
for q, qpt in enumerate(self.kpts):
for q in kpts.loop(1):
eta_aux = 0
for kb, kptb in enumerate(self.kpts):
kj = self.kpts.member(self.kpts.wrap_around(kptb - qpt))
for kj in kpts.loop(1, mpi=True):
kb = kpts.member(kpts.wrap_around(kpts[q] + kpts[kj]))
eta_aux += np.dot(moments_dd[q, kb, n], self.integrals.Lia[kj, kb].T.conj())

for kp, kptp in enumerate(self.kpts):
kx = self.kpts.member(self.kpts.wrap_around(kptp - qpt))
eta_aux = mpi_helper.allreduce(eta_aux)
eta_aux *= 2.0
eta_aux /= self.nkpts

for kp in kpts.loop(1, mpi=True):
kx = kpts.member(kpts.wrap_around(kpts[kp] - kpts[q]))

if not isinstance(eta[kp, q], np.ndarray):
eta[kp, q] = np.zeros(eta_shape(kx), dtype=eta_aux.dtype)

for x in range(self.mo_energy_g[kx].size):
Lp = self.integrals.Lpx[kp, kx][:, :, x]
eta[kp, q][x, n] += (
lib.einsum(f"P{pchar},Q{qchar},PQ->{pqchar}", Lp, Lp.conj(), eta_aux)
* 2.0
/ self.nkpts
)
subscript = f"P{pchar},Q{qchar},PQ->{pqchar}"
eta[kp, q][x, n] += lib.einsum(subscript, Lp, Lp.conj(), eta_aux)

cput1 = lib.logger.timer(self.gw, "rotating DD moments", *cput0)

# Construct the self-energy moments
Expand All @@ -176,26 +176,28 @@ def build_se_moments(self, moments_dd):
for n in moms:
fp = scipy.special.binom(n, moms)
fh = fp * (-1) ** moms
for q, kp in self.kpts.loop(2):
kx = self.kpts.member(self.kpts.wrap_around(self.kpts[kp] - self.kpts[q]))

eo = np.power.outer(self.mo_energy_g[kx][self.mo_occ_g[kx] > 0], n - moms)
to = lib.einsum(
f"t,kt,kt{pqchar}->{pqchar}", fh, eo, eta[kp, q][self.mo_occ_g[kx] > 0]
)
moments_occ[kp, n] += fproc(to)

ev = np.power.outer(self.mo_energy_g[kx][self.mo_occ_g[kx] == 0], n - moms)
tv = lib.einsum(
f"t,ct,ct{pqchar}->{pqchar}", fp, ev, eta[kp, q][self.mo_occ_g[kx] == 0]
)
moments_vir[kp, n] += fproc(tv)

for k, kpt in enumerate(self.kpts):
for n in range(self.nmom_max + 1):
for q in kpts.loop(1):
for kp in kpts.loop(1, mpi=True):
kx = kpts.member(kpts.wrap_around(kpts[kp] - kpts[q]))
subscript = f"t,kt,kt{pqchar}->{pqchar}"

eo = np.power.outer(self.mo_energy_g[kx][self.mo_occ_g[kx] > 0], n - moms)
to = lib.einsum(subscript, fh, eo, eta[kp, q][self.mo_occ_g[kx] > 0])
moments_occ[kp, n] += fproc(to)

ev = np.power.outer(self.mo_energy_g[kx][self.mo_occ_g[kx] == 0], n - moms)
tv = lib.einsum(subscript, fp, ev, eta[kp, q][self.mo_occ_g[kx] == 0])
moments_vir[kp, n] += fproc(tv)

# Numerical integration can lead to small non-hermiticity
for n in range(self.nmom_max + 1):
for k in kpts.loop(1, mpi=True):
moments_occ[k, n] = 0.5 * (moments_occ[k, n] + moments_occ[k, n].T.conj())
moments_vir[k, n] = 0.5 * (moments_vir[k, n] + moments_vir[k, n].T.conj())

moments_occ = mpi_helper.allreduce(moments_occ)
moments_vir = mpi_helper.allreduce(moments_vir)

cput1 = lib.logger.timer(self.gw, "constructing SE moments", *cput1)

return moments_occ, moments_vir
Expand Down
4 changes: 4 additions & 0 deletions momentGW/qsgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,7 @@ def build_static_potential(self, mo_energy, se):
return 0.5 * (se_i + se_j)

check_convergence = evGW.check_convergence

@property
def has_fock_loop(self):
return True
9 changes: 0 additions & 9 deletions momentGW/scgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,6 @@ def kernel(
),
mo_occ_w=None if gw.w0 else gw._gf_to_occ(gf),
)
if mo_coeff.ndim == 3:
v = integrals.Lia[0, 0].real
ci = lib.einsum("pq,qi->pi", mo_coeff[0], gf[0].get_occupied().coupling).real
ca = lib.einsum("pq,qi->pi", mo_coeff[0], gf[0].get_virtual().coupling).real
else:
v = integrals.Lia
m = gf.moment(1)
ci = lib.einsum("pq,qi->pi", mo_coeff, gf.get_occupied().coupling)
ca = lib.einsum("pq,qi->pi", mo_coeff, gf.get_virtual().coupling)

# Update the moments of the SE
if moments is not None and cycle == 1:
Expand Down
5 changes: 5 additions & 0 deletions momentGW/tda.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def build_se_moments(self, moments_dd):
for x in range(q1 - q0):
Lp = self.integrals.Lpx[:, :, x]
eta[x, n] = lib.einsum(f"P{p},Q{q},PQ->{pq}", Lp, Lp, eta_aux) * 2.0

cput1 = lib.logger.timer(self.gw, "rotating DD moments", *cput0)

# Construct the self-energy moments
Expand All @@ -169,10 +170,14 @@ def build_se_moments(self, moments_dd):
ev = np.power.outer(self.mo_energy_g[q0:q1][self.mo_occ_g[q0:q1] == 0], n - moms)
tv = lib.einsum(f"t,ct,ct{pq}->{pq}", fp, ev, eta[self.mo_occ_g[q0:q1] == 0])
moments_vir[n] += fproc(tv)

moments_occ = mpi_helper.allreduce(moments_occ)
moments_vir = mpi_helper.allreduce(moments_vir)

# Numerical integration can lead to small non-hermiticity
moments_occ = 0.5 * (moments_occ + moments_occ.swapaxes(1, 2))
moments_vir = 0.5 * (moments_vir + moments_vir.swapaxes(1, 2))

cput1 = lib.logger.timer(self.gw, "constructing SE moments", *cput1)

return moments_occ, moments_vir
Expand Down
1 change: 1 addition & 0 deletions tests/test_kgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def setUpClass(cls):

smf = k2gamma.k2gamma(mf, kmesh=kmesh)
smf = smf.density_fit(auxbasis="weigend")
smf.exxdiv = None
smf.with_df._prefer_ccdf = True
smf.with_df.force_dm_kbuild = True

Expand Down

0 comments on commit 2d945c6

Please sign in to comment.