Skip to content

Scaling laws

Marcus Wieder edited this page Oct 3, 2024 · 29 revisions

Introduction

Understanding the scaling laws for inference in neural network potentials (NNPs) is essential for designing efficient architectures, especially as system size and model complexity increase. The computational cost and memory requirements typically scale with the number of atoms $𝑁$ in the system. These scaling factors can become critical bottlenecks for large-scale molecular dynamics simulations, especially when utilizing GPUs and optimizing for computational efficiency.

Note

For the sake of simplicity, we will assume 64 bits for integers and floating-point numbers. In practice we use 32 bits for floats and integers (for integers that are used as indices PyTorch requires 64 bits).

(GPU) Memory consumption

Given the limited and finite nature of GPU memory compared to CPU memory, optimizing memory consumption during inference is crucial. GPU memory is often the limiting factor for large systems and complex models. Here, we break down the memory consumption for inference in NNPs into two main categories: model memory, force calculation memory, and system-dependent memory.

Model memory consumption

The model itself must be transferred to GPU memory before inference. For a typical NNP with 2 million float64 parameters, the model's footprint in memory is straightforward to calculate. Each float64 parameter requires 8 bytes, resulting in: $2*10^6×8~\text{bytes}\approx128~\text{MB}$

This is the baseline memory allocation purely for the model's parameters. During backpropagation (when derivatives are computed for forces), memory requirements can increase substantially, as activations and gradients need to be stored for each parameter, which will depend on the model architecture.

Force calculation memory consumption

When forces are calculated (i.e., backpropagation with respect to atomic coordinates), memory consumption increases because the derivatives of the energy with respect to the coordinates must be computed and stored. This involves storing gradients for each of the $3N$ atomic coordinates. For a system with $N$ atoms this requires $3xNx8~\text{bytes}$.

This storage represents only the gradients of the atomic positions, but memory consumption increases significantly as it scales with the number of layers ($L$) and neurons ($H$) per layer. For each layer, the activation of the layer needs to be stored since it will be used later during the backward pass to compute gradients. The memory required to store these activations scales with the number of layers $L$, the number of hidden units $H$, and the number of interactions $NxM$ (interacting atoms):

$$ O(LxNxMxH) $$

In practice, the memory consumption will depend on the architecture of the model. But in general, this shows that memory consumption can grow significatnly as the depth of the model and the number of interactions increase. Using (automatic) mixed (lower) precision and gradient-checkpoints are viable strategies to mitigate some of these issues.

As an example, for a fully connected graph (NxM = N^2) that passes through 4 linear layers, each with output dimension 128 the total memory scales with $$N^2 * 4*128$$.

System-dependent memory consumption

Neighborlist calculation

The memory consumption also scales with the number of atoms $N$, primarily due to the need for constructing a neighbor list, which identifies nearby atoms that contribute to local interactions. Current efficient neighborlist implementations involves an $O(N^2)$ scaling step, where each atom is compared to every other atom within a specified cutoff distance. For a system with $1000$ atoms, this results in a neighbor list calculation with memory consumption scaling as $N^2 * (338~\text{bytes} + 18~\text{bytes} + 28~\text{bytes} = N^2 * 96~\text{bytes})$.

In this expression, the terms correspond to storing 3D positions (9 floats), atom pair distances (1 float), and atom pair indices (2 ints). For $3*10^3$ atoms, this leads to:

$$ (3*10^3)^2 x 96~\text{bytes} \approx 900~\text{MB} $$

Featurization of atom pair interaction

The neighborlist identifies the $NxM$ atom pairs that have interactions contributing to local interactions. Each of these interactions is typically parameterized using e.g. radial symmetry functions. If the number of radial symmetry functions is eg 32 this requires $NxMx32x8\text{bytes}$. For $N=1*10^3$ and $M=20$ the memory consumption is about 4 MB.

Experiment to evaluate memory footprint

The following plots are generated using a realistic set of hyperparamters and torch.float32.

image image

Timings

Only forward pass: image Forward and backward pass: image

import torch
from openmmtools.testsystems import WaterBox
from simtk import unit
from typing import List
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import time  # Import time module for timing


def measure_performance_for_edge_sizes(
    edge_sizes: List[float],
    potential_names: List[str],
):
    """
    Measures GPU memory utilization and computation time for force calculations
    for water boxes of different edge sizes across multiple potentials.

    Parameters
    ----------
    edge_sizes : List[float]
        A list of edge sizes (in nanometers) for the water boxes.
    potential_names : List[str]
        A list of potential names to use in the model setup.

    Returns
    -------
    List[dict]
        A list of dictionaries containing edge size, number of water molecules,
        potential name, memory usage in bytes, and computation time in seconds.
    """
    results = []
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    precicion = torch.float32
    for potential_name in potential_names:
        for edge_size in edge_sizes:
            # Generate water box with the given edge size
            test_system = WaterBox(box_edge=edge_size * unit.nanometer)
            positions = test_system.positions  # Positions in nanometers
            topology = test_system.topology

            # Extract atomic numbers and residue indices
            atomic_numbers = []
            residue_indices = []
            for residue_index, residue in enumerate(topology.residues()):
                for atom in residue.atoms():
                    atomic_numbers.append(atom.element.atomic_number)
                    residue_indices.append(residue_index)
            num_waters = len(list(topology.residues()))
            positions_in_nanometers = positions.value_in_unit(unit.nanometer)

            # Convert to torch tensors and move to GPU
            torch_atomic_numbers = torch.tensor(atomic_numbers, dtype=torch.long, device=device)
            torch_positions = torch.tensor(positions_in_nanometers, dtype=torch.float32, device=device, requires_grad=True)
            torch_atomic_subsystem_indices = torch.zeros_like(torch_atomic_numbers, dtype=torch.long, device=device)
            torch_total_charge = torch.zeros(num_waters, dtype=torch.float32, device=device)

            nnp_input = NNPInput(
                atomic_numbers=torch_atomic_numbers,
                positions=torch_positions,
                atomic_subsystem_indices=torch_atomic_subsystem_indices,
                total_charge=torch_total_charge,
            ).to(dtype=precicion)

            # Import your model setup function
            from modelforge.tests.helper_functions import setup_potential_for_test

            # Setup model
            model = setup_potential_for_test(
                potential_name,
                "inference",
                potential_seed=42,
                use_training_mode_neighborlist=False,
                simulation_environment='PyTorch',
            )

            model.to(device)
            model.to(precicion)
            total_params = sum(p.numel() for p in model.parameters())

            # Measure GPU memory usage and computation time
            torch.cuda.reset_peak_memory_stats(device=device)
            torch.cuda.synchronize()

            # Run forward pass and time it
            start_time = time.perf_counter()
            try:
                output = model(nnp_input.as_namedtuple())["per_molecule_energy"]
            except :
                print("Out of memory error during forward pass")
                continue

            try:
                F_training = -torch.autograd.grad(
                    output.sum(), nnp_input.positions, create_graph=True, retain_graph=True
                )[0]
            except :
                print("Out of memory error during backward pass")
                continue
            torch.cuda.synchronize()
            end_time = time.perf_counter()

            max_memory_allocated = torch.cuda.max_memory_allocated(device=device)
            computation_time = end_time - start_time

            results.append({
                'potential_name': f"{potential_name}: {total_params:.1e} params",
                'edge_size_nm': edge_size,
                'num_waters': num_waters,
                'memory_usage_bytes': max_memory_allocated,
                'computation_time_s': computation_time
            })

            # Clean up
            del nnp_input, output, model, 
            try:
                del F_training
            except:
                pass
            torch.cuda.empty_cache()

    return results

def plot_computation_time(results):
    """
    Plots computation time against the number of water molecules for multiple potentials.

    Parameters
    ----------
    results : List[dict]
        A list of dictionaries containing edge size, number of water molecules,
        potential name, memory usage in bytes, and computation time in seconds.
    """
    # Create a DataFrame for plotting
    df = pd.DataFrame(results)
    df['computation_time_ms'] = df['computation_time_s'] * 1000  # Convert seconds to milliseconds

    # Plot using seaborn
    sns.set(style="whitegrid")
    plt.figure(figsize=(10, 6))
    sns.lineplot(
        data=df,
        x='num_waters',
        y='computation_time_ms',
        hue='potential_name',
        units='potential_name',
        estimator=None,  # Do not aggregate data
        marker='o',
        linewidth=2,
        markersize=8
    )
    plt.title('Computation Time vs Number of Water Molecules for Different Potentials')
    plt.xlabel('Number of Water Molecules')
    plt.ylabel('Computation Time (ms)')
    plt.xticks(sorted(df['num_waters'].unique()))
    plt.legend(title='Potential Name')
    plt.tight_layout()
    plt.show()

# Example usage:
edge_sizes = [1.0, 1.5, 2.0, 2.5, 3.0, 3.5]  # Edge sizes in nanometers
potential_names = ['schnet', 'painn', 'physnet', 'ani2x', 'aimnet2', 'sake'] 

results = measure_performance_for_edge_sizes(
    edge_sizes=edge_sizes,
    potential_names=potential_names,
)
def plot_gpu_memory_usage(results):
    """
    Plots GPU memory usage against the number of water molecules for multiple potentials.

    Parameters
    ----------
    results : List[dict]
        A list of dictionaries containing edge size, number of water molecules,
        potential name, and memory usage in bytes.
    """
    # Create a DataFrame for plotting
    df = pd.DataFrame(results)
    df['memory_usage_mb'] = df['memory_usage_bytes'] / 1e6  # Convert bytes to megabytes

    # Plot using seaborn
    sns.set(style="whitegrid")
    plt.figure(figsize=(10, 6))
    sns.lineplot(
        data=df,
        x='num_waters',
        y='memory_usage_mb',
        units='potential_name',
        estimator=None,  # Do not aggregate data
        hue='potential_name',
        marker='o',
        linewidth=2,
        markersize=8,
    )
    plt.title('Backward pass: GPU Memory Usage vs Number of Water Molecules for Different Potentials')
    plt.xlabel('Number of Water Molecules')
    plt.ylabel('GPU Memory Usage (MB)')
    plt.xticks(sorted(df['num_waters'].unique()))
    plt.legend(title='Potential Name')
    plt.tight_layout()
    plt.show()

# Print the results
for result in results:
    print(f"Potential: {result['potential_name']}, "
          f"Edge Size: {result['edge_size_nm']} nm, "
          f"Number of Waters: {result['num_waters']}, "
          f"Memory Usage: {result['memory_usage_bytes']/1e6:.2f} MB, "
          f"Computation Time: {result['computation_time_s']*1000:.2f} ms")

# Plot the computation time
plot_computation_time(results)
plot_gpu_memory_usage(results)