Skip to content

Commit

Permalink
replace hard-coded const multipliers for NLDFs with new normalizers
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebystrom committed Jun 21, 2024
1 parent 7cbeb4b commit 26e01e7
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 17 deletions.
1 change: 1 addition & 0 deletions ciderpress/gpaw/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def calculate_6d_integral_deriv(
if n_g is not None:
dfeatdf = np.zeros([nexp - 1] + list(n_g.shape))
for i in range(nexp - 1):
# TODO this should be changed to use the feature normalizers
const = ((self.consts[i, 1] + self.consts[-1, 1]) / 2) ** 1.5
const = const + self.consts[i, 0] / (self.consts[i, 0] + const)
for ind, a in enumerate(self.alphas):
Expand Down
1 change: 1 addition & 0 deletions ciderpress/gpaw/atom_analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def _get_paw_helper2(
dydf_obg = dFdf_oag
Nalpha = self.Nalpha

# TODO consts should be removed to make use of feat normalizers
x_sig = np.zeros((nspin, nfeat, ngrid))
dxdf_oig = np.zeros((norb, nfeat, ngrid))
for s in range(nspin):
Expand Down
22 changes: 8 additions & 14 deletions ciderpress/gpaw/cider_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,17 +261,15 @@ def call_xc_kernel(
X0T[:, start : start + nfeat_tmp] = nspin * feat_sg
start += nfeat_tmp

X0TN = (
X0T # self.mlfunc.settings.normalizers.get_normalized_feature_vector(X0T)
)
X0TN = self.mlfunc.settings.normalizers.get_normalized_feature_vector(X0T)
exc_ml, dexcdX0TN_ml = self.mlfunc(X0TN, rhocut=self.rhocut)
xmix = self.xmix # / rho.shape[0]
exc_ml *= xmix
dexcdX0TN_ml *= xmix
# vxc_ml = self.mlfunc.settings.normalizers.get_derivative_wrt_unnormed_features(
# X0T, dexcdX0TN_ml
# )
vxc_ml = dexcdX0TN_ml
vxc_ml = self.mlfunc.settings.normalizers.get_derivative_wrt_unnormed_features(
X0T, dexcdX0TN_ml
)
# vxc_ml = dexcdX0TN_ml
e_g[:] += exc_ml

start = 0
Expand Down Expand Up @@ -975,8 +973,6 @@ def calculate_6d_integral_fwd(self, n_g, cider_exp, c_abi=None):
else:
i_g = None
dq0_g = None
const = ((self.consts[i, 1] + self.consts[-1, 1]) / 2) ** 1.5
const = const + self.consts[i, 0] / (self.consts[i, 0] + const)
for ind, a in enumerate(self.alphas):
self.timer.start("COEFS")
pa_g, dpa_g = fnlxc.eval_cubic_interp(
Expand All @@ -987,8 +983,8 @@ def calculate_6d_integral_fwd(self, n_g, cider_exp, c_abi=None):
self.timer.stop()
p_iag[i, a] = pa_g

feat[i, :] += const * pa_g * self.rbuf_ag[a]
dfeat[i, :] += const * dpa_g * self.rbuf_ag[a]
feat[i, :] += pa_g * self.rbuf_ag[a]
dfeat[i, :] += dpa_g * self.rbuf_ag[a]

self.timer.start("6d comm fwd")
if n_g is not None:
Expand Down Expand Up @@ -1041,11 +1037,9 @@ def calculate_6d_integral_bwd(
for a in self.alphas:
self.rbuf_ag[a][:] = 0.0
for i in range(nexp - 1):
const = ((self.consts[i, 1] + self.consts[-1, 1]) / 2) ** 1.5
const = const + self.consts[i, 0] / (self.consts[i, 0] + const)
vfeati_g = self.domain_world2cider(vfeat_g[i])
for a in alphas:
self.rbuf_ag[a][:] += vfeati_g * const * p_iag[i, a]
self.rbuf_ag[a][:] += vfeati_g * p_iag[i, a]
self.timer.stop()
if compute_stress:
self.calculate_6d_stress_integral()
Expand Down
6 changes: 3 additions & 3 deletions ciderpress/gpaw/paw_cider_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def get_paw_atom_contribs_en(self, n_sg, sigma_xg, y_sbg, F_sag, ae=True):
x_sig[s, i] += pa_g * y_sbg[s, a]
xd_sig[s, i] += dpa_g * y_sbg[s, a]
p_siag[s, i, a] = pa_g
x_sig[s, i] *= ((self.consts[i, 1] + self.consts[-1, 1]) / 2) ** 1.5
xd_sig[s, i] *= ((self.consts[i, 1] + self.consts[-1, 1]) / 2) ** 1.5
p_siag[s, i] *= ((self.consts[i, 1] + self.consts[-1, 1]) / 2) ** 1.5
# x_sig[s, i] *= ((self.consts[i, 1] + self.consts[-1, 1]) / 2) ** 1.5
# xd_sig[s, i] *= ((self.consts[i, 1] + self.consts[-1, 1]) / 2) ** 1.5
# p_siag[s, i] *= ((self.consts[i, 1] + self.consts[-1, 1]) / 2) ** 1.5
dgdn_sig[s, i] = xd_sig[s, i] * dadn
dgdsigma_sig[s, i] = xd_sig[s, i] * dadsigma

Expand Down

0 comments on commit 26e01e7

Please sign in to comment.