Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with CUDA error 59: Device-side assert triggered #1

Open
andyco98 opened this issue Sep 29, 2021 · 9 comments
Open

Problem with CUDA error 59: Device-side assert triggered #1

andyco98 opened this issue Sep 29, 2021 · 9 comments
Assignees
Labels
bug Something isn't working

Comments

@andyco98
Copy link

Hello Eric!

First, thank you very much for this very interesting work!

I was trying to reproduce the code described on the demo "Cell Detection with Contour Proposal Networks.ipynb" and everything works fine until I start training the model. I get after 2-3 Epochs the error:

/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [37,0,0], thread: [0,0,0] Assertion index >= -sizes[i] && index < sizes[i] && "index out of bounds" failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [37,0,0], thread: [1,0,0] Assertion index >= -sizes[i] && index < sizes[i] && "index out of bounds" failed.
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

I have read in internet that this common Pytorch error could be caused by an indexing problem with the labels, but I was unable to solve it. Do you know how it can be solved?

Thank you very much in advance!
André

@ericup
Copy link
Collaborator

ericup commented Sep 29, 2021

Hi André,

Happy to hear that you are interested in this work!
Also thank you very much for reporting this issue.

The issue seems to be an underfow that occurs in the covolutions of at least one of the regression heads when float16 is used.
Disabling AMP should prevent this error.

In the mentioned demo Notebook you can change this line

amp=torch.cuda.is_available(),  # Automatic Mixed Precision (https://pytorch.org/docs/stable/amp.html)

to

amp=False,  # Automatic Mixed Precision (https://pytorch.org/docs/stable/amp.html)

Please let me know if this helps.
I'll see if there are other ways to prevent this with AMP enabled.

Best regards,
Eric

@ericup ericup added the bug Something isn't working label Sep 29, 2021
@ericup ericup self-assigned this Sep 29, 2021
@andyco98
Copy link
Author

Thank you very much Eric! It works fine now! It will be also great to find a way to still use AMP, but so I can start training the model with promising results.
If it is not too much asking, how did you find out that the problem was the enabled AMP?

Best regards,
André

@ericup
Copy link
Collaborator

ericup commented Sep 29, 2021

I agree, AMP is nice to have!
The error message above only tells you that there is an 'index out of bounds'.
There is an important indexing that uses the predicted contour coordinates.
The bounds are checked with torch.clamp, but you would get the same error if your indices contain NaN.
Normally, this should not be a problem unless something is numerically unstable, which can be a consequence of using float16.

Best regards,
Eric

@makangzhe
Copy link

makangzhe commented Nov 23, 2021

when i set "AMP=False" , I try to run the code described on the demo "demo-binary.ipynb" and"demo-multiclass.ipynb" in GPU,I get the error:
/pytorch/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [16,0,0] Assertion t >= 0 && t < n_classes failed. Epoch 1: 0%| | 0/512 [00:02<?, ?it/s]

when i run the code in cpu,it is no problem.

` ~/anaconda3/envs/py38torch19/lib/python3.8/site-packages/torch/autograd/init.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
152 retain_graph = create_graph
153
--> 154 Variable.execution_engine.run_backward(
155 tensors, grad_tensors
, retain_graph, create_graph, inputs,
156 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag

RuntimeError: transform: failed to synchronize: cudaErrorAssert: device-side assert triggered `

can you help me ? Thank you very much in advance!

@ericup
Copy link
Collaborator

ericup commented Nov 23, 2021

Hi, thanks for posting!
I suspect this might be an unrelated issue.

Could you please run python -m torch.utils.collect_env and post the output?
Also, could you add these two lines to the top of your notebook run it again on GPU and post the entire stacktrace?

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

It seems to be a problem of the loss calculation of the score head.
Please make sure that you are working with the original number of classes.
In case you are using an older PyTorch release, could you update it and rerun the Notebook?

@ericup
Copy link
Collaborator

ericup commented Mar 5, 2022

@andyco98 Could you try to update to the latest version via

pip install git+https://github.com/FZJ-INM1-BDA/celldetection.git

and add the following line after the model definition:

cd.wrap_module_(model, cd.models.ReadOut, cd.models.NoAmp)

This practically disables AMP for the readout layers, but still allows you to use AMP everywhere else.
Please let me know if it helps.

@tommy2k0
Copy link

tommy2k0 commented May 4, 2022

Hello Eric,

I also get the same error while trying to run the code in the notebooks on GPU, but it runs without any issue on the CPU. I have set amp=False in the config.

Here is the output from running python -m torch.utils.collect_env

Collecting environment information...
PyTorch version: 1.10.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.8.10 (default, Mar 15 2022, 12:22:08) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-109-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB

Nvidia driver version: 510.47.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.4.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.4.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.22.3
[pip3] torch==1.10.0+cu111
[pip3] torch-tb-profiler==0.4.0
[pip3] torchaudio==0.10.0+rocm4.1
[pip3] torchvision==0.11.0+cu111
[conda] Could not collect

And here is the complete stacktrace

Epoch 1/100:   0%|          | 0/512 [00:00<?, ?it/s]/pytorch/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [27,0,0] Assertion `t >= 0 && t < n_classes` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [28,0,0] Assertion `t >= 0 && t < n_classes` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [29,0,0] Assertion `t >= 0 && t < n_classes` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [30,0,0] Assertion `t >= 0 && t < n_classes` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:455: nll_loss_backward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [31,0,0] Assertion `t >= 0 && t < n_classes` failed.
Epoch 1/100 - loss 54.882:   0%|          | 0/512 [00:01<?, ?it/s] 
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/home/dev/instance_seg/CPN/cells.ipynb Cell 24' in <cell line: 1>()
      1 for epoch in range(1, conf.epochs + 1):
----> 2     train_epoch(model, train_loader, conf.device, optimizer, f'Epoch {epoch}/{conf.epochs}', scaler, scheduler)
      3     if epoch % 10 == 0:
      4         show_results(model, test_loader, conf.device)

/home/dev/instance_seg/CPN/cells.ipynb Cell 22' in train_epoch(model, data_loader, device, optimizer, desc, scaler, scheduler, progress)
     13     tq.desc = ' - '.join(info)
     14 if scaler is None:
---> 15     loss.backward()
     16     optimizer.step()
     17 else:

File ~/dev/venv/lib/python3.8/site-packages/torch/_tensor.py:307, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    298 if has_torch_function_unary(self):
    299     return handle_torch_function(
    300         Tensor.backward,
    301         (self,),
   (...)
    305         create_graph=create_graph,
    306         inputs=inputs)
--> 307 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File ~/dev/venv/lib/python3.8/site-packages/torch/autograd/__init__.py:154, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    151 if retain_graph is None:
    152     retain_graph = create_graph
--> 154 Variable._execution_engine.run_backward(
    155     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    156     allow_unreachable=True, accumulate_grad=True)

RuntimeError: CUDA error: device-side assert triggered 

Many thanks for your help in advance!

@ericup
Copy link
Collaborator

ericup commented May 5, 2022

Hello @tommy2k0,
Thank you very much for posting and for providing more information!

To me it looks like the version combination of PyTorch, CUDA and cudnn might be causing this problem.

Could you try to reinstall PyTorch with the desired CUDA version using the install command from here?
In case this is relevant: You might also consider using conda along with the binaries it provides, so that you don't need to rely on the local CUDA and cudnn installations.

@tommy2k0
Copy link

Seems that was indeed the problem. Many thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants