diff --git a/mlptrain/config.py b/mlptrain/config.py index 9ef1d26f..0f8dd597 100644 --- a/mlptrain/config.py +++ b/mlptrain/config.py @@ -26,6 +26,13 @@ class _ConfigClass: 'l_max': 6, # n_max = 2 l_max 'sigma_at': 0.5, # Å } + # ACE params + ace_params = { + 'N': 4, # maximum correlation order + 'r_cut': 4.0, # outer cutoff of ACE + 'deg_pair': 5, # Specify the pair potential + 'r_cut_pair': 5.0, + } # NeQUIP params nequip_params = {'cutoff': 4.0, 'train_fraction': 0.9} @@ -48,7 +55,7 @@ class _ConfigClass: 'forces_weight': 5.0, 'hidden_irreps': '128x0e + 128x1o', 'batch_size': 10, - 'r_max': 5, + 'r_max': 5.0, 'correlation': 3, 'device': mace_device, 'calc_device': 'cpu', @@ -60,6 +67,7 @@ class _ConfigClass: 'amsgrad': True, 'restart_latest': False, 'save_cpu': True, + 'dtype': 'float32', } # --------------------- Internal properties --------------------------- diff --git a/mlptrain/potentials/_base.py b/mlptrain/potentials/_base.py index 1bab1e5d..6e97fe34 100644 --- a/mlptrain/potentials/_base.py +++ b/mlptrain/potentials/_base.py @@ -339,6 +339,7 @@ def copy(self) -> 'MLPotential': 'Ne': 1, 'Na': 2, 'Mg': 1, + 'S': 3, 'Cl': 2, 'Ar': 1, 'K': 2, diff --git a/mlptrain/potentials/ace/ace.py b/mlptrain/potentials/ace/ace.py index 08ecb9a1..25eae4d7 100644 --- a/mlptrain/potentials/ace/ace.py +++ b/mlptrain/potentials/ace/ace.py @@ -134,7 +134,8 @@ def _print_input(self, filename: str, **kwargs) -> None: _str = ', '.join([f':{s}' for s in self.system.unique_atomic_symbols]) print( - f'species = [{_str}]\n' 'N = 4', # maximum correlation order + f'species = [{_str}]\n' + f"N = {Config.ace_params['N']}", # maximum correlation order file=inp_file, ) @@ -173,10 +174,10 @@ def _print_input(self, filename: str, **kwargs) -> None: print( 'r0 = 1.3\n' f'r_in = {self._r_in_estimate:.4f}\n' # inner cutoff of ACE, choose a little more than min dist in dataset - 'r_cut = 4.0\n' # outer cutoff of ACE + f"r_cut = {Config.ace_params['r_cut']}\n" # outer cutoff of ACE '\n' - 'deg_pair = 5\n' # Specify the pair potential - 'r_cut_pair = 5.0\n', + f"deg_pair = {Config.ace_params['deg_pair']}\n" # Specify the pair potential + f"r_cut_pair = {Config.ace_params['r_cut_pair']}\n", file=inp_file, )