Skip to content

Commit

Permalink
Add units to lammps interface and fix kokkos+cpu bug. (#90)
Browse files Browse the repository at this point in the history
* Add tentative fix for kokkos CPU, and add units to lammps interface

* Fixes to mliap to enable usage on other hardware

Fixes to triton kernel

* Fixed spelling

---------

Co-authored-by: Ben Nebgen <bnebgen@lanl.gov>
  • Loading branch information
lubbersnick and bnebgen-LANL authored Aug 28, 2024
1 parent e2c24a2 commit 39b7bb2
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 22 deletions.
9 changes: 4 additions & 5 deletions hippynn/custom_kernels/env_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
# Load backup implementation for CPU tensors.
from .env_pytorch import envsum as envsum_alternative, sensesum as sensesum_alternative, featsum as featsum_alternative


def config_pruner(configs, kwargs):
def config_pruner(configs, nargs, **kwargs):
"""
Trims the unnecessary config options based on the sens. and feat. sizes
"""
p2_sens_size = triton.next_power_of_2(kwargs["sens_size"])
p2_feat_size = triton.next_power_of_2(kwargs["feat_size"])
#print("For some reason the config pruner also gets arguments:",kwargs)
p2_sens_size = triton.next_power_of_2(nargs["sens_size"])
p2_feat_size = triton.next_power_of_2(nargs["feat_size"])

used = set()
for config in configs:
Expand All @@ -40,7 +40,6 @@ def config_pruner(configs, kwargs):
num_warps=config.num_warps,
)


def get_autotune_config():
"""
Create a list of config options for the kernels
Expand Down
78 changes: 61 additions & 17 deletions hippynn/interfaces/lammps_interface/mliap_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,36 @@ class MLIAPInterface(MLIAPUnified):
Class for creating ML-IAP Unified model based on hippynn graphs.
"""

def __init__(self, energy_node, element_types, ndescriptors=1, model_device=torch.device("cpu"), compute_dtype=torch.float32):
def __init__(
self,
energy_node,
element_types,
ndescriptors=1,
model_device=torch.device("cpu"),
compute_dtype=torch.float32,
energy_unit: float = None,
distance_unit: float = None,
):
"""
:param energy_node: Node for energy
:param element_types: list of atomic symbols corresponding to element types
:param ndescriptors: the number of descriptors to report to LAMMPS
:param model_device: the device to send torch data to (cpu or cuda)
:param energy_unit: If present, multiply the result by the given energy units.
If your model was trained in Hartree and your lammps script will operate in eV,
use en_unit = ase.units.Ha = 27.211386024367243
:param distance_unit: If present, multi input distances by this much as well as dividing into output forces.
If your model was trained to accept nm as input and lammps uses Angstroms,
use dist_unit = ase.units.nm = 10.
"""
super().__init__()
if hippynn.settings.PYTORCH_GPU_MEM_FRAC < 1.0:
torch.cuda.set_per_process_memory_fraction(hippynn.settings.PYTORCH_GPU_MEM_FRAC)
self.element_types = element_types
self.ndescriptors = ndescriptors
self.model_device = model_device
self.energy_unit = energy_unit
self.distance_unit = distance_unit

# Build the calculator
self.rcutfac, self.species_set, self.graph = setup_LAMMPS_graph(energy_node)
Expand All @@ -56,8 +73,8 @@ def compute_descriptors(self, data):
def as_tensor(self, array):
return torch.as_tensor(array, device=self.model_device)

def empty_tensor(self,dimentions):
return torch.empty(dimentions,device=self.model_device)
def empty_tensor(self, dimentions):
return torch.empty(dimentions, device=self.model_device)

def compute_forces(self, data):
"""
Expand All @@ -67,44 +84,66 @@ def compute_forces(self, data):
"""
nlocal = self.as_tensor(data.nlistatoms)
if nlocal.item() > 0:
#If there are no local atoms, do nothing
# If there are no local atoms, do nothing
elems = self.as_tensor(data.elems).type(torch.int64).reshape(1, data.ntotal)
z_vals = self.species_set[elems + 1]
npairs = data.npairs

if npairs > 0:
pair_i = self.as_tensor(data.pair_i).type(torch.int64)
pair_j = self.as_tensor(data.pair_j).type(torch.int64)
rij = self.as_tensor(data.rij).type(self.compute_dtype)
else:
pair_i = self.empty_tensor(0).type(torch.int64)
pair_j = self.empty_tensor(0).type(torch.int64)
rij = self.empty_tensor([0,3]).type(self.compute_dtype)

rij = self.empty_tensor([0, 3]).type(self.compute_dtype)

if self.distance_unit is not None:
rij = self.dist_unit * rij

# note your sign for rij might need to be +1 or -1, depending on how your implementation works
inputs = [z_vals, pair_i, pair_j, -rij, nlocal]
atom_energy, total_energy, fij = self.graph(*inputs)

# Test if we are using lammps-kokkos or not. Is there a more clear way to do that?
if isinstance(data.elems, np.ndarray):
return_device = "cpu"
else:
# Hope that kokkos device and pytorch device are the same (default cuda)
using_kokkos = "kokkos" in data.__class__.__module__.lower()
if using_kokkos:
return_device = elems.device

else:
return_device = "cpu"

# convert units
if self.energy_unit is not None:
atom_energy = self.en_unit * atom_energy
total_energy = self.en_unit * total_energy
fij = self.en_unit * fij

if self.distance_unit is not None:
fij = fij / self.dist_unit

atom_energy = atom_energy.squeeze(1).detach().to(return_device)
total_energy = total_energy.detach().to(return_device)

f = self.as_tensor(data.f)
fij = fij.type(f.dtype).detach().to(return_device)

if return_device == "cpu":

# hacky way to detect if we are in kokkos or not.

if not using_kokkos:
# write back to data.eatoms directly.
fij = fij.numpy()
data.eatoms = atom_energy.numpy().astype(np.double)
if npairs > 0:
data.update_pair_forces(fij)
else:
# view to data.eatoms using pytorch, and write into the view.
eatoms = torch.as_tensor(data.eatoms, device=return_device)
eatoms.copy_(atom_energy)
if npairs > 0:
data.update_pair_forces(fij)
if npairs > 0:
if return_device == "cpu":
data.update_pair_forces_cpu(fij)
else:
data.update_pair_forces_gpu(fij)

data.energy = total_energy.item()

def __getstate__(self):
Expand All @@ -116,11 +155,16 @@ def __setstate__(self, state):
self.__dict__.update(state)
try:
torch.ones(0).to(self.model_device)
except (RuntimeError, AssertionError):
except RuntimeError:
fallback = device_fallback()
warnings.warn(f"Model device ({self.model_device}) not found, falling back to f{fallback}")
self.model_device = fallback

if not hasattr(self, "en_unit"):
self.en_unit = None
if not hasattr(self, "dist_unit"):
self.dist_unit = None

self.species_set = self.species_set.to(self.model_device)
self.graph.to(self.model_device)

Expand Down

0 comments on commit 39b7bb2

Please sign in to comment.