Skip to content

Commit

Permalink
enable tyepwise cutoff in nep training
Browse files Browse the repository at this point in the history
  • Loading branch information
brucefan1983 committed Jun 29, 2024
1 parent 6cd079f commit 2f81eff
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/force/dftd3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ J. Comput. Chem., 32, 1456 (2011).
#include "dftd3para.cuh"
#include "model/box.cuh"
#include "neighbor.cuh"
#include "utilities/common.cuh"
#include <algorithm>
#include <cctype>
#include <iostream>
Expand All @@ -40,15 +41,14 @@ J. Comput. Chem., 32, 1456 (2011).
namespace
{
const int MN = 10000; // maximum number of neighbors for one atom
const int NUM_ELEMENTS = 103;
const std::string ELEMENTS[NUM_ELEMENTS] = {
"H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P",
"S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn",
"Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh",
"Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd",
"Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re",
"Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th",
"Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr"};
"Pa", "U", "Np", "Pu"};

void __global__ find_dftd3_coordination_number_small_box(
DFTD3::DFTD3_Para dftd3_para,
Expand Down
2 changes: 1 addition & 1 deletion src/force/nep3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const std::string ELEMENTS[NUM_ELEMENTS] = {
"Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd",
"Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re",
"Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th",
"Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr"};
"Pa", "U", "Np", "Pu"};

void NEP3::initialize_dftd3()
{
Expand Down
2 changes: 1 addition & 1 deletion src/force/nep3_multigpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ const std::string ELEMENTS[NUM_ELEMENTS] = {
"Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd",
"Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re",
"Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th",
"Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr"};
"Pa", "U", "Np", "Pu"};

