diff --git a/tests/models/test_mpo.py b/tests/models/test_mpo.py index daf7f03..76915b2 100644 --- a/tests/models/test_mpo.py +++ b/tests/models/test_mpo.py @@ -372,8 +372,8 @@ def test_all_algorithms_cuda(self): in_dim=2, out_dim=2, bond_dim=10, - boundary=boundary) - mpo = mpo.to(device) + boundary=boundary, + device=device) for auto_stack in [True, False]: for auto_unbind in [True, False]: diff --git a/tests/models/test_mps.py b/tests/models/test_mps.py index 41a054d..28b0ab5 100644 --- a/tests/models/test_mps.py +++ b/tests/models/test_mps.py @@ -438,8 +438,8 @@ def test_all_algorithms_cuda(self): mps = tk.models.MPS(n_features=n_features, phys_dim=5, bond_dim=2, - boundary=boundary) - mps = mps.to(device) + boundary=boundary, + device=device) for auto_stack in [True, False]: for auto_unbind in [True, False]: @@ -846,8 +846,8 @@ def test_all_algorithms_marginalize_with_matrix_cuda(self): phys_dim=5, bond_dim=2, boundary=boundary, - in_features=in_features) - mps = mps.to(device) + in_features=in_features, + device=device) embedding_matrix = torch.randn(5, 5, device=device) @@ -995,19 +995,20 @@ def test_all_algorithms_marginalize_with_mpo_cuda(self): phys_dim=5, bond_dim=2, boundary=mps_boundary, - in_features=in_features) - mps = mps.to(device) + in_features=in_features, + device=device) mpo = tk.models.MPO(n_features=n_features - len(in_features), in_dim=5, out_dim=5, bond_dim=2, - boundary=mpo_boundary) + boundary=mpo_boundary, + device=device) # Send mpo to cuda before deparameterizing, so that all # nodes are still in the state_dict of the model and are # automatically sent to cuda - mpo = mpo.to(device) + # mpo = mpo.to(device) # De-parameterize MPO nodes to only train MPS nodes mpo = mpo.parameterize(set_param=False, override=True) @@ -1211,6 +1212,58 @@ def test_partial_density(self): for node in mps.mats_env: assert node.grad is not None + def test_partial_density_cuda(self): + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + for n_features in [3, 4, 5]: + for boundary in ['obc', 'pbc']: + print(n_features, boundary) + phys_dim = torch.randint(low=2, high=12, + size=(n_features,)).tolist() + bond_dim = torch.randint(low=2, high=10, size=(n_features,)).tolist() + bond_dim = bond_dim[:-1] if boundary == 'obc' else bond_dim + + trace_sites = torch.randint(low=0, + high=n_features, + size=(n_features // 2,)).tolist() + + mps = tk.models.MPS(n_features=n_features, + phys_dim=phys_dim, + bond_dim=bond_dim, + boundary=boundary, + in_features=trace_sites, + device=device) + + in_dims = [phys_dim[i] for i in mps.in_features] + example = [torch.randn(1, d, device=device) for d in in_dims] + if example == []: + example = None + + mps.trace(example) + + assert mps.resultant_nodes + if trace_sites: + assert mps.data_nodes + assert set(mps.in_features) == set(trace_sites) + + # MPS has to be reset, otherwise partial_density automatically + # calls the forward method that was traced when contracting the + # MPS with example + mps.reset() + + # Here, trace_sites are now the out_features, + # not the in_features + density = mps.partial_density(trace_sites) + assert mps.resultant_nodes + assert mps.data_nodes + assert set(mps.out_features) == set(trace_sites) + + assert density.shape == \ + tuple([phys_dim[i] for i in mps.in_features] * 2) + + density.sum().backward() + for node in mps.mats_env: + assert node.grad is not None + def test_mutual_information(self): for n_features in [1, 2, 3, 4, 10]: for boundary in ['obc', 'pbc']: @@ -1869,6 +1922,52 @@ def test_partial_density(self): for node in mps.mats_env: assert node.grad is not None + def test_partial_density_cuda(self): + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + for n_features in [3, 4, 5]: + phys_dim = torch.randint(low=2, high=12, size=(1,)).item() + bond_dim = torch.randint(low=2, high=10, size=(1,)).item() + + trace_sites = torch.randint(low=0, + high=n_features, + size=(n_features // 2,)).tolist() + + mps = tk.models.UMPS(n_features=n_features, + phys_dim=phys_dim, + bond_dim=bond_dim, + in_features=trace_sites, + device=device) + + # batch x n_features x feature_dim + example = torch.randn(1, n_features // 2, phys_dim, device=device) + if example.numel() == 0: + example = None + + mps.trace(example) + + assert mps.resultant_nodes + if trace_sites: + assert mps.data_nodes + assert set(mps.in_features) == set(trace_sites) + + # MPS has to be reset, otherwise partial_density automatically + # calls the forward method that was traced when contracting the + # MPS with example + mps.reset() + + # Here, trace_sites are now the out_features, + # not the in_features + density = mps.partial_density(trace_sites) + assert mps.resultant_nodes + assert mps.data_nodes + assert set(mps.out_features) == set(trace_sites) + + assert density.shape == (phys_dim,) * 2 * len(mps.in_features) + + density.sum().backward() + for node in mps.mats_env: + assert node.grad is not None + def test_canonicalize_error(self): mps = tk.models.UMPS(n_features=10, phys_dim=2,