Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CCSDt' interface (and general EBCC interface updates) #121

Merged
merged 18 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
c8dbf28
Allow case-sensitive solver specifications.
cjcscott Jul 28, 2023
d0657e7
Working interface for CCSDt' calculations with ebCC.
cjcscott Jul 31, 2023
df47433
Added option to specify use of CCSD_Wavefunction objects for other an…
cjcscott Aug 1, 2023
b65b459
Added functionality to transform wavefunction representations to a di…
cjcscott Aug 1, 2023
ecf13d6
Initial functionality to rotate wavefunction objects into different b…
cjcscott Aug 1, 2023
0793308
Update functionality to rotate CCSD wavefunctions.
cjcscott Aug 2, 2023
399e084
Working interface to EBCC for arbitrary ansatzes.
cjcscott Aug 2, 2023
e0d7b40
Initial set of ebCC tests; not currently all passing.
cjcscott Aug 3, 2023
38f2cf7
Update to ensure coupled electron-boson CCSD wavefunctions are treate…
cjcscott Aug 3, 2023
2d29ed0
Move to minimal-basis tests for ebCC solvers.
cjcscott Aug 4, 2023
99bdc5d
Added manual setting of active spaces for embedded CC methods, approp…
cjcscott Aug 4, 2023
d620d4e
Add example for use of ebCC solvers.
cjcscott Aug 4, 2023
2024f5d
Fix bug in rotation of RCCSD lambda amplitudes.
cjcscott Aug 4, 2023
8000e6c
Small formatting fix to ccsd wavefunction file.
cjcscott Aug 4, 2023
79c085e
Updates for Ollie's review.
cjcscott Aug 7, 2023
5c0712d
Merge branch 'master' into CCSDt_interface
cjcscott Aug 8, 2023
416db7f
Add support and tests for active space cluster solvers with SCMF.
cjcscott Aug 11, 2023
02dc7f4
Change check in make_rdm1 to use CCSD routines for all CC solvers.
cjcscott Aug 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions examples/ewf/molecules/26-ebcc-solvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pyscf
import pyscf.gto
import pyscf.scf
import pyscf.mcscf

import vayesta
import vayesta.ewf

mol = pyscf.gto.Mole()
mol.atom = ['N 0 0 0', 'N 0 0 2']
mol.basis = 'aug-cc-pvdz'
mol.output = 'pyscf.out'
mol.build()

# Hartree-Fock
mf = pyscf.scf.RHF(mol)
mf.kernel()

# Reference CASCI
casci = pyscf.mcscf.CASCI(mf, 8, 10)
casci.kernel()

# Reference CASSCF
casscf = pyscf.mcscf.CASSCF(mf, 8, 10)
casscf.kernel()

def get_emb_result(ansatz, bathtype='full'):
# Uses fastest available solver for given ansatz; PySCF if available, otherwise ebcc.
emb = vayesta.ewf.EWF(mf, solver=ansatz, bath_options=dict(bathtype=bathtype),
solver_options=dict(solve_lambda=False))
# Both these alternative specifications will always use an ebcc solver.
# Note that the capitalization of the solver name other than the ansatz is arbitrary.
#emb = vayesta.ewf.EWF(mf, solver=f'EB{ansatz}', bath_options=dict(bathtype=bathtype),
# solver_options=dict(solve_lambda=False))
#emb = vayesta.ewf.EWF(mf, solver='ebcc', bath_options=dict(bathtype=bathtype),
# solver_options=dict(solve_lambda=False, ansatz=ansatz))

with emb.iao_fragmentation() as f:
with f.rotational_symmetry(2, "y", center=(0, 0, 1)):
f.add_atomic_fragment(0)
emb.kernel()
return emb.e_tot

e_ccsd = get_emb_result('CCSD', 'full')
e_ccsdt = get_emb_result('CCSDT', 'dmet')
e_ccsdtprime = get_emb_result("CCSDt'", 'full')

