Skip to content

Commit

Permalink
Merge branch 'pytest-new' into pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
InnocentBug committed Nov 8, 2024
2 parents 2ae818d + 0b596f9 commit 0b20f3c
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 28 deletions.
84 changes: 81 additions & 3 deletions python/pytest/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os

import torch
import pytest


Expand All @@ -23,8 +23,6 @@ def device(ptens_cuda_support):

if "cuda" in device:
assert ptens_cuda_support
import torch

assert torch.cuda.is_available()

return device
Expand All @@ -33,3 +31,83 @@ def device(ptens_cuda_support):
@pytest.fixture(scope="session")
def float_epsilon():
return 1e-5


def numerical_grad_sum(fn, x, h):
grad = torch.zeros_like(x)
for i in range(x.numel()):
xp = x.clone()
xp.view(-1)[i] += h
xm = x.clone()
xm.view(-1)[i] -= h

# Using torch.sum here, because torch autograd, calcualtes the partial diff of a scalar valued functino.
# With sum, we can a scalar valued function, and the summed parts factorize
num_diff = torch.sum(fn(xp)) - torch.sum(fn(xm))
grad_value = num_diff / (2 * float(h))
grad.view(-1)[i] = grad_value
return grad

@pytest.mark.parametrize("m,c", [(0., 3.), (0.5, -0.3), (-0.8, 0.2)])
def test_numerical_grad_linear(m, c):
def linear(x):
return m*x + c

x = torch.randn((5,10))
grad = numerical_grad_sum(linear, x, 1e-2)
ana_grad = torch.ones_like(x) * m

allclose = torch.allclose(ana_grad, grad, rtol=1e-3, atol=1e-5)
if not allclose:
print(f"Max absolute difference: {torch.max(torch.abs(ana_grad - grad))}")
print(f"Mean absolute difference: {torch.mean(torch.abs(ana_grad - grad))}")
print(f"Numerical grad range: [{grad.min()}, {grad.max()}]")
print(f"Analytical grad range: [{ana_grad.min()}, {ana_grad.max()}]")

assert allclose

@pytest.mark.parametrize("a,b,c", [(1. ,2., 3.), (-0.5, 0.4, -0.3), (1.2, -0.8, 0.2)])
def test_numerical_grad_square(a, b, c):
from torch.autograd.gradcheck import gradcheck
def square(x):
return a*x**2 + b*x + c

x = torch.randn((5,10))
grad = numerical_grad_sum(square, x, 1e-3)
ana_grad = 2*a*x + b

allclose = torch.allclose(ana_grad, grad, rtol=1e-2, atol=1e-2)

if not allclose:
print(f"Max absolute difference: {torch.max(torch.abs(ana_grad - grad))}")
print(f"Mean absolute difference: {torch.mean(torch.abs(ana_grad - grad))}")
print(f"Numerical grad range: [{grad.min()}, {grad.max()}]")
print(f"Analytical grad range: [{ana_grad.min()}, {ana_grad.max()}]")

assert allclose
x.requires_grad_()
assert gradcheck(square, (x,), eps=1e-2, rtol=1e-2, atol=1e-2)


# Add a test against autograd for validation
def test_against_autograd():
def complex_function(x):
return torch.sum(torch.sin(x) + x**2)

x = torch.randn(5, 10, requires_grad=True)

# Compute gradient using autograd
y = complex_function(x)
y.backward()
autograd_grad = x.grad

# Compute gradient using numerical method
numerical_grad = numerical_grad_sum(complex_function, x.detach(), 1e-3)

allclose = torch.allclose(autograd_grad, numerical_grad, rtol=1e-2, atol=1e-2)
if not allclose:
print(f"Max absolute difference: {torch.max(torch.abs(autograd_grad - numerical_grad))}")
print(f"Mean absolute difference: {torch.mean(torch.abs(autograd_grad - numerical_grad))}")


