Skip to content

Commit

Permalink
Config update (duartegroup#91)
Browse files Browse the repository at this point in the history
* Format update

* Double quotation marks correction

* Double quotation marks correction
  • Loading branch information
juraskov authored Mar 27, 2024
1 parent 839b77c commit cd35f02
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
10 changes: 9 additions & 1 deletion mlptrain/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ class _ConfigClass:
'l_max': 6, # n_max = 2 l_max
'sigma_at': 0.5, # Å
}
# ACE params
ace_params = {
'N': 4, # maximum correlation order
'r_cut': 4.0, # outer cutoff of ACE
'deg_pair': 5, # Specify the pair potential
'r_cut_pair': 5.0,
}

# NeQUIP params
nequip_params = {'cutoff': 4.0, 'train_fraction': 0.9}
Expand All @@ -48,7 +55,7 @@ class _ConfigClass:
'forces_weight': 5.0,
'hidden_irreps': '128x0e + 128x1o',
'batch_size': 10,
'r_max': 5,
'r_max': 5.0,
'correlation': 3,
'device': mace_device,
'calc_device': 'cpu',
Expand All @@ -60,6 +67,7 @@ class _ConfigClass:
'amsgrad': True,
'restart_latest': False,
'save_cpu': True,
'dtype': 'float32',
}

# --------------------- Internal properties ---------------------------
Expand Down
1 change: 1 addition & 0 deletions mlptrain/potentials/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def copy(self) -> 'MLPotential':
'Ne': 1,
'Na': 2,
'Mg': 1,
'S': 3,
'Cl': 2,
'Ar': 1,
'K': 2,
Expand Down
9 changes: 5 additions & 4 deletions mlptrain/potentials/ace/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def _print_input(self, filename: str, **kwargs) -> None:
_str = ', '.join([f':{s}' for s in self.system.unique_atomic_symbols])

print(
f'species = [{_str}]\n' 'N = 4', # maximum correlation order
f'species = [{_str}]\n'
f"N = {Config.ace_params['N']}", # maximum correlation order
file=inp_file,
)

Expand Down Expand Up @@ -173,10 +174,10 @@ def _print_input(self, filename: str, **kwargs) -> None:
print(
'r0 = 1.3\n'
f'r_in = {self._r_in_estimate:.4f}\n' # inner cutoff of ACE, choose a little more than min dist in dataset
'r_cut = 4.0\n' # outer cutoff of ACE
f"r_cut = {Config.ace_params['r_cut']}\n" # outer cutoff of ACE
'\n'
'deg_pair = 5\n' # Specify the pair potential
'r_cut_pair = 5.0\n',
f"deg_pair = {Config.ace_params['deg_pair']}\n" # Specify the pair potential
f"r_cut_pair = {Config.ace_params['r_cut_pair']}\n",
file=inp_file,
)

Expand Down

0 comments on commit cd35f02

Please sign in to comment.