print("E(HF)= %+16.8f Ha" % mf.e_tot)
print("E(CASCI)= %+16.8f Ha" % casci.e_tot)
print("E(CASSCF)= %+16.8f Ha" % casscf.e_tot)
print("E(CCSD, complete)= %+16.8f Ha" % e_ccsd)
print("E(emb. CCSDT, DMET CAS)= %+16.8f Ha" % e_ccsdt)
print("E(emb. CCSDt', complete+DMET active space)= %+16.8f Ha" % e_ccsdtprime)
2 changes: 1 addition & 1 deletion vayesta/core/qemb/qemb.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class Options(OptionsBase):
# EBFCI/EBCCSD
max_boson_occ=2,
# EBCC
ansatz=None, fermion_wf=False,
ansatz=None, store_as_ccsd=None, fermion_wf=False,
# Dump
dumpfile='clusters.h5',
# MP2
Expand Down
91 changes: 89 additions & 2 deletions vayesta/core/types/wf/ccsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
# import pyscf
# import pyscf.cc
from vayesta.core import spinalg
from vayesta.core.util import NotCalculatedError, Object, callif, einsum
from vayesta.core.util import NotCalculatedError, Object, callif, einsum, dot

Check warning on line 8 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L8

Added line #L8 was not covered by tests
from vayesta.core.types import wf as wf_types
from vayesta.core.types.orbitals import SpatialOrbitals
from vayesta.core.types.wf.project import (project_c1, project_c2, project_uc1, project_uc2, symmetrize_c2,
symmetrize_uc2)
symmetrize_uc2, transform_c1, transform_c2, transform_uc1, transform_uc2)
from vayesta.core.helper import pack_arrays, unpack_arrays
from scipy.linalg import block_diag

Check warning on line 14 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L14

Added line #L14 was not covered by tests


def CCSD_WaveFunction(mo, t1, t2, **kwargs):
Expand Down Expand Up @@ -170,6 +171,42 @@
def as_fci(self):
raise NotImplementedError

def rotate(self, t, inplace=False):

Check warning on line 174 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L174

Added line #L174 was not covered by tests
"""Rotate wavefunction representation to another basis.
Only rotations which don't mix occupied and virtual orbitals are supported.
Assumes rotated orbitals have same occupancy ordering as originals.
"""
o = self.mo.occ > 0
v = self.mo.occ == 0

Check warning on line 180 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L179-L180

Added lines #L179 - L180 were not covered by tests

to = t[np.ix_(o, o)]
tv = t[np.ix_(v, v)]
tov = t[np.ix_(o, v)]
tvo = t[np.ix_(v, o)]
if abs(tov).max() > 1e-12 or abs(tvo).max() > 1e-12:

Check warning on line 186 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L182-L186

Added lines #L182 - L186 were not covered by tests
raise ValueError("Provided rotation mixes occupied and virtual orbitals.")
return self.rotate_ov(to, tv, inplace=inplace)

Check warning on line 188 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L188

Added line #L188 was not covered by tests

def rotate_ov(self, to, tv, inplace=False):

Check warning on line 190 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L190

Added line #L190 was not covered by tests
"""Rotate wavefunction representation to another basis.
Only rotations which don't mix occupied and virtual orbitals are supported.

Parameters
----------
to : new occupied orbital coefficients in terms of current ones.
tv : new virtual orbital coefficients in terms of current ones.
inplace : Whether to transform in-place or return a new object.
"""
wf = self if inplace else self.copy()
wf.mo.basis_transform(lambda c: dot(c, block_diag(to, tv)), inplace=True)
wf.t1 = transform_c1(wf.t1, to, tv)
wf.t2 = transform_c2(wf.t2, to, tv)
if wf.l1 is not None:
wf.l1 = transform_c1(wf.l1, to, tv)
if wf.l2 is not None:
wf.l2 = transform_c2(wf.l2, to, tv)
return wf

Check warning on line 208 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L200-L208

Added lines #L200 - L208 were not covered by tests


class UCCSD_WaveFunction(RCCSD_WaveFunction):

Expand Down Expand Up @@ -296,6 +333,56 @@
if self.l2 is not None:
self.l2 = spinalg.multiply(self.l2, len(self.l2)*[factor])

def rotate(self, t, inplace=False):

Check warning on line 336 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L336