void NEP3_MULTIGPU::initialize_dftd3()
{
Expand Down
66 changes: 48 additions & 18 deletions src/main_nep/nep3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,25 @@ static __global__ void find_descriptors_radial(
float z12 = g_z12[index];
float d12 = sqrt(x12 * x12 + y12 * y12 + z12 * z12);
float fc12;
find_fc(paramb.rc_radial, paramb.rcinv_radial, d12, fc12);
int t2 = g_type[n2];
float rc = paramb.rc_radial;
if (paramb.use_typewise_cutoff) {
rc = min((COVALENT_RADIUS[paramb.atomic_numbers[t1]] + COVALENT_RADIUS[paramb.atomic_numbers[t2]]) * 2.5f, rc);
}
float rcinv = 1.0f / rc;
find_fc(rc, rcinv, d12, fc12);

float fn12[MAX_NUM_N];
if (paramb.version == 2) {
find_fn(paramb.n_max_radial, paramb.rcinv_radial, d12, fc12, fn12);
find_fn(paramb.n_max_radial, rcinv, d12, fc12, fn12);
for (int n = 0; n <= paramb.n_max_radial; ++n) {
float c = (paramb.num_types == 1)
? 1.0f
: annmb.c[(n * paramb.num_types + t1) * paramb.num_types + t2];
q[n] += fn12[n] * c;
}
} else {
find_fn(paramb.basis_size_radial, paramb.rcinv_radial, d12, fc12, fn12);
find_fn(paramb.basis_size_radial, rcinv, d12, fc12, fn12);
for (int n = 0; n <= paramb.n_max_radial; ++n) {
float gn12 = 0.0f;
for (int k = 0; k <= paramb.basis_size_radial; ++k) {
Expand Down Expand Up @@ -185,11 +191,16 @@ static __global__ void find_descriptors_angular(
float z12 = g_z12[index];
float d12 = sqrt(x12 * x12 + y12 * y12 + z12 * z12);
float fc12;
find_fc(paramb.rc_angular, paramb.rcinv_angular, d12, fc12);
int t2 = g_type[n2];
float rc = paramb.rc_angular;
if (paramb.use_typewise_cutoff) {
rc = min((COVALENT_RADIUS[paramb.atomic_numbers[t1]] + COVALENT_RADIUS[paramb.atomic_numbers[t2]]) * 2.0f, rc);
}
float rcinv = 1.0f / rc;
find_fc(rc, rcinv, d12, fc12);
if (paramb.version == 2) {
float fn;
find_fn(n, paramb.rcinv_angular, d12, fc12, fn);
find_fn(n, rcinv, d12, fc12, fn);
fn *=
(paramb.num_types == 1)
? 1.0f
Expand All @@ -198,7 +209,7 @@ static __global__ void find_descriptors_angular(
accumulate_s(d12, x12, y12, z12, fn, s);
} else {
float fn12[MAX_NUM_N];
find_fn(paramb.basis_size_angular, paramb.rcinv_angular, d12, fc12, fn12);
find_fn(paramb.basis_size_angular, rcinv, d12, fc12, fn12);
float gn12 = 0.0f;
for (int k = 0; k <= paramb.basis_size_angular; ++k) {
int c_index = (n * (paramb.basis_size_angular + 1) + k) * paramb.num_types_sq;
Expand Down Expand Up @@ -242,6 +253,7 @@ NEP3::NEP3(
paramb.rcinv_radial = 1.0f / paramb.rc_radial;
paramb.rc_angular = para.rc_angular;
paramb.rcinv_angular = 1.0f / paramb.rc_angular;
paramb.use_typewise_cutoff = para.use_typewise_cutoff;
paramb.num_types = para.num_types;
paramb.n_max_radial = para.n_max_radial;
paramb.n_max_angular = para.n_max_angular;
Expand Down Expand Up @@ -269,6 +281,7 @@ NEP3::NEP3(
zbl.rc_outer = para.zbl_rc_outer;
for (int n = 0; n < para.atomic_numbers.size(); ++n) {
zbl.atomic_numbers[n] = para.atomic_numbers[n];
paramb.atomic_numbers[n] = para.atomic_numbers[n];
}
if (zbl.flexibled) {
zbl.num_types = para.num_types;
Expand Down Expand Up @@ -563,13 +576,18 @@ static __global__ void find_force_radial(
float d12 = sqrt(r12[0] * r12[0] + r12[1] * r12[1] + r12[2] * r12[2]);
float d12inv = 1.0f / d12;
float fc12, fcp12;
find_fc_and_fcp(paramb.rc_radial, paramb.rcinv_radial, d12, fc12, fcp12);
float rc = paramb.rc_radial;
if (paramb.use_typewise_cutoff) {
rc = min((COVALENT_RADIUS[paramb.atomic_numbers[t1]] + COVALENT_RADIUS[paramb.atomic_numbers[t2]]) * 2.5f, rc);
}
float rcinv = 1.0f / rc;
find_fc_and_fcp(rc, rcinv, d12, fc12, fcp12);
float fn12[MAX_NUM_N];
float fnp12[MAX_NUM_N];
float f12[3] = {0.0f};

if (paramb.version == 2) {
find_fn_and_fnp(paramb.n_max_radial, paramb.rcinv_radial, d12, fc12, fcp12, fn12, fnp12);
find_fn_and_fnp(paramb.n_max_radial, rcinv, d12, fc12, fcp12, fn12, fnp12);
for (int n = 0; n <= paramb.n_max_radial; ++n) {
float tmp12 = g_Fp[n1 + n * N] * fnp12[n] * d12inv;
tmp12 *= (paramb.num_types == 1)
Expand All @@ -581,7 +599,7 @@ static __global__ void find_force_radial(
}
} else {
find_fn_and_fnp(
paramb.basis_size_radial, paramb.rcinv_radial, d12, fc12, fcp12, fn12, fnp12);
paramb.basis_size_radial, rcinv, d12, fc12, fcp12, fn12, fnp12);
for (int n = 0; n <= paramb.n_max_radial; ++n) {
float gnp12 = 0.0f;
for (int k = 0; k <= paramb.basis_size_radial; ++k) {
Expand Down Expand Up @@ -670,15 +688,20 @@ static __global__ void find_force_angular(
float r12[3] = {g_x12[index], g_y12[index], g_z12[index]};
float d12 = sqrt(r12[0] * r12[0] + r12[1] * r12[1] + r12[2] * r12[2]);
float fc12, fcp12;
find_fc_and_fcp(paramb.rc_angular, paramb.rcinv_angular, d12, fc12, fcp12);
int t2 = g_type[n2];
float rc = paramb.rc_angular;
if (paramb.use_typewise_cutoff) {
rc = min((COVALENT_RADIUS[paramb.atomic_numbers[t1]] + COVALENT_RADIUS[paramb.atomic_numbers[t2]]) * 2.0f, rc);
}
float rcinv = 1.0f / rc;
find_fc_and_fcp(rc, rcinv, d12, fc12, fcp12);
float f12[3] = {0.0f};

if (paramb.version == 2) {
for (int n = 0; n <= paramb.n_max_angular; ++n) {
float fn;
float fnp;
find_fn_and_fnp(n, paramb.rcinv_angular, d12, fc12, fcp12, fn, fnp);
find_fn_and_fnp(n, rcinv, d12, fc12, fcp12, fn, fnp);
const float c =
(paramb.num_types == 1)
? 1.0f
Expand All @@ -691,8 +714,7 @@ static __global__ void find_force_angular(
} else {
float fn12[MAX_NUM_N];
float fnp12[MAX_NUM_N];
find_fn_and_fnp(
paramb.basis_size_angular, paramb.rcinv_angular, d12, fc12, fcp12, fn12, fnp12);
find_fn_and_fnp(paramb.basis_size_angular, rcinv, d12, fc12, fcp12, fn12, fnp12);
for (int n = 0; n <= paramb.n_max_angular; ++n) {
float gn12 = 0.0f;
float gnp12 = 0.0f;
Expand Down Expand Up @@ -746,6 +768,7 @@ static __global__ void find_force_angular(

static __global__ void find_force_ZBL(
const int N,
const NEP3::ParaMB paramb,
const NEP3::ZBL zbl,
const int* g_NN,
const int* g_NL,
Expand All @@ -769,8 +792,8 @@ static __global__ void find_force_ZBL(
float s_virial_yz = 0.0f;
float s_virial_zx = 0.0f;
int type1 = g_type[n1];
float zi = zbl.atomic_numbers[type1];
float pow_zi = pow(zi, 0.23f);
int zi = zbl.atomic_numbers[type1];
float pow_zi = pow(float(zi), 0.23f);
int neighbor_number = g_NN[n1];
for (int i1 = 0; i1 < neighbor_number; ++i1) {
int index = i1 * N + n1;
Expand All @@ -780,8 +803,8 @@ static __global__ void find_force_ZBL(
float d12inv = 1.0f / d12;
float f, fp;
int type2 = g_type[n2];
float zj = zbl.atomic_numbers[type2];
float a_inv = (pow_zi + pow(zj, 0.23f)) * 2.134563f;
int zj = zbl.atomic_numbers[type2];
float a_inv = (pow_zi + pow(float(zj), 0.23f)) * 2.134563f;
float zizj = K_C_SP * zi * zj;
if (zbl.flexibled) {
int t1, t2;
Expand All @@ -799,7 +822,13 @@ static __global__ void find_force_ZBL(
}
find_f_and_fp_zbl(ZBL_para, zizj, a_inv, d12, d12inv, f, fp);
} else {
find_f_and_fp_zbl(zizj, a_inv, zbl.rc_inner, zbl.rc_outer, d12, d12inv, f, fp);
float rc_inner = zbl.rc_inner;
float rc_outer = zbl.rc_outer;
if (paramb.use_typewise_cutoff) {
rc_outer = min((COVALENT_RADIUS[zi] + COVALENT_RADIUS[zj]) * 0.7f, rc_outer);
rc_inner = rc_outer * 0.5f;
}
find_f_and_fp_zbl(zizj, a_inv, rc_inner, rc_outer, d12, d12inv, f, fp);
}
float f2 = fp * d12inv * 0.5f;
float f12[3] = {r12[0] * f2, r12[1] * f2, r12[2] * f2};
Expand Down Expand Up @@ -999,6 +1028,7 @@ void NEP3::find_force(
if (zbl.enabled) {
find_force_ZBL<<<grid_size, block_size>>>(
dataset[device_id].N,
paramb,
zbl,
nep_data[device_id].NN_angular.data(),
nep_data[device_id].NL_angular.data(),
Expand Down
10 changes: 6 additions & 4 deletions src/main_nep/nep3.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class NEP3 : public Potential
{
public:
struct ParaMB {
bool use_typewise_cutoff = false;
float rc_radial = 0.0f; // radial cutoff
float rc_angular = 0.0f; // angular cutoff
float rcinv_radial = 0.0f; // inverse of the radial cutoff
Expand All @@ -56,15 +57,16 @@ public:
int num_types_sq = 0; // for nep3
int num_c_radial = 0; // for nep3
int version = 2; // 2 for NEP2 and 3 for NEP3
int atomic_numbers[NUM_ELEMENTS];
};

struct ANN {
int dim = 0; // dimension of the descriptor
int num_neurons1 = 0; // number of neurons in the hidden layer
int num_para = 0; // number of parameters
const float* w0[100]; // weight from the input layer to the hidden layer
const float* b0[100]; // bias for the hidden layer
const float* w1[100]; // weight from the hidden layer to the output layer
const float* w0[NUM_ELEMENTS]; // weight from the input layer to the hidden layer
const float* b0[NUM_ELEMENTS]; // bias for the hidden layer
const float* w1[NUM_ELEMENTS]; // weight from the hidden layer to the output layer
const float* b1; // bias for the output layer
// for the scalar part of polarizability
const float* w0_pol[10]; // weight from the input layer to the hidden layer
Expand All @@ -82,7 +84,7 @@ public:
float rc_outer = 2.0f;
int num_types;
float para[550];
float atomic_numbers[NUM_ELEMENTS];
int atomic_numbers[NUM_ELEMENTS];
};

NEP3(
Expand Down
20 changes: 19 additions & 1 deletion src/main_nep/parameters.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const std::string ELEMENTS[NUM_ELEMENTS] = {
"Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd",
"Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re",
"Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th",
"Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr"};
"Pa", "U", "Np", "Pu"};

Parameters::Parameters()
{
Expand Down Expand Up @@ -72,6 +72,7 @@ void Parameters::set_default_parameters()
is_type_weight_set = false;
is_zbl_set = false;
is_force_delta_set = false;
is_use_typewise_cutoff_set = false;

train_mode = 0; // potential
prediction = 0; // not prediction mode
Expand All @@ -97,6 +98,7 @@ void Parameters::set_default_parameters()
maximum_generation = 100000; // a good starting point
initial_para = 1.0f;
sigma0 = 0.1f;
use_typewise_cutoff = false;

type_weight_cpu.resize(NUM_ELEMENTS);
zbl_para.resize(550); // Maximum number of zbl parameters
Expand Down Expand Up @@ -299,6 +301,12 @@ void Parameters::report_inputs()
printf(" (default) angular cutoff = %g A.\n", rc_angular);
}

if (is_use_typewise_cutoff_set) {
printf(" (input) use %s cutoff.\n", use_typewise_cutoff ? "typewise" : "global");
} else {
printf(" (default) use %s cutoff.\n", use_typewise_cutoff ? "typewise" : "global");
}

if (is_n_max_set) {
printf(" (input) n_max_radial = %d.\n", n_max_radial);
printf(" (input) n_max_angular = %d.\n", n_max_angular);
Expand Down Expand Up @@ -461,6 +469,8 @@ void Parameters::parse_one_keyword(std::vector<std::string>& tokens)
parse_initial_para(param, num_param);
} else if (strcmp(param[0], "sigma0") == 0) {
parse_sigma0(param, num_param);
} else if (strcmp(param[0], "use_typewise_cutoff") == 0) {
parse_use_typewise_cutoff(param, num_param);
} else {
PRINT_KEYWORD_ERROR(param[0]);
}
Expand Down Expand Up @@ -960,3 +970,11 @@ void Parameters::parse_sigma0(const char** param, int num_param)
PRINT_INPUT_ERROR("sigma0 should be within [0.01, 0.1].");
}
}

void Parameters::parse_use_typewise_cutoff(const char** param, int num_param)
{
if (num_param != 1) {
PRINT_INPUT_ERROR("use_typewise_cutoff should have no parameter.\n");
}
use_typewise_cutoff = true;
}
3 changes: 3 additions & 0 deletions src/main_nep/parameters.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public:
int prediction; // 0=no, 1=yes
float initial_para;
float sigma0;
bool use_typewise_cutoff;

// check if a parameter has been set:
bool is_train_mode_set;
Expand All @@ -78,6 +79,7 @@ public:
bool is_type_weight_set;
bool is_force_delta_set;
bool is_zbl_set;
bool is_use_typewise_cutoff_set;

// other parameters
int dim; // dimension of the descriptor vector
Expand Down Expand Up @@ -129,4 +131,5 @@ private:
void parse_generation(const char** param, int num_param);
void parse_initial_para(const char** param, int num_param);
void parse_sigma0(const char** param, int num_param);
void parse_use_typewise_cutoff(const char** param, int num_param);
};
2 changes: 1 addition & 1 deletion src/mc/nep_energy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const std::string ELEMENTS[NUM_ELEMENTS] = {
"Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd",
"Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re",
"Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th",
"Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr"};
"Pa", "U", "Np", "Pu"};

void NEP_Energy::initialize(const char* file_potential)
{
Expand Down
2 changes: 1 addition & 1 deletion src/utilities/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#pragma once

const int MAX_NUM_BEADS = 128;
const int NUM_ELEMENTS = 103;
const int NUM_ELEMENTS = 94;
#define PI 3.14159265358979
#define HBAR 6.465412e-2 // Planck's constant
#define K_B 8.617343e-5 // Boltzmann's constant
Expand Down
12 changes: 12 additions & 0 deletions src/utilities/nep_utilities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ __constant__ float C4B[5] = {
-0.809943929279723f};
__constant__ float C5B[3] = {0.026596810706114f, 0.053193621412227f, 0.026596810706114f};

__constant__ float COVALENT_RADIUS[94] =
{0.426667f,0.613333f,1.6f,1.25333f,1.02667f,1.0f,0.946667f,0.84f,0.853333f,0.893333f,
1.86667f,1.66667f,1.50667f,1.38667f,1.46667f,1.36f,1.32f,1.28f,2.34667f,2.05333f,
1.77333f,1.62667f,1.61333f,1.46667f,1.42667f,1.38667f,1.33333f,1.32f,1.34667f,1.45333f,
1.49333f,1.45333f,1.53333f,1.46667f,1.52f,1.56f,2.52f,2.22667f,1.96f,1.85333f,
1.76f,1.65333f,1.53333f,1.50667f,1.50667f,1.44f,1.53333f,1.64f,1.70667f,1.68f,
1.68f,1.64f,1.76f,1.74667f,2.78667f,2.34667f,2.16f,1.96f,2.10667f,2.09333f,
2.08f,2.06667f,2.01333f,2.02667f,2.01333f,2.0f,1.98667f,1.98667f,1.97333f,2.04f,
1.94667f,1.82667f,1.74667f,1.64f,1.57333f,1.54667f,1.48f,1.49333f,1.50667f,1.76f,
1.73333f,1.73333f,1.81333f,1.74667f,1.84f,1.89333f,2.68f,2.41333f,2.22667f,2.10667f,
2.02667f,2.04f,2.05333f,2.06667f};

const int SIZE_BOX_AND_INVERSE_BOX = 18; // (3 * 3) * 2
const int MAX_NUM_N = 20; // n_max+1 = 19+1
const int MAX_DIM = MAX_NUM_N * 7;
Expand Down

0 comments on commit 2f81eff

Please sign in to comment.