Skip to content

Commit

Permalink
Improve inheritence of kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
obackhouse committed Aug 14, 2023
1 parent eae4a3d commit 9aa7245
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 303 deletions.
47 changes: 47 additions & 0 deletions momentGW/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,53 @@ def build_se_moments(self, *args, **kwargs):
def solve_dyson(self, *args, **kwargs):
raise NotImplementedError

def _kernel(self, *args, **kwargs):
raise NotImplementedError

def kernel(
self,
nmom_max,
mo_energy=None,
mo_coeff=None,
moments=None,
integrals=None,
):
if mo_coeff is None:
mo_coeff = self.mo_coeff
if mo_energy is None:
mo_energy = self.mo_energy

cput0 = (logger.process_clock(), logger.perf_counter())
self.dump_flags()
logger.info(self, "nmom_max = %d", nmom_max)

self.converged, self.gf, self.se, self._qp_energy = self._kernel(
nmom_max,
mo_energy,
mo_coeff,
integrals=integrals,
)

gf_occ = self.gf.get_occupied()
gf_occ.remove_uncoupled(tol=1e-1)
for n in range(min(5, gf_occ.naux)):
en = -gf_occ.energy[-(n + 1)]
vn = gf_occ.coupling[:, -(n + 1)]
qpwt = np.linalg.norm(vn) ** 2
logger.note(self, "IP energy level %d E = %.16g QP weight = %0.6g", n, en, qpwt)

gf_vir = self.gf.get_virtual()
gf_vir.remove_uncoupled(tol=1e-1)
for n in range(min(5, gf_vir.naux)):
en = gf_vir.energy[n]
vn = gf_vir.coupling[:, n]
qpwt = np.linalg.norm(vn) ** 2
logger.note(self, "EA energy level %d E = %.16g QP weight = %0.6g", n, en, qpwt)

logger.timer(self, self.name, *cput0)

return self.converged, self.gf, self.se, self.qp_energy

