Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.jit.script profile guided optimisations produce errors in aev_computer gradients #628

Open
sef43 opened this issue Apr 4, 2023 · 1 comment

Comments

@sef43
Copy link

sef43 commented Apr 4, 2023

Hi, I have found that with pytorch 1.13 and 2.0 (not with pytorch<=1.12) the torch.jit.script profile guided optimisations (that are on by default) cause significant errors in the position gradients calculated via backpropagation of aev_computer when using a CUDA device. This is demonstrated in issue openmm/openmm-ml#50.

An example is shown below, manually turning off the jit optimizations gives accurate forces:

from torchani.neurochem import parse_neurochem_resources, Constants
from torchani.aev import AEVComputer
import torch
import numpy as np


class Model(torch.nn.Module):
   def __init__(self, device):
      super(Model, self).__init__()
      info_file_path='ani-2x_8x.info'
      const_file, _,_,_ = parse_neurochem_resources(info_file_path)
      consts = Constants(const_file)
      self.aev_computer = AEVComputer(**consts)
      self.aev_computer.to(device)

   def forward(self, species, positions):
      incoords = positions
      inspecies = species
      aev = self.aev_computer((inspecies.unsqueeze(0), incoords.unsqueeze(0)))
      sumaevs = torch.mean(aev.aevs)

      return sumaevs

## setup
N=100
species = torch.randint(0, 7, (N,), device="cuda")
pos = np.random.random((N, 3))

for optimize in [True, False]: 
   print("JIT optimize = ", optimize)

   torch._C._jit_set_profiling_executor(optimize)
   torch._C._jit_set_profiling_mode(optimize)

   model = Model("cuda")
   model = torch.jit.script(model)

   grads=[]
   for i in range(10):
      incoords = torch.tensor(pos, dtype=torch.float32, requires_grad=True, device="cuda")
      result = model(species, incoords)
      result.backward(retain_graph=True)
      grad = incoords.grad
      grads.append(grad.cpu().numpy())
      print(i,"max percentage error: ",np.max(100.0*np.abs((grads[0]-grads[-1])/grads[0])))

output I get on an RTX3090 is:

JIT optimize =  True
Downloading ANI model parameters ...
0 max percentage error:  0.0
1 max percentage error:  0.00055674225
2 max percentage error:  217.80972
3 max percentage error:  217.80959
4 max percentage error:  217.81003
5 max percentage error:  217.80975
6 max percentage error:  217.80972
7 max percentage error:  217.81082
8 max percentage error:  217.80956
9 max percentage error:  217.81024
JIT optimize =  False
0 max percentage error:  0.0
1 max percentage error:  0.0003876826
2 max percentage error:  0.0002178617
3 max percentage error:  0.00021537923
4 max percentage error:  0.0005815239
5 max percentage error:  0.0010768962
6 max percentage error:  0.00017895782
7 max percentage error:  0.00035465648
8 max percentage error:  0.00039845158
9 max percentage error:  0.00018266498

I have found a workaround to remove the errors is to replace a ** operation with a torch.float_power: 172b6fe,

@yueyericardo
Copy link
Contributor

Thanks for reporting the issue!

This is a problem of NVFuser. A bug report has been filed at pytorch/pytorch#84510

The minimal reproducible example I extracted from the angular function is the following:

def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
                  ShfA: Tensor, vectors12: Tensor) -> Tensor:
    vectors12 = vectors12.view(2, -1, 3, 1, 1, 1, 1)
    cos_angles = vectors12.prod(0).sum(1)

    ret = (cos_angles + ShfZ) * Zeta * ShfA * 2
    return ret.flatten(start_dim=1)

Replace a ** operation with a torch.float_power will not solve the root cause of this problem.

At this moment, I would recommend disabling NVFuser by running the following:

torch._C._jit_set_nvfuser_enabled(False)

This will change to NNC fuser (https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#fusers) instead of nvfuser, which I tested is working correctly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants