Skip to content

Commit

Permalink
Fix sending to devcie
Browse files Browse the repository at this point in the history
  • Loading branch information
joserapa98 committed Apr 13, 2024
1 parent 175d59a commit 5cb1291
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 10 deletions.
4 changes: 2 additions & 2 deletions tests/models/test_mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,8 @@ def test_all_algorithms_cuda(self):
in_dim=2,
out_dim=2,
bond_dim=10,
boundary=boundary)
mpo = mpo.to(device)
boundary=boundary,
device=device)

for auto_stack in [True, False]:
for auto_unbind in [True, False]:
Expand Down
115 changes: 107 additions & 8 deletions tests/models/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,8 @@ def test_all_algorithms_cuda(self):
mps = tk.models.MPS(n_features=n_features,
phys_dim=5,
bond_dim=2,
boundary=boundary)
mps = mps.to(device)
boundary=boundary,
device=device)

for auto_stack in [True, False]:
for auto_unbind in [True, False]:
Expand Down Expand Up @@ -846,8 +846,8 @@ def test_all_algorithms_marginalize_with_matrix_cuda(self):
phys_dim=5,
bond_dim=2,
boundary=boundary,
in_features=in_features)
mps = mps.to(device)
in_features=in_features,
device=device)

embedding_matrix = torch.randn(5, 5, device=device)

Expand Down Expand Up @@ -995,19 +995,20 @@ def test_all_algorithms_marginalize_with_mpo_cuda(self):
phys_dim=5,
bond_dim=2,
boundary=mps_boundary,
in_features=in_features)
mps = mps.to(device)
in_features=in_features,
device=device)

mpo = tk.models.MPO(n_features=n_features - len(in_features),
in_dim=5,
out_dim=5,
bond_dim=2,
boundary=mpo_boundary)
boundary=mpo_boundary,
device=device)

# Send mpo to cuda before deparameterizing, so that all
# nodes are still in the state_dict of the model and are
# automatically sent to cuda
mpo = mpo.to(device)
# mpo = mpo.to(device)

# De-parameterize MPO nodes to only train MPS nodes
mpo = mpo.parameterize(set_param=False, override=True)
Expand Down Expand Up @@ -1211,6 +1212,58 @@ def test_partial_density(self):
for node in mps.mats_env:
assert node.grad is not None

def test_partial_density_cuda(self):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
for n_features in [3, 4, 5]:
for boundary in ['obc', 'pbc']:
print(n_features, boundary)
phys_dim = torch.randint(low=2, high=12,
size=(n_features,)).tolist()
bond_dim = torch.randint(low=2, high=10, size=(n_features,)).tolist()
bond_dim = bond_dim[:-1] if boundary == 'obc' else bond_dim

trace_sites = torch.randint(low=0,
high=n_features,
size=(n_features // 2,)).tolist()

mps = tk.models.MPS(n_features=n_features,
phys_dim=phys_dim,
bond_dim=bond_dim,
boundary=boundary,
in_features=trace_sites,
device=device)

in_dims = [phys_dim[i] for i in mps.in_features]
example = [torch.randn(1, d, device=device) for d in in_dims]
if example == []:
example = None

mps.trace(example)

assert mps.resultant_nodes
if trace_sites:
assert mps.data_nodes
assert set(mps.in_features) == set(trace_sites)

# MPS has to be reset, otherwise partial_density automatically
# calls the forward method that was traced when contracting the
# MPS with example
mps.reset()

# Here, trace_sites are now the out_features,
# not the in_features
density = mps.partial_density(trace_sites)
assert mps.resultant_nodes
assert mps.data_nodes
assert set(mps.out_features) == set(trace_sites)

assert density.shape == \
tuple([phys_dim[i] for i in mps.in_features] * 2)

density.sum().backward()
for node in mps.mats_env:
assert node.grad is not None

def test_mutual_information(self):
for n_features in [1, 2, 3, 4, 10]:
for boundary in ['obc', 'pbc']:
Expand Down Expand Up @@ -1869,6 +1922,52 @@ def test_partial_density(self):
for node in mps.mats_env:
assert node.grad is not None

def test_partial_density_cuda(self):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
for n_features in [3, 4, 5]:
phys_dim = torch.randint(low=2, high=12, size=(1,)).item()
bond_dim = torch.randint(low=2, high=10, size=(1,)).item()

trace_sites = torch.randint(low=0,
high=n_features,
size=(n_features // 2,)).tolist()

mps = tk.models.UMPS(n_features=n_features,
phys_dim=phys_dim,
bond_dim=bond_dim,
in_features=trace_sites,
device=device)

# batch x n_features x feature_dim
example = torch.randn(1, n_features // 2, phys_dim, device=device)
if example.numel() == 0:
example = None

mps.trace(example)

assert mps.resultant_nodes
if trace_sites:
assert mps.data_nodes
assert set(mps.in_features) == set(trace_sites)

# MPS has to be reset, otherwise partial_density automatically
# calls the forward method that was traced when contracting the
# MPS with example
mps.reset()

# Here, trace_sites are now the out_features,
# not the in_features
density = mps.partial_density(trace_sites)
assert mps.resultant_nodes
assert mps.data_nodes
assert set(mps.out_features) == set(trace_sites)

assert density.shape == (phys_dim,) * 2 * len(mps.in_features)

density.sum().backward()
for node in mps.mats_env:
assert node.grad is not None

def test_canonicalize_error(self):
mps = tk.models.UMPS(n_features=10,
phys_dim=2,
Expand Down

0 comments on commit 5cb1291

Please sign in to comment.