@staticmethod
def _moment_error(t, t_prev):
"""Compute scaled error between moments."""
Expand Down
57 changes: 6 additions & 51 deletions momentGW/evgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def kernel(
Green's function object
se : pyscf.agf2.SelfEnergy
Self-energy object
qp_energy : numpy.ndarray
Quasiparticle energies. Always None for evGW, returned for
compatibility with other evGW methods.
"""

logger.warn(gw, "evGW is untested!")
Expand Down Expand Up @@ -118,7 +121,7 @@ def kernel(
if conv:
break

return conv, gf, se
return conv, gf, se, None


class evGW(GW):
Expand Down Expand Up @@ -178,6 +181,8 @@ class evGW(GW):
def name(self):
return "evG%sW%s" % ("0" if self.g0 else "", "0" if self.w0 else "")

_kernel = kernel

def check_convergence(self, mo_energy, mo_energy_prev, th, th_prev, tp, tp_prev):
"""Check for convergence, and print a summary of changes."""

Expand All @@ -201,53 +206,3 @@ def check_convergence(self, mo_energy, mo_energy_prev, th, th_prev, tp, tp_prev)
max(error_th, error_tp) < self.conv_tol_moms,
)
)

def kernel(
self,
nmom_max,
mo_energy=None,
mo_coeff=None,
moments=None,
integrals=None,
):
if mo_coeff is None:
mo_coeff = self.mo_coeff
if mo_energy is None:
mo_energy = self.mo_energy

cput0 = (logger.process_clock(), logger.perf_counter())
self.dump_flags()
logger.info(self, "nmom_max = %d", nmom_max)

self.converged, self.gf, self.se = kernel(
self,
nmom_max,
mo_energy,
mo_coeff,
integrals=integrals,
)

gf_occ = self.gf.get_occupied()
gf_occ.remove_uncoupled(tol=1e-1)
for n in range(min(5, gf_occ.naux)):
en = -gf_occ.energy[-(n + 1)]
vn = gf_occ.coupling[:, -(n + 1)]
qpwt = np.linalg.norm(vn) ** 2
logger.note(self, "IP energy level %d E = %.16g QP weight = %0.6g", n, en, qpwt)

gf_vir = self.gf.get_virtual()
gf_vir.remove_uncoupled(tol=1e-1)
for n in range(min(5, gf_vir.naux)):
en = gf_vir.energy[n]
vn = gf_vir.coupling[:, n]
qpwt = np.linalg.norm(vn) ** 2
logger.note(self, "EA energy level %d E = %.16g QP weight = %0.6g", n, en, qpwt)

if self.converged:
logger.note(self, "%s converged", self.name)
else:
logger.note(self, "%s failed to converge", self.name)

logger.timer(self, self.name, *cput0)

return self.converged, self.gf, self.se
54 changes: 7 additions & 47 deletions momentGW/gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,15 @@ def kernel(
Returns
-------
conv : bool
Convergence flag. Always True for AGW, returned for
Convergence flag. Always True for GW, returned for
compatibility with other GW methods.
gf : pyscf.agf2.GreensFunction
Green's function object
se : pyscf.agf2.SelfEnergy
Self-energy object
qp_energy : numpy.ndarray
Quasiparticle energies. Always None for GW, returned for
compatibility with other GW methods.
"""

if integrals is None:
Expand All @@ -85,7 +88,7 @@ def kernel(
gf, se = gw.solve_dyson(th, tp, se_static, integrals=integrals)
conv = True

return conv, gf, se
return conv, gf, se, None


class GW(BaseGW):
Expand All @@ -98,6 +101,8 @@ class GW(BaseGW):
def name(self):
return "G0W0"

_kernel = kernel

def build_se_static(self, integrals, mo_coeff=None, mo_energy=None):
"""Build the static part of the self-energy, including the
Fock matrix.
Expand Down Expand Up @@ -301,48 +306,3 @@ def moment_error(self, se_moments_hole, se_moments_part, se):
)

return eh, ep

def kernel(
self,
nmom_max,
mo_energy=None,
mo_coeff=None,
moments=None,
integrals=None,
):
if mo_coeff is None:
mo_coeff = self.mo_coeff
if mo_energy is None:
mo_energy = self.mo_energy

cput0 = (logger.process_clock(), logger.perf_counter())
self.dump_flags()
logger.info(self, "nmom_max = %d", nmom_max)

self.converged, self.gf, self.se = kernel(
self,
nmom_max,
mo_energy,
mo_coeff,
integrals=integrals,
)

gf_occ = self.gf.get_occupied()
gf_occ.remove_uncoupled(tol=1e-1)
for n in range(min(5, gf_occ.naux)):
en = -gf_occ.energy[-(n + 1)]
vn = gf_occ.coupling[:, -(n + 1)]
qpwt = np.linalg.norm(vn) ** 2
logger.note(self, "IP energy level %d E = %.16g QP weight = %0.6g", n, en, qpwt)

gf_vir = self.gf.get_virtual()
gf_vir.remove_uncoupled(tol=1e-1)
for n in range(min(5, gf_vir.naux)):
en = gf_vir.energy[n]
vn = gf_vir.coupling[:, n]
qpwt = np.linalg.norm(vn) ** 2
logger.note(self, "EA energy level %d E = %.16g QP weight = %0.6g", n, en, qpwt)

logger.timer(self, self.name, *cput0)

return self.converged, self.gf, self.se
44 changes: 44 additions & 0 deletions momentGW/pbc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,50 @@ def __init__(self, mf, **kwargs):

self._keys = set(self.__dict__.keys()).union(self._opts)

def kernel(
self,
nmom_max,
mo_energy=None,
mo_coeff=None,
moments=None,
integrals=None,
):
if mo_coeff is None:
mo_coeff = self.mo_coeff
if mo_energy is None:
mo_energy = self.mo_energy

cput0 = (logger.process_clock(), logger.perf_counter())
self.dump_flags()
logger.info(self, "nmom_max = %d", nmom_max)

self.converged, self.gf, self.se, self._qp_energy = self._kernel(
nmom_max,
mo_energy,
mo_coeff,
integrals=integrals,
)

gf_occ = self.gf[0].get_occupied()
gf_occ.remove_uncoupled(tol=1e-1)
for n in range(min(5, gf_occ.naux)):
en = -gf_occ.energy[-(n + 1)]
vn = gf_occ.coupling[:, -(n + 1)]
qpwt = np.linalg.norm(vn) ** 2
logger.note(self, "IP energy level (Γ) %d E = %.16g QP weight = %0.6g", n, en, qpwt)

gf_vir = self.gf[0].get_virtual()
gf_vir.remove_uncoupled(tol=1e-1)
for n in range(min(5, gf_vir.naux)):
en = gf_vir.energy[n]
vn = gf_vir.coupling[:, n]
qpwt = np.linalg.norm(vn) ** 2
logger.note(self, "EA energy level (Γ) %d E = %.16g QP weight = %0.6g", n, en, qpwt)

logger.timer(self, self.name, *cput0)

return self.converged, self.gf, self.se, self.qp_energy

@staticmethod
def _gf_to_occ(gf):
return tuple(BaseGW._gf_to_occ(g) for g in gf)
Expand Down
47 changes: 1 addition & 46 deletions momentGW/pbc/evgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pyscf.pbc import dft, gto
from pyscf.pbc.tools import k2gamma

from momentGW.evgw import evGW, kernel
from momentGW.evgw import evGW
from momentGW.pbc.gw import KGW


Expand Down Expand Up @@ -51,48 +51,3 @@ def check_convergence(self, mo_energy, mo_energy_prev, th, th_prev, tp, tp_prev)
max(error_th, error_tp) < self.conv_tol_moms,
)
)

def kernel(
self,
nmom_max,
mo_energy=None,
mo_coeff=None,
moments=None,
integrals=None,
):
if mo_coeff is None:
mo_coeff = self.mo_coeff
if mo_energy is None:
mo_energy = self.mo_energy

cput0 = (logger.process_clock(), logger.perf_counter())
self.dump_flags()
logger.info(self, "nmom_max = %d", nmom_max)

self.converged, self.gf, self.se = kernel(
self,
nmom_max,
mo_energy,
mo_coeff,
integrals=integrals,
)

gf_occ = self.gf[0].get_occupied()
gf_occ.remove_uncoupled(tol=1e-1)
for n in range(min(5, gf_occ.naux)):
en = -gf_occ.energy[-(n + 1)]
vn = gf_occ.coupling[:, -(n + 1)]
qpwt = np.linalg.norm(vn) ** 2
logger.note(self, "IP energy level (Γ) %d E = %.16g QP weight = %0.6g", n, en, qpwt)

gf_vir = self.gf[0].get_virtual()
gf_vir.remove_uncoupled(tol=1e-1)
for n in range(min(5, gf_vir.naux)):
en = gf_vir.energy[n]
vn = gf_vir.coupling[:, n]
qpwt = np.linalg.norm(vn) ** 2
logger.note(self, "EA energy level (Γ) %d E = %.16g QP weight = %0.6g", n, en, qpwt)

logger.timer(self, self.name, *cput0)

return self.converged, self.gf, self.se
47 changes: 1 addition & 46 deletions momentGW/pbc/gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pyscf.lib import logger
from pyscf.pbc import scf

from momentGW.gw import GW, kernel
from momentGW.gw import GW
from momentGW.pbc.base import BaseKGW
from momentGW.pbc.fock import fock_loop, minimize_chempot, search_chempot
from momentGW.pbc.ints import KIntegrals
Expand Down Expand Up @@ -207,48 +207,3 @@ def make_rdm1(self, gf=None):
gf = [GreensFunction(self.mo_energy, np.eye(self.nmo))]

return np.array([g.make_rdm1() for g in gf])

def kernel(
self,
nmom_max,
mo_energy=None,
mo_coeff=None,
moments=None,
integrals=None,
):
if mo_coeff is None:
mo_coeff = self.mo_coeff
if mo_energy is None:
mo_energy = self.mo_energy

cput0 = (logger.process_clock(), logger.perf_counter())
self.dump_flags()
logger.info(self, "nmom_max = %d", nmom_max)

self.converged, self.gf, self.se = kernel(
self,
nmom_max,
mo_energy,
mo_coeff,
integrals=integrals,
)

gf_occ = self.gf[0].get_occupied()
gf_occ.remove_uncoupled(tol=1e-1)
for n in range(min(5, gf_occ.naux)):
en = -gf_occ.energy[-(n + 1)]
vn = gf_occ.coupling[:, -(n + 1)]
qpwt = np.linalg.norm(vn) ** 2
logger.note(self, "IP energy level (Γ) %d E = %.16g QP weight = %0.6g", n, en, qpwt)

gf_vir = self.gf[0].get_virtual()
gf_vir.remove_uncoupled(tol=1e-1)
for n in range(min(5, gf_vir.naux)):
en = gf_vir.energy[n]
vn = gf_vir.coupling[:, n]
qpwt = np.linalg.norm(vn) ** 2
logger.note(self, "EA energy level (Γ) %d E = %.16g QP weight = %0.6g", n, en, qpwt)

logger.timer(self, self.name, *cput0)

return self.converged, self.gf, self.se
Loading

0 comments on commit 9aa7245

Please sign in to comment.