-
Notifications
You must be signed in to change notification settings - Fork 5
Scaling laws
Understanding the scaling laws for inference in neural network potentials (NNPs) is essential for designing efficient architectures, especially as system size and model complexity increase. The computational cost and memory requirements typically scale with the number of atoms
Note
For the sake of simplicity, we will assume 64 bits for integers and floating-point numbers. In practice we use 32 bits for floats and integers (for integers that are used as indices PyTorch requires 64 bits).
Given the limited and finite nature of GPU memory compared to CPU memory, optimizing memory consumption during inference is crucial. GPU memory is often the limiting factor for large systems and complex models. Here, we break down the memory consumption for inference in NNPs into two main categories: model memory, force calculation memory, and system-dependent memory.
The model itself must be transferred to GPU memory before inference. For a typical NNP with 2 million float64 parameters, the model's footprint in memory is straightforward to calculate. Each float64 parameter requires 8 bytes, resulting in:
This is the baseline memory allocation purely for the model's parameters. During backpropagation (when derivatives are computed for forces), memory requirements can increase substantially, as activations and gradients need to be stored for each parameter, which will depend on the model architecture.
When forces are calculated (i.e., backpropagation with respect to atomic coordinates), memory consumption increases because the derivatives of the energy with respect to the coordinates must be computed and stored. This involves storing gradients for each of the
This storage represents only the gradients of the atomic positions, but memory consumption increases significantly as it scales with the number of layers (
In practice, the memory consumption will depend on the architecture of the model. But in general, this shows that memory consumption can grow significatnly as the depth of the model and the number of interactions increase. Using (automatic) mixed (lower) precision and gradient-checkpoints are viable strategies to mitigate some of these issues.
As an example, for a fully connected graph (NxM = N^2) that passes through 4 linear layers, each with output dimension 128 the total memory scales with
The memory consumption also scales with the number of atoms
In this expression, the terms correspond to storing 3D positions (9 floats), atom pair distances (1 float), and atom pair indices (2 ints). For
The neighborlist identifies the
The following plots are generated using a realistic set of hyperparamters and torch.float32
.
Only forward pass: Forward and backward pass:
import torch
from openmmtools.testsystems import WaterBox
from simtk import unit
from typing import List
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import time # Import time module for timing
def measure_performance_for_edge_sizes(
edge_sizes: List[float],
potential_names: List[str],
):
"""
Measures GPU memory utilization and computation time for force calculations
for water boxes of different edge sizes across multiple potentials.
Parameters
----------
edge_sizes : List[float]
A list of edge sizes (in nanometers) for the water boxes.
potential_names : List[str]
A list of potential names to use in the model setup.
Returns
-------
List[dict]
A list of dictionaries containing edge size, number of water molecules,
potential name, memory usage in bytes, and computation time in seconds.
"""
results = []
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
precicion = torch.float32
for potential_name in potential_names:
for edge_size in edge_sizes:
# Generate water box with the given edge size
test_system = WaterBox(box_edge=edge_size * unit.nanometer)
positions = test_system.positions # Positions in nanometers
topology = test_system.topology
# Extract atomic numbers and residue indices
atomic_numbers = []
residue_indices = []
for residue_index, residue in enumerate(topology.residues()):
for atom in residue.atoms():
atomic_numbers.append(atom.element.atomic_number)
residue_indices.append(residue_index)
num_waters = len(list(topology.residues()))
positions_in_nanometers = positions.value_in_unit(unit.nanometer)
# Convert to torch tensors and move to GPU
torch_atomic_numbers = torch.tensor(atomic_numbers, dtype=torch.long, device=device)
torch_positions = torch.tensor(positions_in_nanometers, dtype=torch.float32, device=device, requires_grad=True)
torch_atomic_subsystem_indices = torch.zeros_like(torch_atomic_numbers, dtype=torch.long, device=device)
torch_total_charge = torch.zeros(num_waters, dtype=torch.float32, device=device)
nnp_input = NNPInput(
atomic_numbers=torch_atomic_numbers,
positions=torch_positions,
atomic_subsystem_indices=torch_atomic_subsystem_indices,
total_charge=torch_total_charge,
).to(dtype=precicion)
# Import your model setup function
from modelforge.tests.helper_functions import setup_potential_for_test
# Setup model
model = setup_potential_for_test(
potential_name,
"inference",
potential_seed=42,
use_training_mode_neighborlist=False,
simulation_environment='PyTorch',
)
model.to(device)
model.to(precicion)
total_params = sum(p.numel() for p in model.parameters())
# Measure GPU memory usage and computation time
torch.cuda.reset_peak_memory_stats(device=device)
torch.cuda.synchronize()
# Run forward pass and time it
start_time = time.perf_counter()
try:
output = model(nnp_input.as_namedtuple())["per_molecule_energy"]
except :
print("Out of memory error during forward pass")
continue
try:
F_training = -torch.autograd.grad(
output.sum(), nnp_input.positions, create_graph=True, retain_graph=True
)[0]
except :
print("Out of memory error during backward pass")
continue
torch.cuda.synchronize()
end_time = time.perf_counter()
max_memory_allocated = torch.cuda.max_memory_allocated(device=device)
computation_time = end_time - start_time
results.append({
'potential_name': f"{potential_name}: {total_params:.1e} params",
'edge_size_nm': edge_size,
'num_waters': num_waters,
'memory_usage_bytes': max_memory_allocated,
'computation_time_s': computation_time
})
# Clean up
del nnp_input, output, model,
try:
del F_training
except:
pass
torch.cuda.empty_cache()
return results
def plot_computation_time(results):
"""
Plots computation time against the number of water molecules for multiple potentials.
Parameters
----------
results : List[dict]
A list of dictionaries containing edge size, number of water molecules,
potential name, memory usage in bytes, and computation time in seconds.
"""
# Create a DataFrame for plotting
df = pd.DataFrame(results)
df['computation_time_ms'] = df['computation_time_s'] * 1000 # Convert seconds to milliseconds
# Plot using seaborn
sns.set(style="whitegrid")
plt.figure(figsize=(10, 6))
sns.lineplot(
data=df,
x='num_waters',
y='computation_time_ms',
hue='potential_name',
units='potential_name',
estimator=None, # Do not aggregate data
marker='o',
linewidth=2,
markersize=8
)
plt.title('Computation Time vs Number of Water Molecules for Different Potentials')
plt.xlabel('Number of Water Molecules')
plt.ylabel('Computation Time (ms)')
plt.xticks(sorted(df['num_waters'].unique()))
plt.legend(title='Potential Name')
plt.tight_layout()
plt.show()
# Example usage:
edge_sizes = [1.0, 1.5, 2.0, 2.5, 3.0, 3.5] # Edge sizes in nanometers
potential_names = ['schnet', 'painn', 'physnet', 'ani2x', 'aimnet2', 'sake']
results = measure_performance_for_edge_sizes(
edge_sizes=edge_sizes,
potential_names=potential_names,
)
def plot_gpu_memory_usage(results):
"""
Plots GPU memory usage against the number of water molecules for multiple potentials.
Parameters
----------
results : List[dict]
A list of dictionaries containing edge size, number of water molecules,
potential name, and memory usage in bytes.
"""
# Create a DataFrame for plotting
df = pd.DataFrame(results)
df['memory_usage_mb'] = df['memory_usage_bytes'] / 1e6 # Convert bytes to megabytes
# Plot using seaborn
sns.set(style="whitegrid")
plt.figure(figsize=(10, 6))
sns.lineplot(
data=df,
x='num_waters',
y='memory_usage_mb',
units='potential_name',
estimator=None, # Do not aggregate data
hue='potential_name',
marker='o',
linewidth=2,
markersize=8,
)
plt.title('Backward pass: GPU Memory Usage vs Number of Water Molecules for Different Potentials')
plt.xlabel('Number of Water Molecules')
plt.ylabel('GPU Memory Usage (MB)')
plt.xticks(sorted(df['num_waters'].unique()))
plt.legend(title='Potential Name')
plt.tight_layout()
plt.show()
# Print the results
for result in results:
print(f"Potential: {result['potential_name']}, "
f"Edge Size: {result['edge_size_nm']} nm, "
f"Number of Waters: {result['num_waters']}, "
f"Memory Usage: {result['memory_usage_bytes']/1e6:.2f} MB, "
f"Computation Time: {result['computation_time_s']*1000:.2f} ms")
# Plot the computation time
plot_computation_time(results)
plot_gpu_memory_usage(results)