Added line #L336 was not covered by tests
"""Rotate wavefunction representation to another basis.
Only rotations which don't mix occupied and virtual orbitals are supported.
Assumes rotated orbitals have same occupancy ordering as originals.
"""
# Allow support for same rotation for alpha and beta.
if isinstance(t, np.ndarray) and t.ndim == 2:
t = (t, t)
def get_components(tsp, occ):
o = occ > 0
v = occ == 0
tspo = tsp[np.ix_(o, o)]
tspv = tsp[np.ix_(v, v)]
tspov = tsp[np.ix_(o, v)]
tspvo = tsp[np.ix_(v, o)]
if abs(tspov).max() > 1e-12 or abs(tspvo).max() > 1e-12:

Check warning on line 351 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L342-L351

Added lines #L342 - L351 were not covered by tests
raise ValueError("Provided rotation mixes occupied and virtual orbitals.")
return tspo, tspv

Check warning on line 353 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L353

Added line #L353 was not covered by tests

toa, tva = get_components(t[0], self.mo.alpha.occ)
tob, tvb = get_components(t[1], self.mo.beta.occ)

Check warning on line 356 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L355-L356

Added lines #L355 - L356 were not covered by tests

return self.rotate_ov((toa, tob), (tva, tvb), inplace=inplace)

Check warning on line 358 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L358

Added line #L358 was not covered by tests

def rotate_ov(self, to, tv, inplace=False):

Check warning on line 360 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L360

Added line #L360 was not covered by tests
"""Rotate wavefunction representation to another basis.
Only rotations which don't mix occupied and virtual orbitals are supported.

Parameters
----------
to : new occupied orbital coefficients in terms of current ones.
tv : new virtual orbital coefficients in terms of current ones.
inplace : Whether to transform in-place or return a new object.
"""
wf = self if inplace else self.copy()
if isinstance(to, np.ndarray) and len(to) == 2:
assert(isinstance(tv, np.ndarray) and len(tv) == 2)
trafo = lambda c: dot(c, block_diag(to, tv))

Check warning on line 373 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L370-L373

Added lines #L370 - L373 were not covered by tests
else:
trafo = [lambda c: dot(c, x) for x in (block_diag(to[0], tv[0]), block_diag(to[1], tv[1]))]

Check warning on line 375 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L375

Added line #L375 was not covered by tests

wf.mo.basis_transform(trafo, inplace=True)
wf.t1 = transform_uc1(wf.t1, to, tv)
wf.t2 = transform_uc2(wf.t2, to, tv)
if wf.l1 is not None:
wf.l1 = transform_uc1(wf.l1, to, tv)
if wf.l2 is not None:
wf.l2 = transform_uc2(wf.l2, to, tv)
return wf

Check warning on line 384 in vayesta/core/types/wf/ccsd.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/ccsd.py#L377-L384

Added lines #L377 - L384 were not covered by tests

#def pack(self, dtype=float):
# """Pack into a single array of data type `dtype`.

