Skip to content

Commit

Permalink
Merge pull request #214 from RaulPPelaez/output_graph
Browse files Browse the repository at this point in the history
Make OutputModel aware of CUDA graph capturing.
  • Loading branch information
RaulPPelaez authored Sep 22, 2023
2 parents fd83954 + ba036a9 commit 4e24910
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 4 deletions.
42 changes: 42 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,48 @@ def test_torchscript_dynamic_shapes(model_name, device):
grad_outputs=grad_outputs,
)[0]

#Currently only tensornet is CUDA graph compatible
@mark.parametrize("model_name", ["tensornet"])
def test_cuda_graph_compatible(model_name):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
z, pos, batch = create_example_batch()
args = {"model": model_name,
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32 }
model = create_model(args).to(device="cuda")
z = z.to("cuda")
pos = pos.to("cuda").requires_grad_(True)
batch = batch.to("cuda")
model = torch.jit.script(model).to(device="cuda")
#Save and load the model, do not use a file
import io
buffer = io.BytesIO()
torch.jit.save(model, buffer)
buffer.seek(0)
model = torch.jit.load(buffer)
for _ in range(0, 5):
y, neg_dy = model(z, pos, batch=batch)
model = torch.cuda.make_graphed_callables(model, (z, pos, batch), allow_unused_input=True)
y2, neg_dy2 = model(z, pos, batch=batch)
assert torch.allclose(y, y2)
assert torch.allclose(neg_dy, neg_dy2)

@mark.parametrize("model_name", models.__all__)
def test_seed(model_name):
args = load_example_args(model_name, remove_prior=True)
Expand Down
23 changes: 19 additions & 4 deletions torchmdnet/models/output_modules.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from abc import abstractmethod, ABCMeta
from torch_scatter import scatter
from typing import Optional
from torchmdnet.models.utils import act_class_mapping, GatedEquivariantBlock
from torchmdnet.models.utils import act_class_mapping, GatedEquivariantBlock, check_stream_capturing
from torchmdnet.utils import atomic_masses
from torch_scatter import scatter
import torch
from torch import nn


from warnings import warn
__all__ = ["Scalar", "DipoleMoment", "ElectronicSpatialExtent"]


Expand All @@ -16,6 +15,7 @@ def __init__(self, allow_prior_model, reduce_op):
super(OutputModel, self).__init__()
self.allow_prior_model = allow_prior_model
self.reduce_op = reduce_op
self.dim_size = 0

def reset_parameters(self):
pass
Expand All @@ -25,7 +25,22 @@ def pre_reduce(self, x, v, z, pos, batch):
return

def reduce(self, x, batch):
return scatter(x, batch, dim=0, reduce=self.reduce_op)
is_capturing = (
x.is_cuda
and check_stream_capturing()
)
if not x.is_cuda or not is_capturing:
self.dim_size = int(batch.max().item() + 1)
if is_capturing:
assert (
self.dim_size > 0
), "Warming up is needed before capturing the model into a CUDA graph"
warn(
"CUDA graph capture will lock the batch to the current number of samples ({}). Changing this will result in a crash".format(
self.dim_size
)
)
return scatter(x, batch, dim=0, dim_size=self.dim_size, reduce=self.reduce_op)

def post_reduce(self, x):
return x
Expand Down
52 changes: 52 additions & 0 deletions torchmdnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,58 @@ def forward(self, x, v):
x = self.act(x)
return x, v

def _compile_check_stream_capturing():
"""
Compiles the check_stream_capturing function.
This is required because the builtin torch function that does this is not scriptable.
"""
# Check if the function is already compiled
if hasattr(torch.ops.torch_extension, "is_stream_capturing"):
return
from torch.utils.cpp_extension import load_inline
cpp_source = '''
#include <torch/script.h>
#if defined(WITH_CUDA)
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime_api.h>
#endif
bool is_stream_capturing() {
#if defined(WITH_CUDA)
auto current_stream = at::cuda::getCurrentCUDAStream().stream();
cudaStreamCaptureStatus capture_status;
cudaError_t err = cudaStreamGetCaptureInfo(current_stream, &capture_status, nullptr);
if (err != cudaSuccess) {
throw std::runtime_error(cudaGetErrorString(err));
}
return capture_status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive;
#else
return false;
#endif
}
static auto registry = torch::RegisterOperators()
.op("torch_extension::is_stream_capturing", &is_stream_capturing);
'''

# Create an inline extension
load_inline(
"is_stream_capturing",
cpp_sources=cpp_source,
functions=["is_stream_capturing"],
with_cuda=torch.cuda.is_available(),
extra_cflags=["-DWITH_CUDA"] if torch.cuda.is_available() else None,
verbose=True,
)
_compile_check_stream_capturing()
@torch.jit.script
def check_stream_capturing():
"""
Returns True if the current CUDA stream is capturing.
Returns False if CUDA is not available or the current stream is not capturing.
This utility is required because the builtin torch function that does this is not scriptable.
"""
return torch.ops.torch_extension.is_stream_capturing()

rbf_class_mapping = {"gauss": GaussianSmearing, "expnorm": ExpNormalSmearing}

Expand Down

0 comments on commit 4e24910

Please sign in to comment.