assert allclose
51 changes: 28 additions & 23 deletions python/pytest/test_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import ptens
import pytest
import ptens_base
from conftest import numerical_grad_sum

from torch.autograd.gradcheck import gradcheck


def test_bug1(device):
nnodes = 15
graph = ptens.ggraph.random(nnodes, 0.5)
print(graph)
subgraphs = [ptens.subgraph.trivial(), ptens.subgraph.edge()]
node_values = torch.rand(nnodes, 1, requires_grad=True)

Expand All @@ -19,41 +19,46 @@ def test_bug1(device):
gather_features = ptens.subgraphlayer0.gather(sg, node_attributes)
result = torch.sum(gather_features)
result.backward()
print(node_values.grad)

# linmap_features = ptens.subgraphlayer0.linmaps(node_attributes)
result = torch.sum(node_attributes)
result.backward()
print(node_attributes.grad)

check = gradcheck(ptens.subgraphlayer0.gather, (sg, node_attributes), eps=1e-3)
print(check)
assert check



class TestGather(object):

def backprop(self,cls,fn,N,_nc):
if(cls==ptens.ptensor0):
x=cls.randn(N,_nc)
else:
atoms=ptens_base.atomspack.random(N,0.3)
x=cls.randn(atoms,_nc)
h=1e-3

def backprop(self,cls, N,nc, device):
atoms=ptens_base.atomspack.random(N, nc, 0.3)
x=cls.randn(atoms,nc).to(device)
x.requires_grad_()
G=ptens.ggraph.random(N,0.3)
z=fn(x,G)
atoms2 = G.subgraphs(ptens.subgraph.trivial())

check = gradcheck(cls.gather, (atoms2, x), eps=self.h)
assert check

z = cls.gather(atoms2, x)
loss=torch.sum(z)
loss.backward()
xgrad=x.grad


fn = lambda x: cls.gather(atoms2, x)
xgrad2 = numerical_grad_sum(fn, x, self.h)

testvec=z.randn_like()
loss=z.inp(testvec).to('cuda')
loss.backward(torch.tensor(1.0))
xgrad=x.get_grad()
assert torch.allclose(xgrad, xgrad2, rtol=1e-2, atol=1e-2)

xeps=x.randn_like()
z=fn(x+xeps,G)
xloss=z.inp(testvec).to('cuda')
assert(torch.allclose(xloss-loss,xeps.inp(xgrad),rtol=1e-3, atol=1e-4))


@pytest.mark.parametrize(('N', 'nc'), [(8, 1), (1, 2), (16, 4)])
def test_gather0(self,N, nc, device):
self.backprop(ptens.ptensorlayer0,N,nc, device)

@pytest.mark.parametrize('nc', [1, 2, 4])
def test_gather(self,nc):
self.backprop(ptens.ptensor0,ptens.gather,8,nc)
@pytest.mark.parametrize(('N', 'nc'), [(8, 1), (1, 2), (16, 4)])
def test_gather1(self,N, nc, device):
self.backprop(ptens.ptensorlayer0,N,nc, device)
5 changes: 3 additions & 2 deletions python/src/ptens/ptensorlayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@

class ptensorlayer(torch.Tensor):

covariant_functions=[torch.Tensor.to,torch.Tensor.add,torch.Tensor.sub,torch.relu,torch.nn.functional.linear]
covariant_functions=[torch.Tensor.to,torch.Tensor.add,torch.Tensor.sub,torch.relu,torch.nn.functional.linear, torch.Tensor.clone]

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func in ptensorlayer.covariant_functions:
r= super().__torch_function__(func, types, args, kwargs)
r.atoms=args[0].atoms
if hasattr(args[0], "atoms"):
r.atoms=args[0].atoms
else:
r= super().__torch_function__(func, types, args, kwargs)
if isinstance(r,torch.Tensor):
Expand Down

0 comments on commit 0b20f3c

Please sign in to comment.