Expand Down
25 changes: 25 additions & 0 deletions vayesta/core/types/wf/project.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utility functions for projection of wave functions."""

import numpy as np
from vayesta.core.util import einsum, dot

Check warning on line 4 in vayesta/core/types/wf/project.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/project.py#L4

Added line #L4 was not covered by tests


def project_c1(c1, p):
Expand Down Expand Up @@ -45,3 +46,27 @@
if len(c2) == 4:
c2ab = (c2ab + c2[2].transpose(1,0,3,2))/2
return (c2aa, c2ab, c2bb)

def transform_c1(c1, to, tv):
if c1 is None: return None
return dot(to.T, c1, tv)

Check warning on line 52 in vayesta/core/types/wf/project.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/project.py#L50-L52

Added lines #L50 - L52 were not covered by tests

def transform_c2(c2, to, tv, to2=None, tv2=None):
if c2 is None: return None
if to2 is None: to2 = to
if tv2 is None: tv2 = tv

Check warning on line 57 in vayesta/core/types/wf/project.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/project.py#L54-L57

Added lines #L54 - L57 were not covered by tests
# Use einsum for now- tensordot would be faster but less readable.
return einsum("ijab,iI,jJ,aA,bB->IJAB", c2, to, to2, tv, tv2)

Check warning on line 59 in vayesta/core/types/wf/project.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/project.py#L59

Added line #L59 was not covered by tests

def transform_uc1(c1, to, tv):
if c1 is None: return None
return (transform_c1(c1[0], to[0], tv[0]),

Check warning on line 63 in vayesta/core/types/wf/project.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/project.py#L61-L63

Added lines #L61 - L63 were not covered by tests
transform_c1(c1[1], to[1], tv[1]))

def transform_uc2(c2, to, tv):
if c2 is None: return None
c2ba = (c2[2] if len(c2) == 4 else c2[1].transpose(1,0,3,2))
return (transform_c2(c2[0], to[0], tv[0]),

Check warning on line 69 in vayesta/core/types/wf/project.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/project.py#L66-L69

Added lines #L66 - L69 were not covered by tests
transform_c2(c2[1], to[0], tv[0], to[1], tv[1]),
transform_c2(c2ba, to[1], tv[1], to[0], tv[0]),
transform_c2(c2[-1], to[1], tv[1]))
6 changes: 6 additions & 0 deletions vayesta/core/types/wf/wf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@
def make_rdm2(self, *args, **kwargs):
raise AbstractMethodError

def rotate_ov(self, *args, **kwargs):

Check warning on line 68 in vayesta/core/types/wf/wf.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/wf.py#L68

Added line #L68 was not covered by tests
raise AbstractMethodError

def rotate(self, *args, **kwargs):

Check warning on line 71 in vayesta/core/types/wf/wf.py

View check run for this annotation

Codecov / codecov/patch

vayesta/core/types/wf/wf.py#L71

Added line #L71 was not covered by tests
raise AbstractMethodError

@staticmethod
def from_pyscf(obj, **kwargs):
# HF
Expand Down
2 changes: 1 addition & 1 deletion vayesta/ewf/ewf.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@
# Defaults

def make_rdm1(self, *args, **kwargs):
if self.solver.lower() == 'ccsd':
if "cc" in self.solver.lower():

Check warning on line 232 in vayesta/ewf/ewf.py

View check run for this annotation

Codecov / codecov/patch

vayesta/ewf/ewf.py#L232

Added line #L232 was not covered by tests
return self._make_rdm1_ccsd_global_wf(*args, **kwargs)
if self.solver.lower() == 'mp2':
return self._make_rdm1_mp2_global_wf(*args, **kwargs)
Expand Down
15 changes: 11 additions & 4 deletions vayesta/ewf/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,13 @@
# For self-consistent mode
self.solver_results = None

def _reset(self, *args, **kwargs):
super()._reset(*args, **kwargs)

Check warning on line 104 in vayesta/ewf/fragment.py

View check run for this annotation

Codecov / codecov/patch

vayesta/ewf/fragment.py#L103-L104

Added lines #L103 - L104 were not covered by tests
# Need to unset these so can be regenerated each iteration.
self.opts.c_cas_occ = self.opts.c_cas_vir = None

Check warning on line 106 in vayesta/ewf/fragment.py

View check run for this annotation

Codecov / codecov/patch

vayesta/ewf/fragment.py#L106

Added line #L106 was not covered by tests

def set_cas(self, iaos=None, c_occ=None, c_vir=None, minao='auto', dmet_threshold=None):
"""Set complete active space for tailored CCSD"""
"""Set complete active space for tailored CCSD and active-space CC methods."""
if dmet_threshold is None:
dmet_threshold = 2*self.opts.bath_options['dmet_threshold']
if iaos is not None:
Expand Down Expand Up @@ -282,8 +287,9 @@
self.log.debugv("Passing fragment option %s to solver.", attr)
solver_opts[attr] = getattr(self.opts, attr)

if solver.upper() == 'TCCSD':
solver_opts['tcc'] = True
has_actspace = ((solver == "TCCSD") or ("CCSDt'" in solver) or

Check warning on line 290 in vayesta/ewf/fragment.py

View check run for this annotation

Codecov / codecov/patch

vayesta/ewf/fragment.py#L290

Added line #L290 was not covered by tests
(solver.upper() == "EBCC" and self.opts.solver_options['ansatz'] == "CCSDt'"))
if has_actspace:

Check warning on line 292 in vayesta/ewf/fragment.py

View check run for this annotation

Codecov / codecov/patch

vayesta/ewf/fragment.py#L292

Added line #L292 was not covered by tests
# Set CAS orbitals
if self.opts.c_cas_occ is None:
self.log.warning("Occupied CAS orbitals not set. Setting to occupied DMET cluster orbitals.")
Expand All @@ -293,7 +299,8 @@
self.opts.c_cas_vir = self._dmet_bath.c_cluster_vir
solver_opts['c_cas_occ'] = self.opts.c_cas_occ
solver_opts['c_cas_vir'] = self.opts.c_cas_vir
solver_opts['tcc_fci_opts'] = self.opts.tcc_fci_opts
if solver == "TCCSD":
solver_opts['tcc_fci_opts'] = self.opts.tcc_fci_opts

Check warning on line 303 in vayesta/ewf/fragment.py

View check run for this annotation

Codecov / codecov/patch

vayesta/ewf/fragment.py#L302-L303

Added lines #L302 - L303 were not covered by tests
elif solver.upper() == 'DUMP':
solver_opts['filename'] = self.opts.solver_options['dumpfile']
solver_opts['external_corrections'] = self.flags.external_corrections
Expand Down
10 changes: 8 additions & 2 deletions vayesta/ewf/ufragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@

class Fragment(RFragment, BaseFragment):

def set_cas(self, *args, **kwargs):
raise NotImplementedError()
def set_cas(self, iaos=None, c_occ=None, c_vir=None, minao='auto', dmet_threshold=None):

Check warning on line 14 in vayesta/ewf/ufragment.py

View check run for this annotation

Codecov / codecov/patch

vayesta/ewf/ufragment.py#L14

Added line #L14 was not covered by tests
"""Set complete active space for tailored CCSD and active-space CC methods."""
if iaos is not None:

Check warning on line 16 in vayesta/ewf/ufragment.py

View check run for this annotation

Codecov / codecov/patch

vayesta/ewf/ufragment.py#L16

Added line #L16 was not covered by tests
raise NotImplementedError("Unrestricted IAO-based CAS not implemented yet.")

self.opts.c_cas_occ = c_occ
self.opts.c_cas_vir = c_vir
return c_occ, c_vir

Check warning on line 21 in vayesta/ewf/ufragment.py

View check run for this annotation

Codecov / codecov/patch

vayesta/ewf/ufragment.py#L19-L21

Added lines #L19 - L21 were not covered by tests

def get_fragment_energy(self, c1, c2, hamil=None, fock=None, axis1='fragment', c2ba_order='ba'):
"""Calculate fragment correlation energy contribution from projected C1, C2.
Expand Down
35 changes: 22 additions & 13 deletions vayesta/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
from vayesta.solver.cisd import RCISD_Solver, UCISD_Solver
from vayesta.solver.coupled_ccsd import coupledRCCSD_Solver
from vayesta.solver.dump import DumpSolver
from vayesta.solver.ebcc import REBCC_Solver, UEBCC_Solver, EB_REBCC_Solver, EB_UEBCC_Solver
from vayesta.solver.ebfci import EB_EBFCI_Solver, EB_UEBFCI_Solver
from vayesta.solver.ext_ccsd import extRCCSD_Solver, extUCCSD_Solver
from vayesta.solver.fci import FCI_Solver, UFCI_Solver
from vayesta.solver.hamiltonian import is_ham, is_uhf_ham, is_eb_ham, ClusterHamiltonian
from vayesta.solver.mp2 import RMP2_Solver, UMP2_Solver
from vayesta.solver.tccsd import TRCCSD_Solver

