Skip to content

Commit

Permalink
Fix: wrong STRU read in md restart case (#5157)
Browse files Browse the repository at this point in the history
* Fix: wrong STRU read in md restart case

* Refactor: update error info

* Fix: heap-buffer-overflow in stress_op_test.cpp

* Refactor: add a new para for STRU

* [pre-commit.ci lite] apply automatic fixes

* Fix: set_golbalv

* Tests: update unitests

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
YuLiu98 and pre-commit-ci-lite[bot] authored Sep 24, 2024
1 parent eab9d73 commit f356958
Show file tree
Hide file tree
Showing 16 changed files with 79 additions and 62 deletions.
1 change: 1 addition & 0 deletions docs/advanced/input_files/input-main.md
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ These variables are used to control parameters related to input files.
- **Type**: String
- **Description**: the name of the structure file
- Containing various information about atom species, including pseudopotential files, local orbitals files, cell information, atom positions, and whether atoms should be allowed to move.
- When [calculation](#calculation) is set to `md` and [md_restart](#md_restart) is set to `true`, this keyword will NOT work.
- Refer to [Doc](https://github.com/deepmodeling/abacus-develop/blob/develop/docs/advanced/input_files/stru.md)
- **Default**: STRU

Expand Down
2 changes: 1 addition & 1 deletion source/driver_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void Driver::driver_run() {

// the life of ucell should begin here, mohan 2024-05-12
// delete ucell as a GlobalC in near future
GlobalC::ucell.setup_cell(PARAM.inp.stru_file, GlobalV::ofs_running);
GlobalC::ucell.setup_cell(PARAM.globalv.global_in_stru, GlobalV::ofs_running);
Check_Atomic_Stru::check_atomic_stru(GlobalC::ucell,
PARAM.inp.min_dist_coef);

Expand Down
2 changes: 0 additions & 2 deletions source/module_base/global_variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ int GSIZE = DSIZE;
//----------------------------------------------------------
// EXPLAIN : The input file name and directory
//----------------------------------------------------------
std::string stru_file = "STRU";

std::ofstream ofs_running;
std::ofstream ofs_warning;
std::ofstream ofs_info; // output math lib info
Expand Down
1 change: 0 additions & 1 deletion source/module_base/global_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ extern int KPAR_LCAO;
// NAME : ofs_running( contain information during runnnig)
// NAME : ofs_warning( contain warning information, including error)
//==========================================================
extern std::string stru_file;
// extern std::string global_pseudo_type; // mohan add 2013-05-20 (xiaohui add
// 2013-06-23)
extern std::ofstream ofs_running;
Expand Down
14 changes: 9 additions & 5 deletions source/module_cell/module_paw/paw_cell_libpaw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,11 @@ void Paw_Cell::mix_dij(const int iat, double*dij_paw)
const int size_dij = nproj * (nproj+1) / 2;
for(int i = 0; i < size_dij * nspden; i ++)
{
if(!first_iter) dij_paw[i] = dij_save[iat][i] * (1.0 - mixing_beta) + dij_paw[i] * mixing_beta;
if(!first_iter) { dij_paw[i] = dij_save[iat][i] * (1.0 - mixing_beta) + dij_paw[i] * mixing_beta;
}

if(count > 30) dij_paw[i] = dij_save[iat][i];
if(count > 30) { dij_paw[i] = dij_save[iat][i];
}

dij_save[iat][i] = dij_paw[i];
}
Expand All @@ -224,7 +226,7 @@ void Paw_Cell::set_libpaw_files()
filename_list = new char[ntypat*264];
if(GlobalV::MY_RANK == 0)
{
std::ifstream ifa(PARAM.inp.stru_file.c_str(), std::ios::in);
std::ifstream ifa(PARAM.globalv.global_in_stru.c_str(), std::ios::in);
if (!ifa)
{
ModuleBase::WARNING_QUIT("set_libpaw_files", "can not open stru file");
Expand All @@ -234,7 +236,8 @@ void Paw_Cell::set_libpaw_files()
while(!ifa.eof())
{
getline(ifa,line);
if (line.find("PAW_FILES") != std::string::npos) break;
if (line.find("PAW_FILES") != std::string::npos) { break;
}
}

for(int i = 0; i < ntypat*264; i++)
Expand Down Expand Up @@ -681,7 +684,8 @@ void Paw_Cell::set_sij()
double* sij = new double[nproj * nproj];

#ifdef __MPI
if(GlobalV::RANK_IN_POOL == 0) extract_sij(it,size_sij,sij_libpaw);
if(GlobalV::RANK_IN_POOL == 0) { extract_sij(it,size_sij,sij_libpaw);
}
Parallel_Common::bcast_double(sij_libpaw,size_sij*nspden);
#else
extract_sij(it,size_sij,sij_libpaw);
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void ESolver_KS<T, Device>::before_all_runners(const Input_para& inp, UnitCell&

if (GlobalV::MY_RANK == 0)
{
std::ifstream ifa(PARAM.inp.stru_file.c_str(), std::ios::in);
std::ifstream ifa(PARAM.globalv.global_in_stru.c_str(), std::ios::in);
if (!ifa)
{
ModuleBase::WARNING_QUIT("set_libpaw_files", "can not open stru file");
Expand Down

Large diffs are not rendered by default.

13 changes: 0 additions & 13 deletions source/module_io/input_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,19 +166,6 @@ void Input_Conv::Convert()
//----------------------------------------------------------
// main parameters / electrons / spin ( 10/16 )
//----------------------------------------------------------
// suffix
if (PARAM.inp.calculation == "md" && PARAM.mdp.md_restart) // md restart liuyu add 2023-04-12
{
int istep = 0;
double temperature = 0.0;
MD_func::current_md_info(GlobalV::MY_RANK, PARAM.globalv.global_readin_dir, istep, temperature);
if (PARAM.inp.read_file_dir == "auto")
{
GlobalV::stru_file = PARAM.globalv.global_stru_dir + "STRU_MD_" + std::to_string(istep);
}
} else if (PARAM.inp.stru_file != "") {
GlobalV::stru_file = PARAM.inp.stru_file;
}

ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "pseudo_dir", PARAM.inp.pseudo_dir);
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "orbital_dir", PARAM.inp.orbital_dir);
Expand Down
2 changes: 1 addition & 1 deletion source/module_io/json_output/general_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void gen_general_info(const Parameter& param)
AbacusJson::add_json({"general_info", "omp_num"}, omp_num, false);
AbacusJson::add_json({"general_info", "pseudo_dir"}, param.inp.pseudo_dir, false);
AbacusJson::add_json({"general_info", "orbital_dir"}, param.inp.orbital_dir, false);
AbacusJson::add_json({"general_info", "stru_file"}, param.inp.stru_file, false);
AbacusJson::add_json({"general_info", "stru_file"}, param.globalv.global_in_stru, false);
AbacusJson::add_json({"general_info", "kpt_file"}, param.inp.kpoint_file, false);
AbacusJson::add_json({"general_info", "start_time"}, start_time_str, false);
AbacusJson::add_json({"general_info", "end_time"}, end_time_str, false);
Expand Down
2 changes: 1 addition & 1 deletion source/module_io/json_output/test/para_json_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ TEST(AbacusJsonTest, GeneralInfo)
PARAM.input.device = "cpu";
PARAM.input.pseudo_dir = "./abacus/test/pseudo_dir";
PARAM.input.orbital_dir = "./abacus/test/orbital_dir";
PARAM.input.stru_file = "./abacus/test/stru_file";
PARAM.sys.global_in_stru = "./abacus/test/stru_file";
PARAM.input.kpoint_file = "./abacus/test/kpoint_file";
// output the json file
Json::AbacusJson::doc.Parse("{}");
Expand Down
33 changes: 26 additions & 7 deletions source/module_io/read_input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
readvalue_item->read_value(*readvalue_item, param);
}

// 2) count the number of atom types from STRU file
if (this->check_ntype_flag) {
check_ntype(param.input.stru_file, param.input.ntype);
}

// 3) reset this value when some conditions are met
// 2) reset this value when some conditions are met
// e.g. if (calulation_type == "nscf") then set "init_chg" to "file".
for (auto& input_item: this->input_lists)
{
Expand All @@ -320,6 +315,12 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
}
this->set_globalv(param);

// 3) count the number of atom types from STRU file
if (this->check_ntype_flag)
{
check_ntype(param.globalv.global_in_stru, param.input.ntype);
}

// 4) check the value of the parameters
for (auto& input_item: this->input_lists)
{
Expand Down Expand Up @@ -432,7 +433,7 @@ void ReadInput::check_ntype(const std::string& fn, int& param_ntype)
if (!ifa)
{
GlobalV::ofs_warning << fn;
ModuleBase::WARNING_QUIT("ReadInput::check_ntype", "Can not find the file containing atom positions.!");
ModuleBase::WARNING_QUIT("ReadInput::check_ntype", "Can not find the file: " + fn);
}

int ntype_stru = 0;
Expand Down Expand Up @@ -476,6 +477,24 @@ void ReadInput::check_ntype(const std::string& fn, int& param_ntype)
}
}

int ReadInput::current_md_step(const std::string& file_dir)
{
std::stringstream ssc;
ssc << file_dir << "Restart_md.dat";
std::ifstream file(ssc.str().c_str());

if (!file)
{
ModuleBase::WARNING_QUIT("current_md_step", "no Restart_md.dat");
}

int md_step;
file >> md_step;
file.close();

return md_step;
}

void ReadInput::add_item(const Input_Item& item)
{
// only rank 0 read the input file
Expand Down
7 changes: 7 additions & 0 deletions source/module_io/read_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ class ReadInput
* @param filename output file name
*/
void write_txt_input(const Parameter& param, const std::string& filename);
/**
* @brief determine the md step in restart case
*
* @param file_dir directory of Restart_md.dat
* @return md step
*/
int current_md_step(const std::string& file_dir);
/**
* @brief count_nype from STRU file
*
Expand Down
26 changes: 23 additions & 3 deletions source/module_io/read_set_globalv.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include "read_input.h"
#include "read_input_tool.h"
#include "module_parameter/parameter.h"
#include "module_base/global_variable.h"
#include "module_base/tool_quit.h"
#include "module_parameter/parameter.h"
#include "read_input.h"
#include "read_input_tool.h"
namespace ModuleIO
{
void ReadInput::set_globalv(Parameter& para)
Expand Down Expand Up @@ -31,6 +31,26 @@ void ReadInput::set_globalv(Parameter& para)
para.sys.global_readin_dir = para.inp.read_file_dir + '/';
}
para.sys.global_readin_dir = to_dir(para.sys.global_readin_dir);

/// get the stru file for md restart case
if (para.inp.calculation == "md" && para.mdp.md_restart)
{
int istep = current_md_step(para.sys.global_readin_dir);

if (para.inp.read_file_dir == "auto")
{
para.sys.global_in_stru = para.sys.global_stru_dir + "STRU_MD_" + std::to_string(istep);
}
else
{
para.sys.global_in_stru = para.inp.read_file_dir + "STRU_MD_" + std::to_string(istep);
}
}
else
{
para.sys.global_in_stru = para.inp.stru_file;
}

/// caculate the gamma_only_pw and gamma_only_local
if (para.input.gamma_only)
{
Expand Down
3 changes: 2 additions & 1 deletion source/module_parameter/system_parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ struct System_para
///< for plane wave basis.
bool gamma_only_local = false; ///< true if "gamma_only" is true and "lcao"
///< is true; for local orbitals.
std::string global_in_card = "INPUT"; ///< global input directory
std::string global_in_card = "INPUT"; ///< input file
std::string global_in_stru = "STRU"; ///< stru file
std::string global_out_dir = ""; ///< global output directory
std::string global_readin_dir = ""; ///< global readin directory
std::string global_stru_dir = ""; ///< global structure directory
Expand Down
19 changes: 0 additions & 19 deletions tests/integrate/120_PW_KP_MD_NVT/STRU

This file was deleted.

12 changes: 6 additions & 6 deletions tests/integrate/120_PW_KP_MD_NVT/STRU_MD_2
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
ATOMIC_SPECIES
Si 1 Si_ONCV_PBE-1.0.upf upf201
Si 1 Si_ONCV_PBE-1.0.upf

LATTICE_CONSTANT
10.2

LATTICE_VECTORS
0.5000000000 0.5000000000 0.0000000000 #latvec1
0.5000000000 0.0000000000 0.5000000000 #latvec2
0.0000000000 0.5000000000 0.5000000000 #latvec3
0.5 0.5 0 #latvec1
0.5 0 0.5 #latvec2
0 0.5 0.5 #latvec3

ATOMIC_POSITIONS
Cartesian

Si #label
0 #magnetism
2 #number of atoms
0.9961377061 0.5017059282 0.5004642243 m 1 1 1 v -0.0005823336 0.0003037059 0.0000737542
0.2448622939 0.2532940718 0.2505357757 m 1 1 1 v 0.0005823336 -0.0003037059 -0.0000737542
0 0 0 m 1 1 1 v -0.00016026383234 1.88587389656e-05 4.31519285331e-06
0.241 0.255 0.250999999999 m 1 1 1 v 0.00016026383234 -1.88587389656e-05 -4.31519285331e-06

0 comments on commit f356958

Please sign in to comment.