try:
from vayesta.solver.ebcc import REBCC_Solver, UEBCC_Solver, EB_REBCC_Solver, EB_UEBCC_Solver

Check warning on line 13 in vayesta/solver/__init__.py

View check run for this annotation

Codecov / codecov/patch

vayesta/solver/__init__.py#L12-L13

Added lines #L12 - L13 were not covered by tests
except ImportError:
_has_ebcc = False
else:
_has_ebcc = True

Check warning on line 17 in vayesta/solver/__init__.py

View check run for this annotation

Codecov / codecov/patch

vayesta/solver/__init__.py#L17

Added line #L17 was not covered by tests

def get_solver_class(ham, solver):
assert (is_ham(ham))
Expand All @@ -36,7 +41,6 @@


def _get_solver_class_internal(is_uhf, is_eb, solver):
solver = solver.upper()
# First check if we have a CC approach as implemented in pyscf.
if solver == "CCSD" and not is_eb:
# Use pyscf solvers.
Expand All @@ -48,21 +52,25 @@
if is_uhf or is_eb:
raise ValueError("TCCSD is not implemented for unrestricted or electron-boson calculations!")
return TRCCSD_Solver
if solver == "EXTCCSD":
if solver == "extCCSD":

Check warning on line 55 in vayesta/solver/__init__.py

View check run for this annotation

Codecov / codecov/patch

vayesta/solver/__init__.py#L55

Added line #L55 was not covered by tests
if is_eb:
raise ValueError("extCCSD is not implemented for electron-boson calculations!")
if is_uhf:
return extUCCSD_Solver
return extRCCSD_Solver
if solver == "COUPLEDCCSD":
if solver == "coupledCCSD":

Check warning on line 61 in vayesta/solver/__init__.py

View check run for this annotation

Codecov / codecov/patch

vayesta/solver/__init__.py#L61

Added line #L61 was not covered by tests
if is_eb:
raise ValueError("coupledCCSD is not implemented for electron-boson calculations!")
if is_uhf:
raise ValueError("coupledCCSD is not implemented for unrestricted calculations!")
return coupledRCCSD_Solver

# Now consider general CC ansatzes; these are solved via EBCC.
if "CC" in solver:
# Note that we support all capitalisations of `ebcc`, but need `CC` to be capitalised when also using this to
# specify an ansatz.
if "CC" in solver.upper():
if not _has_ebcc:
raise ImportError(f"{solver} solver is only accessible via ebcc. Please install ebcc.")

Check warning on line 73 in vayesta/solver/__init__.py

View check run for this annotation

Codecov / codecov/patch

vayesta/solver/__init__.py#L71-L73

Added lines #L71 - L73 were not covered by tests
if is_uhf:
if is_eb:
solverclass = EB_UEBCC_Solver
Expand All @@ -73,23 +81,24 @@
solverclass = EB_REBCC_Solver
else:
solverclass = REBCC_Solver
if solver == "EBCC":
if solver.upper() == "EBCC":

Check warning on line 84 in vayesta/solver/__init__.py

View check run for this annotation

Codecov / codecov/patch

vayesta/solver/__init__.py#L84

Added line #L84 was not covered by tests
# Default to `opts.ansatz`.
return solverclass
if solver[:2] == "EB":
if solver[:2].upper() == "EB":

Check warning on line 87 in vayesta/solver/__init__.py

View check run for this annotation

Codecov / codecov/patch

vayesta/solver/__init__.py#L87

Added line #L87 was not covered by tests
solver = solver[2:]
if solver == "CCSD" and is_eb:
# Need to specify CC level for coupled electron-boson model; throw an error rather than assume.
raise ValueError(
"Please specify a coupled electron-boson CC ansatz as a solver, for example CCSD-S-1-1,"
"rather than CCSD")

# This is just a wrapper to allow us to use the solver option as the ansatz kwarg in this case.
def get_right_CC(*args, **kwargs):

if kwargs.get("ansatz", None) is not None:
raise ValueError(
"Desired CC ansatz specified differently in solver and solver_options.ansatz."
"Please use only specify via one approach, or ensure they agree.")
setansatz = kwargs.get("ansatz", None)
if setansatz is not None:
if setansatz != solver:

Check warning on line 98 in vayesta/solver/__init__.py

View check run for this annotation

Codecov / codecov/patch

vayesta/solver/__init__.py#L96-L98

Added lines #L96 - L98 were not covered by tests
raise ValueError(
"Desired CC ansatz specified differently in solver and solver_options.ansatz."
"Please use only specify via one approach, or ensure they agree.")
kwargs["ansatz"] = solver
return solverclass(*args, **kwargs)

Expand Down
Loading
Loading