diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2573642..4601762 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,7 @@ on: workflow_dispatch: jobs: - tests: + tests_basics: runs-on: ubuntu-latest @@ -33,6 +33,21 @@ jobs: - name: Test decompositions if: always() run: pytest tests/decompositions/test_svd_decompositions.py --cov=tensorkrowch + + tests_models: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install dependencies + run: python -m pip install --upgrade pip torch opt_einsum + - name: Install pytest + run: pip install pytest pytest-cov - name: Test MPS if: always() run: pytest tests/models/test_mps.py --cov=tensorkrowch @@ -44,4 +59,4 @@ jobs: run: pytest tests/models/test_peps.py --cov=tensorkrowch - name: Test Tree if: always() - run: pytest tests/models/test_tree.py --cov=tensorkrowch + run: pytest tests/models/test_tree.py --cov=tensorkrow \ No newline at end of file diff --git a/docs/_build/doctest/output.txt b/docs/_build/doctest/output.txt index 1b925f4..ecdef6d 100644 --- a/docs/_build/doctest/output.txt +++ b/docs/_build/doctest/output.txt @@ -1,22 +1,6 @@ -Results of doctest builder run on 2024-05-13 15:19:16 +Results of doctest builder run on 2024-05-14 21:33:03 ===================================================== -Document: models ----------------- -1 items passed all tests: - 118 tests in default -118 tests in 1 items. -118 passed and 0 failed. -Test passed. - -Document: embeddings --------------------- -1 items passed all tests: - 45 tests in default -45 tests in 1 items. -45 passed and 0 failed. -Test passed. - Document: components -------------------- ********************************************************************** @@ -27,8 +11,8 @@ Expected: tensor([[-0.2799, -0.4383, -0.8387], [ 1.6225, -0.3370, -1.2316]]) Got: - tensor([[-0.3491, 1.6723, 0.1550], - [-1.2811, 1.7880, 1.4058]]) + tensor([[-1.0631, -0.4740, 0.1857], + [ 0.1684, 0.9431, -0.5282]]) ********************************************************************** File "../tensorkrowch/components.py", line ?, in default Failed example: @@ -36,7 +20,7 @@ Failed example: Expected: tensor(-1.5029) Got: - tensor(3.3910) + tensor(-0.7681) ********************************************************************** File "../tensorkrowch/components.py", line ?, in default Failed example: @@ -44,7 +28,7 @@ Failed example: Expected: tensor([ 1.3427, -0.7752, -2.0704]) Got: - tensor([-1.6301, 3.4603, 1.5608]) + tensor([-0.8947, 0.4691, -0.3424]) ********************************************************************** File "../tensorkrowch/components.py", line ?, in default Failed example: @@ -53,8 +37,8 @@ Expected: tensor([[ 1.4005, -0.0521, -1.2091], [ 1.9844, 0.3513, -0.5920]]) Got: - tensor([[ 0.1392, -1.0223, -1.4595], - [-0.8048, 0.5959, -0.3506]]) + tensor([[ 0.7595, 0.6039, -0.3528], + [-0.9391, 0.5732, -0.1683]]) ********************************************************************** File "../tensorkrowch/components.py", line ?, in default Failed example: @@ -62,7 +46,7 @@ Failed example: Expected: tensor(0.3139) Got: - tensor(-0.4837) + tensor(0.0794) ********************************************************************** File "../tensorkrowch/components.py", line ?, in default Failed example: @@ -70,7 +54,7 @@ Failed example: Expected: tensor([ 1.6925, 0.1496, -0.9006]) Got: - tensor([-0.3328, -0.2132, -0.9050]) + tensor([-0.0898, 0.5886, -0.2606]) ********************************************************************** File "../tensorkrowch/components.py", line ?, in default Failed example: @@ -79,8 +63,8 @@ Expected: tensor([[ 0.2111, -0.9551, -0.7812], [ 0.2254, 0.3381, -0.2461]]) Got: - tensor([[-1.0223, -0.8771, -2.0712], - [-0.8749, -0.0445, 0.0772]]) + tensor([[-0.5163, -1.9652, 0.5784], + [ 1.1581, 1.1339, -0.2820]]) ********************************************************************** File "../tensorkrowch/components.py", line ?, in default Failed example: @@ -88,7 +72,7 @@ Failed example: Expected: tensor(0.5567) Got: - tensor(0.7768) + tensor(1.1973) ********************************************************************** File "../tensorkrowch/components.py", line ?, in default Failed example: @@ -96,7 +80,7 @@ Failed example: Expected: tensor([0.0101, 0.9145, 0.3784]) Got: - tensor([0.1042, 0.5887, 1.5192]) + tensor([1.1840, 2.1914, 0.6084]) ********************************************************************** File "../tensorkrowch/components.py", line ?, in default Failed example: @@ -105,8 +89,8 @@ Expected: tensor([[ 1.5570, 1.8441, -0.0743], [ 0.4572, 0.7592, 0.6356]]) Got: - tensor([[-0.5176, -0.2405, -0.4334], - [-0.5865, -0.3753, 0.4047]]) + tensor([[ 0.4585, -0.9622, -1.3563], + [ 1.5910, -1.6511, 1.4884]]) ********************************************************************** File "../tensorkrowch/components.py", line ?, in default Failed example: @@ -114,7 +98,7 @@ Failed example: Expected: tensor(2.6495) Got: - tensor(1.0781) + tensor(3.2324) ********************************************************************** File "../tensorkrowch/components.py", line ?, in default Failed example: @@ -122,7 +106,7 @@ Failed example: Expected: tensor([1.6227, 1.9942, 0.6399]) Got: - tensor([0.7822, 0.4458, 0.5930]) + tensor([1.6557, 1.9110, 2.0137]) ********************************************************************** File "../tensorkrowch/components.py", line ?, in default Failed example: @@ -153,17 +137,17 @@ Got: Node( name: my_node tensor: - tensor([[[-0.6988, -1.1643], - [-0.4288, 1.4445], - [ 0.3650, 0.1796], - [-0.7360, -0.4975], - [ 0.2338, -0.2614]], + tensor([[[ 0.0756, 0.2978], + [-0.9519, -0.6720], + [-0.5188, -0.1702], + [-0.5977, 0.7331], + [ 0.2803, -0.6546]], - [[ 1.4763, -0.5471], - [ 0.1191, 0.6986], - [ 0.4719, 0.6936], - [-0.3989, -0.9837], - [-0.7574, -0.4083]]]) + [[ 0.1922, -1.1753], + [ 0.1176, 0.1418], + [ 0.3928, 0.8509], + [-0.5760, 0.9210], + [ 2.1306, -1.1455]]]) axes: [left input @@ -202,17 +186,17 @@ Got: Node( name: node tensor: - tensor([[[ 5.5549e-01, -7.1141e-01], - [ 1.3587e-02, -7.7982e-01], - [-2.4167e-01, 3.1204e-01], - [ 1.2386e-03, -4.0765e-02], - [ 1.2701e-01, -1.3453e-01]], + tensor([[[ 1.0910, -0.5809], + [ 0.1853, 0.3664], + [-0.0117, 0.5005], + [-0.0661, 0.0234], + [-0.8290, 1.9922]], - [[-1.8415e+00, 3.0231e-01], - [-1.0302e+00, 1.9352e+00], - [-2.3677e+00, 9.1492e-01], - [ 1.0165e+00, -4.1301e-01], - [-1.0887e+00, 2.1221e-01]]]) + [[ 0.1525, 0.9578], + [ 0.7693, 0.5078], + [ 0.1810, -1.0350], + [ 0.3174, -0.2801], + [ 0.4202, -0.4585]]]) axes: [axis_0 axis_1 @@ -264,17 +248,17 @@ Got: name: my_paramnode tensor: Parameter containing: - tensor([[[-0.7826, 0.0821], - [ 0.3175, 0.7941], - [ 1.5045, -0.8090], - [ 0.3632, -0.3924], - [ 1.7720, 2.3521]], + tensor([[[ 0.7591, 0.2141], + [ 0.4399, -2.5171], + [-3.7951, 0.2628], + [-0.2629, 2.4291], + [ 1.8276, 0.2167]], - [[ 0.7189, 0.9170], - [ 0.0166, 0.3871], - [ 1.1382, -1.4283], - [-0.7954, 0.0954], - [-0.3044, 0.9413]]], requires_grad=True) + [[-0.4733, 1.6149], + [ 0.0621, 0.6297], + [ 0.4117, 0.4973], + [-0.1463, -0.2648], + [-1.0274, -1.1514]]], requires_grad=True) axes: [left input @@ -315,17 +299,17 @@ Got: name: paramnode tensor: Parameter containing: - tensor([[[-0.2850, -0.0664], - [-0.2535, 1.0285], - [-0.4207, 0.8395], - [-0.7933, 0.1413], - [-1.4896, -0.7656]], + tensor([[[-1.3839, -1.1648], + [-0.2835, -0.6608], + [ 0.1896, -0.8977], + [-1.0581, 0.9863], + [-1.2442, -0.7461]], - [[-0.4232, 0.1107], - [-0.5041, -1.5471], - [ 0.6676, 1.5313], - [ 1.4140, -1.2289], - [ 1.2960, -1.4492]]], requires_grad=True) + [[ 0.2663, -0.2669], + [-0.9733, 1.0781], + [ 0.2842, -0.3121], + [-0.9311, -0.5546], + [ 0.4118, 1.2157]]], requires_grad=True) axes: [axis_0 axis_1 @@ -344,8 +328,8 @@ Expected: [ 1.3371, 1.4761, 0.6551]], requires_grad=True) Got: Parameter containing: - tensor([[ 0.5249, 0.1834, 1.2888], - [-0.4819, 0.8874, 0.0728]], requires_grad=True) + tensor([[-0.4446, 0.2519, 1.6840], + [ 1.2981, -0.4704, -0.1416]], requires_grad=True) ********************************************************************** File "../tensorkrowch/components.py", line ?, in default Failed example: @@ -396,11 +380,19 @@ Got: data_0[feature] <-> nodeA[input]]) ********************************************************************** 1 items had failures: - 21 of 398 in default -398 tests in 1 items. -377 passed and 21 failed. + 21 of 401 in default +401 tests in 1 items. +380 passed and 21 failed. ***Test Failed*** 21 failures. +Document: embeddings +-------------------- +1 items passed all tests: + 45 tests in default +45 tests in 1 items. +45 passed and 0 failed. +Test passed. + Document: operations -------------------- 1 items passed all tests: @@ -409,9 +401,17 @@ Document: operations 214 passed and 0 failed. Test passed. +Document: models +---------------- +1 items passed all tests: + 118 tests in default +118 tests in 1 items. +118 passed and 0 failed. +Test passed. + Doctest summary =============== - 775 tests + 778 tests 21 failures in tests 0 failures in setup code 0 failures in cleanup code diff --git a/docs/_build/doctrees/components.doctree b/docs/_build/doctrees/components.doctree index 2d1dbe4..1583617 100644 Binary files a/docs/_build/doctrees/components.doctree and b/docs/_build/doctrees/components.doctree differ diff --git a/docs/_build/doctrees/environment.pickle b/docs/_build/doctrees/environment.pickle index a7f1663..7de6fb4 100644 Binary files a/docs/_build/doctrees/environment.pickle and b/docs/_build/doctrees/environment.pickle differ diff --git a/docs/_build/doctrees/models.doctree b/docs/_build/doctrees/models.doctree index eb19824..3e21eec 100644 Binary files a/docs/_build/doctrees/models.doctree and b/docs/_build/doctrees/models.doctree differ diff --git a/docs/_build/doctrees/tutorials/0_first_steps.doctree b/docs/_build/doctrees/tutorials/0_first_steps.doctree index 92aae46..aabdd71 100644 Binary files a/docs/_build/doctrees/tutorials/0_first_steps.doctree and b/docs/_build/doctrees/tutorials/0_first_steps.doctree differ diff --git a/docs/_build/html/.buildinfo b/docs/_build/html/.buildinfo index 218778c..38ea4b1 100644 --- a/docs/_build/html/.buildinfo +++ b/docs/_build/html/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: e4485bdccc4a9ef1ab7141439481b17d +config: c711dcb750f5eb023ece44b19b251e46 tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/_build/html/_modules/index.html b/docs/_build/html/_modules/index.html index b16a90d..4c08981 100644 --- a/docs/_build/html/_modules/index.html +++ b/docs/_build/html/_modules/index.html @@ -5,7 +5,7 @@ - Overview: module code — TensorKrowch 1.1.2 documentation + Overview: module code — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/_modules/tensorkrowch/components.html b/docs/_build/html/_modules/tensorkrowch/components.html index b7bc16d..3f05e31 100644 --- a/docs/_build/html/_modules/tensorkrowch/components.html +++ b/docs/_build/html/_modules/tensorkrowch/components.html @@ -5,7 +5,7 @@ - tensorkrowch.components — TensorKrowch 1.1.2 documentation + tensorkrowch.components — TensorKrowch 1.1.3 documentation @@ -777,6 +777,7 @@

Source code for tensorkrowch.components

                  node1_list: Optional[List[bool]] = None,
                  init_method: Optional[Text] = None,
                  device: Optional[torch.device] = None,
+                 dtype: Optional[torch.dtype] = None,
                  **kwargs: float) -> None:
 
         super().__init__()
@@ -888,7 +889,10 @@ 

Source code for tensorkrowch.components

         if shape is not None:
             if init_method is not None:
                 self._unrestricted_set_tensor(
-                    init_method=init_method, device=device, **kwargs)
+                    init_method=init_method,
+                    device=device,
+                    dtype=dtype,
+                    **kwargs)
         else:
             self._unrestricted_set_tensor(tensor=tensor)
 
@@ -1096,6 +1100,36 @@ 

Source code for tensorkrowch.components

         their own tensors or use other node's tensor.
         """
         return not (self._leaf or self._data or self._virtual)
+ +
[docs] def is_conj(self) -> bool: + """ + Equivalent to `torch.is_conj() + <https://pytorch.org/docs/stable/generated/torch.is_conj.html>`_. + """ + tensor = self.tensor + if tensor is None: + return + return tensor.is_conj()
+ +
[docs] def is_complex(self) -> bool: + """ + Equivalent to `torch.is_complex() + <https://pytorch.org/docs/stable/generated/torch.is_complex.html>`_. + """ + tensor = self.tensor + if tensor is None: + return + return tensor.is_complex()
+ +
[docs] def is_floating_point(self) -> bool: + """ + Equivalent to `torch.is_floating_point() + <https://pytorch.org/docs/stable/generated/torch.is_floating_point.html>`_. + """ + tensor = self.tensor + if tensor is None: + return + return tensor.is_floating_point()
[docs] def size(self, axis: Optional[Ax] = None) -> Union[Size, int]: """ @@ -1497,9 +1531,10 @@

Source code for tensorkrowch.components

 
     @staticmethod
     def _make_copy_tensor(shape: Shape,
-                          device: torch.device = torch.device('cpu')) -> Tensor:
+                          device: Optional[torch.device] = None,
+                          dtype: Optional[torch.dtype] = None) -> Tensor:
         """Returns copy tensor (ones in the "diagonal", zeros elsewhere)."""
-        copy_tensor = torch.zeros(shape, device=device)
+        copy_tensor = torch.zeros(shape, device=device, dtype=dtype)
         rank = len(shape)
         if rank <= 1:
             i = 0
@@ -1512,7 +1547,8 @@ 

Source code for tensorkrowch.components

     def _make_rand_tensor(shape: Shape,
                           low: float = 0.,
                           high: float = 1.,
-                          device: torch.device = torch.device('cpu')) -> Tensor:
+                          device: Optional[torch.device] = None,
+                          dtype: Optional[torch.dtype] = None) -> Tensor:
         """Returns tensor whose entries are drawn from the uniform distribution."""
         if not isinstance(low, float):
             raise TypeError('`low` should be float type')
@@ -1520,13 +1556,14 @@ 

Source code for tensorkrowch.components

             raise TypeError('`high` should be float type')
         if low >= high:
             raise ValueError('`low` should be strictly smaller than `high`')
-        return torch.rand(shape, device=device) * (high - low) + low
+        return torch.rand(shape, device=device, dtype=dtype) * (high - low) + low
 
     @staticmethod
     def _make_randn_tensor(shape: Shape,
                            mean: float = 0.,
                            std: float = 1.,
-                           device: torch.device = torch.device('cpu')) -> Tensor:
+                           device: Optional[torch.device] = None,
+                           dtype: Optional[torch.dtype] = None) -> Tensor:
         """Returns tensor whose entries are drawn from the normal distribution."""
         if not isinstance(mean, float):
             raise TypeError('`mean` should be float type')
@@ -1534,12 +1571,13 @@ 

Source code for tensorkrowch.components

             raise TypeError('`std` should be float type')
         if std <= 0:
             raise ValueError('`std` should be positive')
-        return torch.randn(shape, device=device) * std + mean
+        return torch.randn(shape, device=device, dtype=dtype) * std + mean
 
 
[docs] def make_tensor(self, shape: Optional[Shape] = None, init_method: Text = 'zeros', - device: torch.device = torch.device('cpu'), + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, **kwargs: float) -> Tensor: """ Returns a tensor that can be put in the node, and is initialized according @@ -1553,6 +1591,8 @@

Source code for tensorkrowch.components

             Initialization method.
         device : torch.device, optional
             Device where to initialize the tensor.
+        dtype : torch.dtype, optional
+            Dtype of the tensor.
         kwargs : float
             Keyword arguments for the different initialization methods:
 
@@ -1574,15 +1614,17 @@ 

Source code for tensorkrowch.components

         if shape is None:
             shape = self._shape
         if init_method == 'zeros':
-            return torch.zeros(shape, device=device)
+            return torch.zeros(shape, device=device, dtype=dtype)
         elif init_method == 'ones':
-            return torch.ones(shape, device=device)
+            return torch.ones(shape, device=device, dtype=dtype)
         elif init_method == 'copy':
-            return self._make_copy_tensor(shape, device=device)
+            return self._make_copy_tensor(shape, device=device, dtype=dtype)
         elif init_method == 'rand':
-            return self._make_rand_tensor(shape, device=device, **kwargs)
+            return self._make_rand_tensor(shape, device=device, dtype=dtype,
+                                          **kwargs)
         elif init_method == 'randn':
-            return self._make_randn_tensor(shape, device=device, **kwargs)
+            return self._make_randn_tensor(shape, device=device, dtype=dtype,
+                                           **kwargs)
         else:
             raise ValueError('Choose a valid `init_method`: "zeros", '
                              '"ones", "copy", "rand", "randn"')
@@ -1655,6 +1697,7 @@

Source code for tensorkrowch.components

                                  tensor: Optional[Tensor] = None,
                                  init_method: Optional[Text] = 'zeros',
                                  device: Optional[torch.device] = None,
+                                 dtype: Optional[torch.dtype] = None,
                                  **kwargs: float) -> None:
         """
         Sets a new node's tensor or creates one with :meth:`make_tensor` and sets
@@ -1673,6 +1716,8 @@ 

Source code for tensorkrowch.components

             Initialization method.
         device : torch.device, optional
             Device where to initialize the tensor.
+        dtype : torch.dtype, optional
+            Dtype of the tensor.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`make_tensor`.
@@ -1680,10 +1725,14 @@ 

Source code for tensorkrowch.components

         if tensor is not None:
             if not isinstance(tensor, Tensor):
                 raise TypeError('`tensor` should be torch.Tensor type')
-            elif device is not None:
+            if device is not None:
                 warnings.warn('`device` was specified but is being ignored. '
                               'Provide a tensor that is already in the required'
                               ' device')
+            if dtype is not None:
+                warnings.warn('`dtype` was specified but is being ignored. '
+                              'Provide a tensor that already has the required'
+                              ' dtype')
 
             if not self._compatible_shape(tensor):
                 tensor = self._crop_tensor(tensor)
@@ -1691,10 +1740,13 @@ 

Source code for tensorkrowch.components

 
         elif init_method is not None:
             node_tensor = self.tensor
-            if (device is None) and (node_tensor is not None):
-                device = node_tensor.device
+            if node_tensor is not None:
+                if device is None:
+                    device = node_tensor.device
+                if dtype is None:
+                    dtype = node_tensor.dtype
             tensor = self.make_tensor(
-                init_method=init_method, device=device, **kwargs)
+                init_method=init_method, device=device, dtype=dtype, **kwargs)
             correct_format_tensor = self._set_tensor_format(tensor)
 
         else:
@@ -1707,6 +1759,7 @@ 

Source code for tensorkrowch.components

                    tensor: Optional[Tensor] = None,
                    init_method: Optional[Text] = 'zeros',
                    device: Optional[torch.device] = None,
+                   dtype: Optional[torch.dtype] = None,
                    **kwargs: float) -> None:
         """
         Sets new node's tensor or creates one with :meth:`make_tensor` and sets
@@ -1740,6 +1793,8 @@ 

Source code for tensorkrowch.components

             Initialization method.
         device : torch.device, optional
             Device where to initialize the tensor.
+        dtype : torch.dtype, optional
+            Dtype of the tensor.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`make_tensor`.
@@ -1780,6 +1835,7 @@ 

Source code for tensorkrowch.components

             self._unrestricted_set_tensor(tensor=tensor,
                                           init_method=init_method,
                                           device=device,
+                                          dtype=dtype,
                                           **kwargs)
         else:
             raise ValueError('Node\'s tensor can only be changed if it is not'
@@ -2332,6 +2388,8 @@ 

Source code for tensorkrowch.components

         Initialization method.
     device : torch.device, optional
         Device where to initialize the tensor if ``init_method`` is provided.
+    dtype : torch.dtype, optional
+        Dtype of the tensor if ``init_method`` is provided.
     kwargs : float
         Keyword arguments for the different initialization methods. See
         :meth:`AbstractNode.make_tensor`.
@@ -2656,6 +2714,8 @@ 

Source code for tensorkrowch.components

         Initialization method.
     device : torch.device, optional
         Device where to initialize the tensor if ``init_method`` is provided.
+    dtype : torch.dtype, optional
+        Dtype of the tensor if ``init_method`` is provided.
     kwargs : float
         Keyword arguments for the different initialization methods. See
         :meth:`AbstractNode.make_tensor`.
@@ -2737,6 +2797,7 @@ 

Source code for tensorkrowch.components

                  node1_list: Optional[List[bool]] = None,
                  init_method: Optional[Text] = None,
                  device: Optional[torch.device] = None,
+                 dtype: Optional[torch.dtype] = None,
                  **kwargs: float) -> None:
 
         super().__init__(shape=shape,
@@ -2752,6 +2813,7 @@ 

Source code for tensorkrowch.components

                          node1_list=node1_list,
                          init_method=init_method,
                          device=device,
+                         dtype=dtype,
                          **kwargs)
 
     # ----------
diff --git a/docs/_build/html/_modules/tensorkrowch/decompositions/svd_decompositions.html b/docs/_build/html/_modules/tensorkrowch/decompositions/svd_decompositions.html
index dfab7d8..faa1d44 100644
--- a/docs/_build/html/_modules/tensorkrowch/decompositions/svd_decompositions.html
+++ b/docs/_build/html/_modules/tensorkrowch/decompositions/svd_decompositions.html
@@ -5,7 +5,7 @@
   
     
     
-    tensorkrowch.decompositions.svd_decompositions — TensorKrowch 1.1.2 documentation
+    tensorkrowch.decompositions.svd_decompositions — TensorKrowch 1.1.3 documentation
     
   
   
@@ -472,6 +472,9 @@ 

Source code for tensorkrowch.decompositions.svd_decompositions

tensors.append(u) prev_bond = aux_rank + + if vh.is_complex(): + s = s.to(vh.dtype) vec = torch.diag_embed(s) @ vh tensors.append(vec) @@ -608,6 +611,9 @@

Source code for tensorkrowch.decompositions.svd_decompositions

tensors.append(u) prev_bond = aux_rank + + if vh.is_complex(): + s = s.to(vh.dtype) mat = torch.diag_embed(s) @ vh mat = mat.reshape(aux_rank, in_out_dims[-2], in_out_dims[-1]) diff --git a/docs/_build/html/_modules/tensorkrowch/embeddings.html b/docs/_build/html/_modules/tensorkrowch/embeddings.html index fea19d8..c21c6f0 100644 --- a/docs/_build/html/_modules/tensorkrowch/embeddings.html +++ b/docs/_build/html/_modules/tensorkrowch/embeddings.html @@ -5,7 +5,7 @@ - tensorkrowch.embeddings — TensorKrowch 1.1.2 documentation + tensorkrowch.embeddings — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/_modules/tensorkrowch/initializers.html b/docs/_build/html/_modules/tensorkrowch/initializers.html index 7ae7d3d..a2bdba3 100644 --- a/docs/_build/html/_modules/tensorkrowch/initializers.html +++ b/docs/_build/html/_modules/tensorkrowch/initializers.html @@ -5,7 +5,7 @@ - tensorkrowch.initializers — TensorKrowch 1.1.2 documentation + tensorkrowch.initializers — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/_modules/tensorkrowch/models/mpo.html b/docs/_build/html/_modules/tensorkrowch/models/mpo.html index 37300e8..92ac5ec 100644 --- a/docs/_build/html/_modules/tensorkrowch/models/mpo.html +++ b/docs/_build/html/_modules/tensorkrowch/models/mpo.html @@ -5,7 +5,7 @@ - tensorkrowch.models.mpo — TensorKrowch 1.1.2 documentation + tensorkrowch.models.mpo — TensorKrowch 1.1.3 documentation @@ -408,6 +408,8 @@

Source code for tensorkrowch.models.mpo

         explanation of the different initialization methods.
     device : torch.device, optional
         Device where to initialize the tensors if ``init_method`` is provided.
+    dtype : torch.dtype, optional
+        Dtype of the tensor if ``init_method`` is provided.
     kwargs : float
         Keyword arguments for the different initialization methods. See
         :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -448,6 +450,7 @@ 

Source code for tensorkrowch.models.mpo

                  n_batches: int = 1,
                  init_method: Text = 'randn',
                  device: Optional[torch.device] = None,
+                 dtype: Optional[torch.dtype] = None,
                  **kwargs) -> None:
 
         super().__init__(name='mpo')
@@ -594,6 +597,7 @@ 

Source code for tensorkrowch.models.mpo

         self.initialize(tensors=tensors,
                         init_method=init_method,
                         device=device,
+                        dtype=dtype,
                         **kwargs)
     
     # ----------
@@ -726,6 +730,7 @@ 

Source code for tensorkrowch.models.mpo

                    tensors: Optional[Sequence[torch.Tensor]] = None,
                    init_method: Optional[Text] = 'randn',
                    device: Optional[torch.device] = None,
+                   dtype: Optional[torch.dtype] = None,
                    **kwargs: float) -> None:
         """
         Initializes all the nodes of the :class:`MPO`. It can be called when
@@ -735,7 +740,7 @@ 

Source code for tensorkrowch.models.mpo

         
         * ``{"zeros", "ones", "copy", "rand", "randn"}``: Each node is
           initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
-          the given method, ``device`` and ``kwargs``.
+          the given method, ``device``, ``dtype`` and ``kwargs``.
         
         Parameters
         ----------
@@ -749,6 +754,8 @@ 

Source code for tensorkrowch.models.mpo

             Initialization method.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
+        dtype : torch.dtype, optional
+            Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -763,6 +770,8 @@ 

Source code for tensorkrowch.models.mpo

                 
                 if device is None:
                     device = tensors[0].device
+                if dtype is None:
+                    dtype = tensors[0].dtype
                 
                 if len(tensors) == 1:
                     tensors[0] = tensors[0].reshape(1,
@@ -773,13 +782,15 @@ 

Source code for tensorkrowch.models.mpo

                 else:
                     # Left node
                     aux_tensor = torch.zeros(*self._mats_env[0].shape,
-                                             device=device)
+                                             device=device,
+                                             dtype=dtype)
                     aux_tensor[0] = tensors[0]
                     tensors[0] = aux_tensor
                     
                     # Right node
                     aux_tensor = torch.zeros(*self._mats_env[-1].shape,
-                                             device=device)
+                                             device=device,
+                                             dtype=dtype)
                     aux_tensor[..., 0, :] = tensors[-1]
                     tensors[-1] = aux_tensor
                 
@@ -791,10 +802,13 @@ 

Source code for tensorkrowch.models.mpo

             for i, node in enumerate(self._mats_env):
                 node.set_tensor(init_method=init_method,
                                 device=device,
+                                dtype=dtype,
                                 **kwargs)
                 
                 if self._boundary == 'obc':
-                    aux_tensor = torch.zeros(*node.shape, device=device)
+                    aux_tensor = torch.zeros(*node.shape,
+                                             device=device,
+                                             dtype=dtype)
                     if i == 0:
                         # Left node
                         aux_tensor[0] = node.tensor[0]
@@ -804,8 +818,12 @@ 

Source code for tensorkrowch.models.mpo

                     node.tensor = aux_tensor
         
         if self._boundary == 'obc':
-            self._left_node.set_tensor(init_method='copy', device=device)
-            self._right_node.set_tensor(init_method='copy', device=device)
+ self._left_node.set_tensor(init_method='copy', + device=device, + dtype=dtype) + self._right_node.set_tensor(init_method='copy', + device=device, + dtype=dtype)
[docs] def set_data_nodes(self) -> None: """ @@ -842,7 +860,8 @@

Source code for tensorkrowch.models.mpo

                       tensors=None,
                       n_batches=self._n_batches,
                       init_method=None,
-                      device=None)
+                      device=None,
+                      dtype=None)
         new_mpo.name = self.name + '_copy'
         if share_tensors:
             for new_node, node in zip(new_mpo._mats_env, self._mats_env):
@@ -1352,6 +1371,8 @@ 

Source code for tensorkrowch.models.mpo

         explanation of the different initialization methods.
     device : torch.device, optional
         Device where to initialize the tensors if ``init_method`` is provided.
+    dtype : torch.dtype, optional
+        Dtype of the tensor if ``init_method`` is provided.
     kwargs : float
         Keyword arguments for the different initialization methods. See
         :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -1380,6 +1401,7 @@ 

Source code for tensorkrowch.models.mpo

                  n_batches: int = 1,
                  init_method: Text = 'randn',
                  device: Optional[torch.device] = None,
+                 dtype: Optional[torch.dtype] = None,
                  **kwargs) -> None:
 
         tensors = None
@@ -1424,6 +1446,7 @@ 

Source code for tensorkrowch.models.mpo

                          n_batches=n_batches,
                          init_method=init_method,
                          device=device,
+                         dtype=dtype,
                          **kwargs)
         self.name = 'umpo'
     
@@ -1449,6 +1472,7 @@ 

Source code for tensorkrowch.models.mpo

                    tensors: Optional[Sequence[torch.Tensor]] = None,
                    init_method: Optional[Text] = 'randn',
                    device: Optional[torch.device] = None,
+                   dtype: Optional[torch.dtype] = None,
                    **kwargs: float) -> None:
         """
         Initializes the common tensor of the :class:`UMPO`. It can be called
@@ -1458,7 +1482,7 @@ 

Source code for tensorkrowch.models.mpo

         
         * ``{"zeros", "ones", "copy", "rand", "randn"}``: The tensor is
           initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
-          the given method, ``device`` and ``kwargs``.
+          the given method, ``device``, ``dtype`` and ``kwargs``.
         
         Parameters
         ----------
@@ -1470,6 +1494,8 @@ 

Source code for tensorkrowch.models.mpo

             Initialization method.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
+        dtype : torch.dtype, optional
+            Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -1480,6 +1506,7 @@ 

Source code for tensorkrowch.models.mpo

         elif init_method is not None:
             self.uniform_memory.set_tensor(init_method=init_method,
                                            device=device,
+                                           dtype=dtype,
                                            **kwargs)
[docs] def copy(self, share_tensors: bool = False) -> 'UMPO': @@ -1507,7 +1534,8 @@

Source code for tensorkrowch.models.mpo

                        tensor=None,
                        n_batches=self._n_batches,
                        init_method=None,
-                       device=None)
+                       device=None,
+                       dtype=None)
         new_mpo.name = self.name + '_copy'
         if share_tensors:
             new_mpo.uniform_memory.tensor = self.uniform_memory.tensor
diff --git a/docs/_build/html/_modules/tensorkrowch/models/mps.html b/docs/_build/html/_modules/tensorkrowch/models/mps.html
index 563ef92..17a676b 100644
--- a/docs/_build/html/_modules/tensorkrowch/models/mps.html
+++ b/docs/_build/html/_modules/tensorkrowch/models/mps.html
@@ -5,7 +5,7 @@
   
     
     
-    tensorkrowch.models.mps — TensorKrowch 1.1.2 documentation
+    tensorkrowch.models.mps — TensorKrowch 1.1.3 documentation
     
   
   
@@ -438,6 +438,8 @@ 

Source code for tensorkrowch.models.mps

         explanation of the different initialization methods.
     device : torch.device, optional
         Device where to initialize the tensors if ``init_method`` is provided.
+    dtype : torch.dtype, optional
+        Dtype of the tensor if ``init_method`` is provided.
     kwargs : float
         Keyword arguments for the different initialization methods. See
         :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -497,6 +499,7 @@ 

Source code for tensorkrowch.models.mps

                  n_batches: int = 1,
                  init_method: Text = 'randn',
                  device: Optional[torch.device] = None,
+                 dtype: Optional[torch.dtype] = None,
                  **kwargs) -> None:
 
         super().__init__(name='mps')
@@ -690,6 +693,7 @@ 

Source code for tensorkrowch.models.mps

         self.initialize(tensors=tensors,
                         init_method=init_method,
                         device=device,
+                        dtype=dtype,
                         **kwargs)
     
     # ----------
@@ -906,7 +910,9 @@ 

Source code for tensorkrowch.models.mps

                 if i == self._n_features - 1:
                     self._mats_env[-1]['right'] ^ self._right_node['left']
     
-    def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
+    def _make_canonical(self,
+                        device: Optional[torch.device] = None,
+                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
         """
         Creates random unitaries to initialize the MPS in canonical form with
         orthogonality center at the rightmost node. Unitaries in nodes are
@@ -934,22 +940,22 @@ 

Source code for tensorkrowch.models.mps

                 phys_dim = node_shape[1]
             size = max(aux_shape[0], aux_shape[1])
             
-            tensor = random_unitary(size, device=device)
+            tensor = random_unitary(size, device=device, dtype=dtype)
             tensor = tensor[:min(aux_shape[0], size), :min(aux_shape[1], size)]
             tensor = tensor.reshape(*node_shape)
             
             if i == (self._n_features - 1):
-                if self._boundary == 'obc':
-                    tensor = tensor.t() / tensor.norm() * sqrt(phys_dim)
-                else:
-                    tensor = tensor / tensor.norm() * sqrt(phys_dim)
-            else:
-                tensor = tensor * sqrt(phys_dim)
+                if (self._boundary == 'obc') and (i == 0):
+                    tensor = tensor[:, 0]
+                tensor = tensor / tensor.norm()
+            tensor = tensor * sqrt(phys_dim)
             
             tensors.append(tensor)
         return tensors
     
-    def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
+    def _make_unitaries(self,
+                        device: Optional[torch.device] = None,
+                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
         """
         Creates random unitaries to initialize the MPS nodes as stacks of
         unitaries.
@@ -973,7 +979,7 @@ 

Source code for tensorkrowch.models.mps

                 size_2 = min(node.shape[2], size)
             
             for _ in range(node.shape[1]):
-                tensor = random_unitary(size, device=device)
+                tensor = random_unitary(size, device=device, dtype=dtype)
                 tensor = tensor[:size_1, :size_2]
                 units.append(tensor)
             
@@ -981,19 +987,18 @@ 

Source code for tensorkrowch.models.mps

             
             if self._boundary == 'obc':
                 if i == 0:
-                    tensors.append(units.squeeze(0))
+                    units = units.squeeze(0)
                 elif i == (self._n_features - 1):
-                    tensors.append(units.squeeze(-1))
-                else:
-                    tensors.append(units)
-            else:    
-                tensors.append(units)
+                    units = units.squeeze(-1)
+            tensors.append(units)
+        
         return tensors
 
 
[docs] def initialize(self, tensors: Optional[Sequence[torch.Tensor]] = None, init_method: Optional[Text] = 'randn', device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, **kwargs: float) -> None: """ Initializes all the nodes of the :class:`MPS`. It can be called when @@ -1003,7 +1008,7 @@

Source code for tensorkrowch.models.mps

         
         * ``{"zeros", "ones", "copy", "rand", "randn"}``: Each node is
           initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
-          the given method, ``device`` and ``kwargs``.
+          the given method, ``device``, ``dtype`` and ``kwargs``.
         
         * ``"randn_eye"``: Nodes are initialized as in this
           `paper <https://arxiv.org/abs/1605.03795>`_, adding identities at the
@@ -1033,14 +1038,16 @@ 

Source code for tensorkrowch.models.mps

             Initialization method.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
+        dtype : torch.dtype, optional
+            Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
         """
         if init_method == 'unit':
-            tensors = self._make_unitaries(device=device)
+            tensors = self._make_unitaries(device=device, dtype=dtype)
         elif init_method == 'canonical':
-            tensors = self._make_canonical(device=device)
+            tensors = self._make_canonical(device=device, dtype=dtype)
 
         if tensors is not None:
             if len(tensors) != self._n_features:
@@ -1052,19 +1059,23 @@ 

Source code for tensorkrowch.models.mps

                 
                 if device is None:
                     device = tensors[0].device
+                if dtype is None:
+                    dtype = tensors[0].dtype
                 
                 if len(tensors) == 1:
                     tensors[0] = tensors[0].reshape(1, -1, 1)
                 else:
                     # Left node
                     aux_tensor = torch.zeros(*self._mats_env[0].shape,
-                                             device=device)
+                                             device=device,
+                                             dtype=dtype)
                     aux_tensor[0] = tensors[0]
                     tensors[0] = aux_tensor
                     
                     # Right node
                     aux_tensor = torch.zeros(*self._mats_env[-1].shape,
-                                             device=device)
+                                             device=device,
+                                             dtype=dtype)
                     aux_tensor[..., 0] = tensors[-1]
                     tensors[-1] = aux_tensor
                 
@@ -1080,16 +1091,20 @@ 

Source code for tensorkrowch.models.mps

             for i, node in enumerate(self._mats_env):
                 node.set_tensor(init_method=init_method,
                                 device=device,
+                                dtype=dtype,
                                 **kwargs)
                 if add_eye:
                     aux_tensor = node.tensor.detach()
                     aux_tensor[:, 0, :] += torch.eye(node.shape[0],
                                                      node.shape[2],
-                                                     device=device)
+                                                     device=device,
+                                                     dtype=dtype)
                     node.tensor = aux_tensor
                 
                 if self._boundary == 'obc':
-                    aux_tensor = torch.zeros(*node.shape, device=device)
+                    aux_tensor = torch.zeros(*node.shape,
+                                             device=device,
+                                             dtype=dtype)
                     if i == 0:
                         # Left node
                         aux_tensor[0] = node.tensor[0]
@@ -1100,8 +1115,12 @@ 

Source code for tensorkrowch.models.mps

                         node.tensor = aux_tensor
         
         if self._boundary == 'obc':
-            self._left_node.set_tensor(init_method='copy', device=device)
-            self._right_node.set_tensor(init_method='copy', device=device)
+ self._left_node.set_tensor(init_method='copy', + device=device, + dtype=dtype) + self._right_node.set_tensor(init_method='copy', + device=device, + dtype=dtype)
[docs] def set_data_nodes(self) -> None: """ @@ -1139,7 +1158,8 @@

Source code for tensorkrowch.models.mps

                       out_features=self._out_features,
                       n_batches=self._n_batches,
                       init_method=None,
-                      device=None)
+                      device=None,
+                      dtype=None)
         new_mps.name = self.name + '_copy'
         if share_tensors:
             for new_node, node in zip(new_mps._mats_env, self._mats_env):
@@ -1552,10 +1572,10 @@ 

Source code for tensorkrowch.models.mps

                 copied_nodes = []
                 for node in nodes_out_env:
                     copied_node = node.__class__(shape=node._shape,
-                                              axes_names=node.axes_names,
-                                              name='virtual_result_copy',
-                                              network=self,
-                                              virtual=True)
+                                                 axes_names=node.axes_names,
+                                                 name='virtual_result_copy',
+                                                 network=self,
+                                                 virtual=True)
                     copied_node.set_tensor_from(node)
                     copied_nodes.append(copied_node)
                     
@@ -1643,6 +1663,12 @@ 

Source code for tensorkrowch.models.mps

                         # Connect copies directly to output nodes
                         copied_node['input'] ^ node['input']
                 
+                # If MPS nodes are complex, copied nodes are their conjugates
+                is_complex = copied_nodes[0].is_complex()
+                if is_complex:
+                    for i, node in enumerate(copied_nodes):
+                        copied_nodes[i] = node.conj()
+                
                 # Contract output nodes with copies
                 mats_out_env = self._input_contraction(
                     nodes_env=nodes_out_env,
@@ -1764,6 +1790,12 @@ 

Source code for tensorkrowch.models.mps

             copied_nodes = []
             for node in all_nodes:
                 copied_nodes.append(node.neighbours('input'))
+        
+        # If MPS nodes are complex, copied nodes are their conjugates
+        is_complex = copied_nodes[0].is_complex()
+        if is_complex:
+            for i, node in enumerate(copied_nodes):
+                copied_nodes[i] = node.conj()
             
         # Contract output nodes with copies
         mats_out_env = self._input_contraction(
@@ -1872,10 +1904,14 @@ 

Source code for tensorkrowch.models.mps

         if n_dims >= 1:
             if n_dims == 1:
                 data = torch.cat(data, dim=-1)
-                data = basis(data, dim=dims[0]).float().to(self.in_env[0].device)
+                data = basis(data, dim=dims[0])\
+                    .to(self.in_env[0].dtype)\
+                    .to(self.in_env[0].device)
             elif n_dims > 1:
                 data = [
-                    basis(dat, dim=dim).squeeze(-2).float().to(self.in_env[0].device)
+                    basis(dat, dim=dim).squeeze(-2)\
+                        .to(self.in_env[0].dtype)\
+                        .to(self.in_env[0].device)
                     for dat, dim in zip(data, dims)
                     ]
             
@@ -2002,7 +2038,7 @@ 

Source code for tensorkrowch.models.mps

                                   middle_tensor.shape[-1]),         # right
             full_matrices=False)
         
-        s = s[s > 0]
+        s = s[s.pow(2) > 0]
         entropy = -(s.pow(2) * s.pow(2).log()).sum()
         
         # Rescale
@@ -2195,6 +2231,7 @@ 

Source code for tensorkrowch.models.mps

                              side: Text = 'right'):
         """Projects all nodes into a space of dimension ``bond_dim``."""
         device = nodes[0].tensor.device
+        dtype = nodes[0].tensor.dtype
 
         if side == 'left':
             nodes.reverse()
@@ -2219,7 +2256,7 @@ 

Source code for tensorkrowch.models.mps

 
                 proj_mat_node.tensor = torch.eye(
                     torch.tensor(phys_dim_lst).prod().int().item(),
-                    bond_dim).view(*phys_dim_lst, -1).to(device)
+                    bond_dim).view(*phys_dim_lst, -1).to(dtype).to(device)
                 for k in range(j + 1):
                     nodes[k]['input'] ^ proj_mat_node[k]
 
@@ -2239,7 +2276,7 @@ 

Source code for tensorkrowch.models.mps

 
             proj_mat_node.tensor = torch.eye(
                 torch.tensor(phys_dim_lst).prod().int().item(),
-                bond_dim).view(*phys_dim_lst, -1).to(device)
+                bond_dim).view(*phys_dim_lst, -1).to(dtype).to(device)
             for k in range(j + 1):
                 nodes[k]['input'] ^ proj_mat_node[k]
 
@@ -2256,7 +2293,8 @@ 

Source code for tensorkrowch.models.mps

                                  name=f'proj_vec_node_{side}_({k})',
                                  network=self)
 
-            proj_vec_node.tensor = torch.eye(phys_dim, 1).squeeze().to(device)
+            proj_vec_node.tensor = torch.eye(phys_dim, 1).squeeze()\
+                .to(dtype).to(device)
             nodes[k]['input'] ^ proj_vec_node['input']
             line_mat_nodes.append(proj_vec_node @ nodes[k])
 
@@ -2424,6 +2462,8 @@ 

Source code for tensorkrowch.models.mps

         explanation of the different initialization methods.
     device : torch.device, optional
         Device where to initialize the tensors if ``init_method`` is provided.
+    dtype : torch.dtype, optional
+        Dtype of the tensor if ``init_method`` is provided.
     kwargs : float
         Keyword arguments for the different initialization methods. See
         :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -2452,6 +2492,7 @@ 

Source code for tensorkrowch.models.mps

                  n_batches: int = 1,
                  init_method: Text = 'randn',
                  device: Optional[torch.device] = None,
+                 dtype: Optional[torch.dtype] = None,
                  **kwargs) -> None:
         
         tensors = None
@@ -2493,6 +2534,7 @@ 

Source code for tensorkrowch.models.mps

                          n_batches=n_batches,
                          init_method=init_method,
                          device=device,
+                         dtype=dtype,
                          **kwargs)
         self.name = 'umps'
 
@@ -2513,7 +2555,9 @@ 

Source code for tensorkrowch.models.mps

         for node in self._mats_env:
             node.set_tensor_from(uniform_memory)
     
-    def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
+    def _make_canonical(self,
+                        device: Optional[torch.device] = None,
+                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
         """
         Creates random unitaries to initialize the MPS in canonical form with
         orthogonality center at the rightmost node. Unitaries in nodes are
@@ -2527,13 +2571,15 @@ 

Source code for tensorkrowch.models.mps

         size = max(aux_shape[0], aux_shape[1])
         phys_dim = node_shape[1]
         
-        tensor = random_unitary(size, device=device)
+        tensor = random_unitary(size, device=device, dtype=dtype)
         tensor = tensor[:min(aux_shape[0], size), :min(aux_shape[1], size)]
         tensor = tensor.reshape(*node_shape)
         tensor = tensor * sqrt(phys_dim)
         return tensor
     
-    def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
+    def _make_unitaries(self,
+                        device: Optional[torch.device] = None,
+                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
         """
         Creates random unitaries to initialize the MPS nodes as stacks of
         unitaries.
@@ -2543,7 +2589,7 @@ 

Source code for tensorkrowch.models.mps

         
         units = []
         for _ in range(node_shape[1]):
-            tensor = random_unitary(node_shape[0], device=device)
+            tensor = random_unitary(node_shape[0], device=device, dtype=dtype)
             units.append(tensor)
         tensor = torch.stack(units, dim=1)
         return tensor
@@ -2552,6 +2598,7 @@ 

Source code for tensorkrowch.models.mps

                    tensors: Optional[Sequence[torch.Tensor]] = None,
                    init_method: Optional[Text] = 'randn',
                    device: Optional[torch.device] = None,
+                   dtype: Optional[torch.dtype] = None,
                    **kwargs: float) -> None:
         """
         Initializes the common tensor of the :class:`UMPS`. It can be called
@@ -2561,7 +2608,7 @@ 

Source code for tensorkrowch.models.mps

         
         * ``{"zeros", "ones", "copy", "rand", "randn"}``: The tensor is
           initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
-          the given method, ``device`` and ``kwargs``.
+          the given method, ``device``, ``dtype`` and ``kwargs``.
         
         * ``"randn_eye"``: Tensor is initialized as in this
           `paper <https://arxiv.org/abs/1605.03795>`_, adding identities at the
@@ -2588,6 +2635,8 @@ 

Source code for tensorkrowch.models.mps

             Initialization method.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
+        dtype : torch.dtype, optional
+            Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -2595,9 +2644,9 @@ 

Source code for tensorkrowch.models.mps

         node = self.uniform_memory
         
         if init_method == 'unit':
-            tensors = [self._make_unitaries(device=device)]
+            tensors = [self._make_unitaries(device=device, dtype=dtype)]
         elif init_method == 'canonical':
-            tensors = [self._make_canonical(device=device)]
+            tensors = [self._make_canonical(device=device, dtype=dtype)]
         
         if tensors is not None:
             node.tensor = tensors[0]
@@ -2610,12 +2659,14 @@ 

Source code for tensorkrowch.models.mps

             
             node.set_tensor(init_method=init_method,
                             device=device,
+                            dtype=dtype,
                             **kwargs)
             if add_eye:
                 aux_tensor = node.tensor.detach()
                 aux_tensor[:, 0, :] += torch.eye(node.shape[0],
                                                  node.shape[2],
-                                                 device=device)
+                                                 device=device,
+                                                 dtype=dtype)
                 node.tensor = aux_tensor
[docs] def copy(self, share_tensors: bool = False) -> 'UMPS': @@ -2644,7 +2695,8 @@

Source code for tensorkrowch.models.mps

                        out_features=self._out_features,
                        n_batches=self._n_batches,
                        init_method=None,
-                       device=None)
+                       device=None,
+                       dtype=None)
         new_mps.name = self.name + '_copy'
         if share_tensors:
             new_mps.uniform_memory.tensor = self.uniform_memory.tensor
@@ -2788,6 +2840,8 @@ 

Source code for tensorkrowch.models.mps

         explanation of the different initialization methods.
     device : torch.device, optional
         Device where to initialize the tensors if ``init_method`` is provided.
+    dtype : torch.dtype, optional
+        Dtype of the tensor if ``init_method`` is provided.
     kwargs : float
         Keyword arguments for the different initialization methods. See
         :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -2829,6 +2883,7 @@ 

Source code for tensorkrowch.models.mps

                  n_batches: int = 1,
                  init_method: Text = 'randn',
                  device: Optional[torch.device] = None,
+                 dtype: Optional[torch.dtype] = None,
                  **kwargs) -> None:
         
         phys_dim = None
@@ -2890,6 +2945,7 @@ 

Source code for tensorkrowch.models.mps

                          n_batches=n_batches,
                          init_method=init_method,
                          device=device,
+                         dtype=dtype,
                          **kwargs)
         self.name = 'mpslayer'
         self._in_dim = self._phys_dim[:out_position] + \
@@ -2919,7 +2975,9 @@ 

Source code for tensorkrowch.models.mps

         """Returns the output node."""
         return self._mats_env[self._out_position]
     
-    def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
+    def _make_canonical(self,
+                        device: Optional[torch.device] = None,
+                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
         """
         Creates random unitaries to initialize the MPS in canonical form with
         orthogonality center at the rightmost node. Unitaries in nodes are
@@ -2944,14 +3002,16 @@ 

Source code for tensorkrowch.models.mps

                 phys_dim = node_shape[1]
             size = max(aux_shape[0], aux_shape[1])
             
-            tensor = random_unitary(size, device=device)
+            tensor = random_unitary(size, device=device, dtype=dtype)
             tensor = tensor[:min(aux_shape[0], size), :min(aux_shape[1], size)]
             tensor = tensor.reshape(*node_shape)
             
             left_tensors.append(tensor * sqrt(phys_dim))
         
         # Output node
-        out_tensor = torch.randn(self.out_node.shape, device=device)
+        out_tensor = torch.randn(self.out_node.shape,
+                                 device=device,
+                                 dtype=dtype)
         phys_dim = out_tensor.shape[1]
         if self._boundary == 'obc':
             if self._out_position == 0:
@@ -2978,7 +3038,7 @@ 

Source code for tensorkrowch.models.mps

                 phys_dim = node_shape[1]
             size = max(aux_shape[0], aux_shape[1])
             
-            tensor = random_unitary(size, device=device)
+            tensor = random_unitary(size, device=device, dtype=dtype)
             tensor = tensor[:min(aux_shape[0], size), :min(aux_shape[1], size)]
             tensor = tensor.reshape(*node_shape)
             
@@ -2989,7 +3049,9 @@ 

Source code for tensorkrowch.models.mps

         tensors = left_tensors + [out_tensor] + right_tensors
         return tensors
     
-    def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
+    def _make_unitaries(self,
+                        device: Optional[torch.device] = None,
+                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
         """
         Creates random unitaries to initialize the MPS nodes as stacks of
         unitaries.
@@ -3011,7 +3073,7 @@ 

Source code for tensorkrowch.models.mps

                 size_2 = min(node.shape[2], size)
             
             for _ in range(node.shape[1]):
-                tensor = random_unitary(size, device=device)
+                tensor = random_unitary(size, device=device, dtype=dtype)
                 tensor = tensor[:size_1, :size_2]
                 units.append(tensor)
             
@@ -3026,7 +3088,9 @@ 

Source code for tensorkrowch.models.mps

                 left_tensors.append(units)
         
         # Output node
-        out_tensor = torch.randn(self.out_node.shape, device=device)
+        out_tensor = torch.randn(self.out_node.shape,
+                                 device=device,
+                                 dtype=dtype)
         if self._boundary == 'obc':
             if self._out_position == 0:
                 out_tensor = out_tensor[0]
@@ -3050,7 +3114,7 @@ 

Source code for tensorkrowch.models.mps

                 size_2 = min(node.shape[2], size)
             
             for _ in range(node.shape[1]):
-                tensor = random_unitary(size, device=device).t()
+                tensor = random_unitary(size, device=device, dtype=dtype).H
                 tensor = tensor[:size_1, :size_2]
                 units.append(tensor)
             
@@ -3068,63 +3132,12 @@ 

Source code for tensorkrowch.models.mps

         # All tensors
         tensors = left_tensors + [out_tensor] + right_tensors
         return tensors
-
-    def initialize(self,
-                   tensors: Optional[Sequence[torch.Tensor]] = None,
-                   init_method: Optional[Text] = 'randn',
-                   device: Optional[torch.device] = None,
-                   **kwargs: float) -> None:
-        """
-        Initializes all the nodes of the :class:`MPS`. It can be called when
-        instantiating the model, or to override the existing nodes' tensors.
-        
-        There are different methods to initialize the nodes:
-        
-        * ``{"zeros", "ones", "copy", "rand", "randn"}``: Each node is
-          initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
-          the given method, ``device`` and ``kwargs``.
-        
-        * ``"randn_eye"``: Nodes are initialized as in this
-          `paper <https://arxiv.org/abs/1605.03795>`_, adding identities at the
-          top of random gaussian tensors. In this case, ``std`` should be
-          specified with a low value, e.g., ``std = 1e-9``.
-        
-        * ``"unit"``: Nodes are initialized as stacks of random unitaries. This,
-          combined (at least) with an embedding of the inputs as elements of
-          the computational basis (:func:`~tensorkrowch.embeddings.discretize`
-          combined with :func:`~tensorkrowch.embeddings.basis`)
-        
-        * ``"canonical"```: MPS is initialized in canonical form with a squared
-          norm `close` to the product of all the physical dimensions (if bond
-          dimensions are bigger than the powers of the physical dimensions,
-          the norm could vary). Th orthogonality center is at the rightmost
-          node.
-        
-        Parameters
-        ----------
-        tensors : list[torch.Tensor] or tuple[torch.Tensor], optional
-            Sequence of tensors to set in each of the MPS nodes. If ``boundary``
-            is ``"obc"``, all tensors should be rank-3, except the first and
-            last ones, which can be rank-2, or rank-1 (if the first and last are
-            the same). If ``boundary`` is ``"pbc"``, all tensors should be
-            rank-3.
-        init_method : {"zeros", "ones", "copy", "rand", "randn", "randn_eye", "unit", "canonical"}, optional
-            Initialization method.
-        device : torch.device, optional
-            Device where to initialize the tensors if ``init_method`` is provided.
-        kwargs : float
-            Keyword arguments for the different initialization methods. See
-            :meth:`~tensorkrowch.AbstractNode.make_tensor`.
-        """
-        if init_method == 'unit':
-            tensors = self._make_unitaries(device=device)
-        elif init_method == 'canonical':
-            tensors = self._make_canonical(device=device)
     
 
[docs] def initialize(self, tensors: Optional[Sequence[torch.Tensor]] = None, init_method: Optional[Text] = 'randn', device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, **kwargs: float) -> None: """ Initializes all the nodes of the :class:`MPSLayer`. It can be called when @@ -3134,7 +3147,7 @@

Source code for tensorkrowch.models.mps

         
         * ``{"zeros", "ones", "copy", "rand", "randn"}``: Each node is
           initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
-          the given method, ``device`` and ``kwargs``.
+          the given method, ``device``, ``dtype`` and ``kwargs``.
         
         * ``"randn_eye"``: Nodes are initialized as in this
           `paper <https://arxiv.org/abs/1605.03795>`_, adding identities at the
@@ -3163,14 +3176,16 @@ 

Source code for tensorkrowch.models.mps

             Initialization method.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
+        dtype : torch.dtype, optional
+            Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
         """
         if init_method == 'unit':
-            tensors = self._make_unitaries(device=device)
+            tensors = self._make_unitaries(device=device, dtype=dtype)
         elif init_method == 'canonical':
-            tensors = self._make_canonical(device=device)
+            tensors = self._make_canonical(device=device, dtype=dtype)
 
         if tensors is not None:
             if len(tensors) != self._n_features:
@@ -3182,19 +3197,23 @@ 

Source code for tensorkrowch.models.mps

                 
                 if device is None:
                     device = tensors[0].device
+                if dtype is None:
+                    dtype = tensors[0].dtype
                 
                 if len(tensors) == 1:
                     tensors[0] = tensors[0].reshape(1, -1, 1)
                 else:
                     # Left node
                     aux_tensor = torch.zeros(*self._mats_env[0].shape,
-                                             device=device)
+                                             device=device,
+                                             dtype=dtype)
                     aux_tensor[0] = tensors[0]
                     tensors[0] = aux_tensor
                     
                     # Right node
                     aux_tensor = torch.zeros(*self._mats_env[-1].shape,
-                                             device=device)
+                                             device=device,
+                                             dtype=dtype)
                     aux_tensor[..., 0] = tensors[-1]
                     tensors[-1] = aux_tensor
                 
@@ -3210,12 +3229,14 @@ 

Source code for tensorkrowch.models.mps

             for i, node in enumerate(self._mats_env):
                 node.set_tensor(init_method=init_method,
                                 device=device,
+                                dtype=dtype,
                                 **kwargs)
                 if add_eye:
                     aux_tensor = node.tensor.detach()
                     eye_tensor = torch.eye(node.shape[0],
                                            node.shape[2],
-                                           device=device)
+                                           device=device,
+                                           dtype=dtype)
                     if i == self._out_position:
                         eye_tensor = eye_tensor.unsqueeze(1)
                         eye_tensor = eye_tensor.expand(node.shape)
@@ -3225,7 +3246,9 @@ 

Source code for tensorkrowch.models.mps

                     node.tensor = aux_tensor
                 
                 if self._boundary == 'obc':
-                    aux_tensor = torch.zeros(*node.shape, device=device)
+                    aux_tensor = torch.zeros(*node.shape,
+                                             device=device,
+                                             dtype=dtype)
                     if i == 0:
                         # Left node
                         aux_tensor[0] = node.tensor[0]
@@ -3236,8 +3259,12 @@ 

Source code for tensorkrowch.models.mps

                         node.tensor = aux_tensor
         
         if self._boundary == 'obc':
-            self._left_node.set_tensor(init_method='copy', device=device)
-            self._right_node.set_tensor(init_method='copy', device=device)
+ self._left_node.set_tensor(init_method='copy', + device=device, + dtype=dtype) + self._right_node.set_tensor(init_method='copy', + device=device, + dtype=dtype)
[docs] def copy(self, share_tensors: bool = False) -> 'MPSLayer': """ @@ -3266,7 +3293,8 @@

Source code for tensorkrowch.models.mps

                            tensors=None,
                            n_batches=self._n_batches,
                            init_method=None,
-                           device=None)
+                           device=None,
+                           dtype=None)
         new_mps.name = self.name + '_copy'
         if share_tensors:
             for new_node, node in zip(new_mps._mats_env, self._mats_env):
@@ -3328,6 +3356,8 @@ 

Source code for tensorkrowch.models.mps

         explanation of the different initialization methods.
     device : torch.device, optional
         Device where to initialize the tensors if ``init_method`` is provided.
+    dtype : torch.dtype, optional
+        Dtype of the tensor if ``init_method`` is provided.
     kwargs : float
         Keyword arguments for the different initialization methods. See
         :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -3358,6 +3388,7 @@ 

Source code for tensorkrowch.models.mps

                  n_batches: int = 1,
                  init_method: Text = 'randn',
                  device: Optional[torch.device] = None,
+                 dtype: Optional[torch.dtype] = None,
                  **kwargs) -> None:
         
         phys_dim = None
@@ -3433,6 +3464,7 @@ 

Source code for tensorkrowch.models.mps

                          n_batches=n_batches,
                          init_method=init_method,
                          device=device,
+                         dtype=dtype,
                          **kwargs)
         self.name = 'umpslayer'
         self._in_dim = self._phys_dim[:out_position] + \
@@ -3481,7 +3513,9 @@ 

Source code for tensorkrowch.models.mps

         for node in in_nodes:
             node.set_tensor_from(uniform_memory)
     
-    def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
+    def _make_canonical(self,
+                        device: Optional[torch.device] = None,
+                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
         """
         Creates random unitaries to initialize the MPS in canonical form with
         orthogonality center at the rightmost node. Unitaries in nodes are
@@ -3496,18 +3530,22 @@ 

Source code for tensorkrowch.models.mps

         size = max(aux_shape[0], aux_shape[1])
         phys_dim = node_shape[1]
         
-        uni_tensor = random_unitary(size, device=device)
+        uni_tensor = random_unitary(size, device=device, dtype=dtype)
         uni_tensor = uni_tensor[:min(aux_shape[0], size), :min(aux_shape[1], size)]
         uni_tensor = uni_tensor.reshape(*node_shape)
         uni_tensor = uni_tensor * sqrt(phys_dim)
         
         # Output node
-        out_tensor = torch.randn(self.out_node.shape, device=device)
+        out_tensor = torch.randn(self.out_node.shape,
+                                 device=device,
+                                 dtype=dtype)
         out_tensor = out_tensor / out_tensor.norm() * sqrt(out_tensor.shape[1])
         
         return [uni_tensor, out_tensor]
     
-    def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
+    def _make_unitaries(self,
+                        device: Optional[torch.device] = None,
+                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
         """
         Creates random unitaries to initialize the MPS nodes as stacks of
         unitaries.
@@ -3518,7 +3556,9 @@ 

Source code for tensorkrowch.models.mps

             
             units = []
             for _ in range(node_shape[1]):
-                tensor = random_unitary(node_shape[0], device=device)
+                tensor = random_unitary(node_shape[0],
+                                        device=device,
+                                        dtype=dtype)
                 units.append(tensor)
             
             tensors.append(torch.stack(units, dim=1))
@@ -3529,6 +3569,7 @@ 

Source code for tensorkrowch.models.mps

                    tensors: Optional[Sequence[torch.Tensor]] = None,
                    init_method: Optional[Text] = 'randn',
                    device: Optional[torch.device] = None,
+                   dtype: Optional[torch.dtype] = None,
                    **kwargs: float) -> None:
         """
         Initializes the common tensor of the :class:`UMPSLayer`. It can be called
@@ -3538,7 +3579,7 @@ 

Source code for tensorkrowch.models.mps

         
         * ``{"zeros", "ones", "copy", "rand", "randn"}``: The tensor is
           initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
-          the given method, ``device`` and ``kwargs``.
+          the given method, ``device``, ``dtype`` and ``kwargs``.
         
         * ``"randn_eye"``: Tensor is initialized as in this
           `paper <https://arxiv.org/abs/1605.03795>`_, adding identities at the
@@ -3566,14 +3607,16 @@ 

Source code for tensorkrowch.models.mps

             Initialization method.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
+        dtype : torch.dtype, optional
+            Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
         """
         if init_method == 'unit':
-            tensors = self._make_unitaries(device=device)
+            tensors = self._make_unitaries(device=device, dtype=dtype)
         elif init_method == 'canonical':
-            tensors = self._make_canonical(device=device)
+            tensors = self._make_canonical(device=device, dtype=dtype)
         
         if tensors is not None:
             self.uniform_memory.tensor = tensors[0]
@@ -3588,12 +3631,14 @@ 

Source code for tensorkrowch.models.mps

                 
                 node.set_tensor(init_method=init_method,
                                 device=device,
+                                dtype=dtype,
                                 **kwargs)
                 if add_eye:
                     aux_tensor = node.tensor.detach()
                     eye_tensor = torch.eye(node.shape[0],
                                            node.shape[2],
-                                           device=device)
+                                           device=device,
+                                           dtype=dtype)
                     if i == 0:
                         aux_tensor[:, 0, :] += eye_tensor
                     else:
@@ -3628,7 +3673,8 @@ 

Source code for tensorkrowch.models.mps

                             tensor=None,
                             n_batches=self._n_batches,
                             init_method=None,
-                            device=None)
+                            device=None,
+                            dtype=None)
         new_mps.name = self.name + '_copy'
         if share_tensors:
             new_mps.uniform_memory.tensor = self.uniform_memory.tensor
@@ -3869,6 +3915,8 @@ 

Source code for tensorkrowch.models.mps

         explanation of the different initialization methods.
     device : torch.device, optional
         Device where to initialize the tensors if ``init_method`` is provided.
+    dtype : torch.dtype, optional
+        Dtype of the tensor if ``init_method`` is provided.
     kwargs : float
         Keyword arguments for the different initialization methods. See
         :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -3895,6 +3943,7 @@ 

Source code for tensorkrowch.models.mps

                  tensors: Optional[Sequence[torch.Tensor]] = None,
                  init_method: Text = 'randn',
                  device: Optional[torch.device] = None,
+                 dtype: Optional[torch.dtype] = None,
                  **kwargs):
         
         unfold = self._set_attributes(in_channels=in_channels,
@@ -3912,6 +3961,7 @@ 

Source code for tensorkrowch.models.mps

                      n_batches=2,
                      init_method=init_method,
                      device=device,
+                     dtype=dtype,
                      **kwargs)
         
         self.unfold = unfold
@@ -3980,7 +4030,8 @@ 

Source code for tensorkrowch.models.mps

                           boundary=self._boundary,
                           tensors=None,
                           init_method=None,
-                          device=None)
+                          device=None,
+                          dtype=None)
         new_mps.name = self.name + '_copy'
         if share_tensors:
             for new_node, node in zip(new_mps._mats_env, self._mats_env):
@@ -4032,6 +4083,8 @@ 

Source code for tensorkrowch.models.mps

         explanation of the different initialization methods.
     device : torch.device, optional
         Device where to initialize the tensors if ``init_method`` is provided.
+    dtype : torch.dtype, optional
+        Dtype of the tensor if ``init_method`` is provided.
     kwargs : float
         Keyword arguments for the different initialization methods. See
         :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -4061,6 +4114,7 @@ 

Source code for tensorkrowch.models.mps

                  tensor: Optional[torch.Tensor] = None,
                  init_method: Text = 'randn',
                  device: Optional[torch.device] = None,
+                 dtype: Optional[torch.dtype] = None,
                  **kwargs):
 
         unfold = self._set_attributes(in_channels=in_channels,
@@ -4077,6 +4131,7 @@ 

Source code for tensorkrowch.models.mps

                       n_batches=2,
                       init_method=init_method,
                       device=device,
+                      dtype=dtype,
                       **kwargs)
         
         self.unfold = unfold
@@ -4144,7 +4199,8 @@ 

Source code for tensorkrowch.models.mps

                            dilation=self.dilation,
                            tensor=None,
                            init_method=None,
-                           device=None)
+                           device=None,
+                           dtype=None)
         new_mps.name = self.name + '_copy'
         if share_tensors:
             new_mps.uniform_memory.tensor = self.uniform_memory.tensor
@@ -4212,6 +4268,8 @@ 

Source code for tensorkrowch.models.mps

         explanation of the different initialization methods.
     device : torch.device, optional
         Device where to initialize the tensors if ``init_method`` is provided.
+    dtype : torch.dtype, optional
+        Dtype of the tensor if ``init_method`` is provided.
     kwargs : float
         Keyword arguments for the different initialization methods. See
         :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -4241,6 +4299,7 @@ 

Source code for tensorkrowch.models.mps

                  tensors: Optional[Sequence[torch.Tensor]] = None,
                  init_method: Text = 'randn',
                  device: Optional[torch.device] = None,
+                 dtype: Optional[torch.dtype] = None,
                  **kwargs):
 
         unfold = self._set_attributes(in_channels=in_channels,
@@ -4261,6 +4320,7 @@ 

Source code for tensorkrowch.models.mps

                           n_batches=2,
                           init_method=init_method,
                           device=device,
+                          dtype=dtype,
                           **kwargs)
         
         self._out_channels = out_channels
@@ -4336,7 +4396,8 @@ 

Source code for tensorkrowch.models.mps

                                boundary=self._boundary,
                                tensors=None,
                                init_method=None,
-                               device=None)
+                               device=None,
+                               dtype=None)
         new_mps.name = self.name + '_copy'
         if share_tensors:
             for new_node, node in zip(new_mps._mats_env, self._mats_env):
@@ -4397,6 +4458,8 @@ 

Source code for tensorkrowch.models.mps

         detailed explanation of the different initialization methods.
     device : torch.device, optional
         Device where to initialize the tensors if ``init_method`` is provided.
+    dtype : torch.dtype, optional
+        Dtype of the tensor if ``init_method`` is provided.
     kwargs : float
         Keyword arguments for the different initialization methods. See
         :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -4429,6 +4492,7 @@ 

Source code for tensorkrowch.models.mps

                  tensors: Optional[Sequence[torch.Tensor]] = None,
                  init_method: Text = 'randn',
                  device: Optional[torch.device] = None,
+                 dtype: Optional[torch.dtype] = None,
                  **kwargs):
 
         unfold = self._set_attributes(in_channels=in_channels,
@@ -4448,6 +4512,7 @@ 

Source code for tensorkrowch.models.mps

                            n_batches=2,
                            init_method=init_method,
                            device=device,
+                           dtype=dtype,
                            **kwargs)
         
         self._out_channels = out_channels
@@ -4527,7 +4592,8 @@ 

Source code for tensorkrowch.models.mps

                                 dilation=self.dilation,
                                 tensor=None,
                                 init_method=None,
-                                device=None)
+                                device=None,
+                                dtype=None)
         new_mps.name = self.name + '_copy'
         if share_tensors:
             new_mps.uniform_memory.tensor = self.uniform_memory.tensor
diff --git a/docs/_build/html/_modules/tensorkrowch/models/mps_data.html b/docs/_build/html/_modules/tensorkrowch/models/mps_data.html
index 07bb2a2..424d1f6 100644
--- a/docs/_build/html/_modules/tensorkrowch/models/mps_data.html
+++ b/docs/_build/html/_modules/tensorkrowch/models/mps_data.html
@@ -5,7 +5,7 @@
   
     
     
-    tensorkrowch.models.mps_data — TensorKrowch 1.1.2 documentation
+    tensorkrowch.models.mps_data — TensorKrowch 1.1.3 documentation
     
   
   
@@ -413,6 +413,8 @@ 

Source code for tensorkrowch.models.mps_data

        :meth:`add_data` to see how to initialize MPS nodes with data tensors.
     device : torch.device, optional
         Device where to initialize the tensors if ``init_method`` is provided.
+    dtype : torch.dtype, optional
+            Dtype of the tensor if ``init_method`` is provided.
     kwargs : float
         Keyword arguments for the different initialization methods. See
         :meth:`~tensorkrowch.AbstractNode.make_tensor`.
@@ -439,6 +441,7 @@ 

Source code for tensorkrowch.models.mps_data

tensors: Optional[Sequence[torch.Tensor]] = None,
                  init_method: Optional[Text] = None,
                  device: Optional[torch.device] = None,
+                 dtype: Optional[torch.dtype] = None,
                  **kwargs) -> None:
 
         super().__init__(name='mps_data')
@@ -568,6 +571,7 @@ 

Source code for tensorkrowch.models.mps_data

self._make_nodes()
         self.initialize(init_method=init_method,
                             device=device,
+                            dtype=dtype,
                             **kwargs)
         
         if tensors is not None:
@@ -616,6 +620,15 @@ 

Source code for tensorkrowch.models.mps_data

        """Returns the list of nodes in ``mats_env``."""
         return self._mats_env
     
+    @property
+    def tensors(self) -> List[torch.Tensor]:
+        """Returns the list of MPS tensors."""
+        mps_tensors = [node.tensor for node in self._mats_env]
+        if self._boundary == 'obc':
+            mps_tensors[0] = mps_tensors[0][0, :, :]
+            mps_tensors[-1] = mps_tensors[-1][:, :, 0]
+        return mps_tensors
+    
     # -------
     # Methods
     # -------
@@ -673,6 +686,7 @@ 

Source code for tensorkrowch.models.mps_data

[docs]    def initialize(self,
                    init_method: Optional[Text] = 'randn',
                    device: Optional[torch.device] = None,
+                   dtype: Optional[torch.dtype] = None,
                    **kwargs: float) -> None:
         """
         Initializes all the nodes of the :class:`MPSData`. It can be called when
@@ -682,7 +696,7 @@ 

Source code for tensorkrowch.models.mps_data

        
         * ``{"zeros", "ones", "copy", "rand", "randn"}``: Each node is
           initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
-          the given method, ``device`` and ``kwargs``.
+          the given method, ``device``, ``dtype`` and ``kwargs``.
         
         Parameters
         ----------
@@ -691,22 +705,31 @@ 

Source code for tensorkrowch.models.mps_data

            initialize MPS nodes with data tensors.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
+        dtype : torch.dtype, optional
+            Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
         """
         if self._boundary == 'obc':
-            self._left_node.set_tensor(init_method='copy', device=device)
-            self._right_node.set_tensor(init_method='copy', device=device)
+            self._left_node.set_tensor(init_method='copy',
+                                       device=device,
+                                       dtype=dtype)
+            self._right_node.set_tensor(init_method='copy',
+                                        device=device,
+                                        dtype=dtype)
                 
         if init_method is not None:
             for i, node in enumerate(self._mats_env):
                 node.set_tensor(init_method=init_method,
                                 device=device,
+                                dtype=dtype,
                                 **kwargs)
                 
                 if self._boundary == 'obc':
-                    aux_tensor = torch.zeros(*node.shape, device=device)
+                    aux_tensor = torch.zeros(*node.shape,
+                                             device=device,
+                                             dtype=dtype)
                     if i == 0:
                         # Left node
                         aux_tensor[..., 0, :, :] = node.tensor[..., 0, :, :]
@@ -762,37 +785,45 @@ 

Source code for tensorkrowch.models.mps_data

data = data[:]
         device = data[0].device
+        dtype = data[0].dtype
         for i, node in enumerate(self._mats_env):
             if self._boundary == 'obc':
-                aux_tensor = torch.zeros(*node.shape, device=device)
                 if (i == 0) and (i == (self._n_features - 1)):
                     aux_tensor = torch.zeros(*data[i].shape[:-1],
                                              *node.shape[-3:],
-                                             device=device)
+                                             device=device,
+                                             dtype=dtype)
                     aux_tensor[..., 0, :, 0] = data[i]
                     data[i] = aux_tensor
                 elif i == 0:
                     aux_tensor = torch.zeros(*data[i].shape[:-2],
                                              *node.shape[-3:-1],
                                              data[i].shape[-1],
-                                             device=device)
+                                             device=device,
+                                             dtype=dtype)
                     aux_tensor[..., 0, :, :] = data[i]
                     data[i] = aux_tensor
                 elif i == (self._n_features - 1):
                     aux_tensor = torch.zeros(*data[i].shape[:-1],
                                              *node.shape[-2:],
-                                             device=device)
+                                             device=device,
+                                             dtype=dtype)
                     aux_tensor[..., 0] = data[i]
                     data[i] = aux_tensor
                     
             node._direct_set_tensor(data[i])
         
-        # Send left and right nodes to correct device
+        # Send left and right nodes to correct device and dtype
         if self._boundary == 'obc':
             if self._left_node.device != device:
                 self._left_node.tensor = self._left_node.tensor.to(device)
+            if self._left_node.dtype != dtype:
+                self._left_node.tensor = self._left_node.tensor.to(dtype)
+            
             if self._right_node.device != device:
                 self._right_node.tensor = self._right_node.tensor.to(device)
+            if self._right_node.dtype != dtype:
+                self._right_node.tensor = self._right_node.tensor.to(dtype)
         
         # Update bond dim
         if self._boundary == 'obc':
diff --git a/docs/_build/html/_modules/tensorkrowch/models/peps.html b/docs/_build/html/_modules/tensorkrowch/models/peps.html
index 3c67433..1d26549 100644
--- a/docs/_build/html/_modules/tensorkrowch/models/peps.html
+++ b/docs/_build/html/_modules/tensorkrowch/models/peps.html
@@ -5,7 +5,7 @@
   
     
     
-    tensorkrowch.models.peps — TensorKrowch 1.1.2 documentation
+    tensorkrowch.models.peps — TensorKrowch 1.1.3 documentation
     
   
   
diff --git a/docs/_build/html/_modules/tensorkrowch/models/tree.html b/docs/_build/html/_modules/tensorkrowch/models/tree.html
index ec2b696..c57a6ce 100644
--- a/docs/_build/html/_modules/tensorkrowch/models/tree.html
+++ b/docs/_build/html/_modules/tensorkrowch/models/tree.html
@@ -5,7 +5,7 @@
   
     
     
-    tensorkrowch.models.tree — TensorKrowch 1.1.2 documentation
+    tensorkrowch.models.tree — TensorKrowch 1.1.3 documentation
     
   
   
diff --git a/docs/_build/html/_modules/tensorkrowch/operations.html b/docs/_build/html/_modules/tensorkrowch/operations.html
index 82dbc4b..ddf3449 100644
--- a/docs/_build/html/_modules/tensorkrowch/operations.html
+++ b/docs/_build/html/_modules/tensorkrowch/operations.html
@@ -5,7 +5,7 @@
   
     
     
-    tensorkrowch.operations — TensorKrowch 1.1.2 documentation
+    tensorkrowch.operations — TensorKrowch 1.1.3 documentation
     
   
   
@@ -336,8 +336,11 @@ 

Source code for tensorkrowch.operations

         * permute_           (in-place)
         * tprod
         * mul
+        * div
         * add
         * sub
+        * renormalize
+        * conj
         
     Node-like operations:
         * split
@@ -1793,6 +1796,120 @@ 

Source code for tensorkrowch.operations

 AbstractNode.renormalize = renormalize_node
 
 
+##################################   conj    ##################################
+# MARK: conj
+def _check_first_conj(node: AbstractNode) -> Optional[Successor]:
+    args = (node,)
+    successors = node._successors.get('conj')
+    if not successors:
+        return None
+    return successors.get(args)
+
+
+def _conj_first(node: AbstractNode) -> Node:
+    new_node = Node._create_resultant(axes_names=node.axes_names,
+                                      name='conj',
+                                      network=node._network,
+                                      tensor=node.tensor.conj(),
+                                      edges=node._edges,
+                                      node1_list=node.is_node1())
+    
+    # Create successor
+    net = node._network
+    args = (node,)
+    successor = Successor(node_ref=node.node_ref(),
+                          index=node._tensor_info['index'],
+                          child=new_node)
+
+    # Add successor to parent
+    if 'conj' in node._successors:
+        node._successors['conj'].update({args: successor})
+    else:
+        node._successors['conj'] = {args: successor}
+
+    # Add operation to list of performed operations of TN
+    net._seq_ops.append(('conj', args))
+
+    # Record in inverse_memory while tracing
+    if net._tracing:
+        node._record_in_inverse_memory()
+
+    return new_node
+
+
+def _conj_next(successor: Successor, node: AbstractNode) -> Node:
+    tensor = node._direct_get_tensor(successor.node_ref,
+                                     successor.index)
+    child = successor.child
+    child._direct_set_tensor(tensor.conj())
+
+    # Record in inverse_memory while contracting, if network is traced
+    # (to delete memory if possible)
+    if node._network._traced:
+        node._check_inverse_memory(successor.node_ref)
+    
+    return child
+
+
+conj_op = Operation('conj',
+                    _check_first_conj,
+                    _conj_first,
+                    _conj_next)
+
+
+def conj(node: AbstractNode) -> Node:
+    """
+    Returns a view of the node's tensor with a flipped conjugate bit. If the
+    node has a non-complex dtype, this function returns a new node with the
+    same tensor.
+
+    See `conj <https://pytorch.org/docs/stable/generated/torch.conj.html>`_
+    in the **PyTorch** documentation.
+
+    Parameters
+    ----------
+    node : Node or ParamNode
+        Node that is to be conjugated.
+
+    Returns
+    -------
+    Node
+    
+    Examples
+    --------
+    >>> nodeA = tk.randn((3, 3), dtype=torch.complex64)
+    >>> conjA = tk.conj(nodeA)
+    >>> conjA.is_conj()
+    True
+    """
+    return conj_op(node)
+
+
+conj_node = copy_func(conj)
+conj_node.__doc__ = \
+    """
+    Returns a view of the node's tensor with a flipped conjugate bit. If the
+    node has a non-complex dtype, this function returns a new node with the
+    same tensor.
+
+    See `conj <https://pytorch.org/docs/stable/generated/torch.conj.html>`_
+    in the **PyTorch** documentation.
+
+    Returns
+    -------
+    Node
+    
+    Examples
+    --------
+    >>> nodeA = tk.randn((3, 3), dtype=torch.complex64)
+    >>> conjA = nodeA.conj()
+    >>> conjA.is_conj()
+    True
+    """
+
+AbstractNode.conj = conj_node
+
+
 ###############################################################################
 #                            NODE-LIKE OPERATIONS                             #
 ###############################################################################
@@ -1926,9 +2043,12 @@ 

Source code for tensorkrowch.operations

         u = u[..., :rank]
         s = s[..., :rank]
         vh = vh[..., :rank, :]
+        
+        if u.is_complex():
+            s = s.to(u.dtype)
 
         if mode == 'svdr':
-            phase = torch.sign(torch.randn(s.shape))
+            phase = torch.sgn(torch.randn_like(s))
             phase = torch.diag_embed(phase)
             u = u @ phase
             vh = phase @ vh
@@ -2125,9 +2245,12 @@ 

Source code for tensorkrowch.operations

         u = u[..., :rank]
         s = s[..., :rank]
         vh = vh[..., :rank, :]
+        
+        if u.is_complex():
+            s = s.to(u.dtype)
 
         if mode == 'svdr':
-            phase = torch.sign(torch.randn(s.shape))
+            phase = torch.sgn(torch.randn_like(s))
             phase = torch.diag_embed(phase)
             u = u @ phase
             vh = phase @ vh
diff --git a/docs/_build/html/_sources/tutorials/0_first_steps.rst.txt b/docs/_build/html/_sources/tutorials/0_first_steps.rst.txt
index 05997e1..153719f 100644
--- a/docs/_build/html/_sources/tutorials/0_first_steps.rst.txt
+++ b/docs/_build/html/_sources/tutorials/0_first_steps.rst.txt
@@ -134,7 +134,7 @@ possible classes.
 ::
 
     # Check if GPU is available
-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
 
     # Instantiate model
     mps = tk.models.MPSLayer(n_features=image_size[0] * image_size[1] + 1,
diff --git a/docs/_build/html/_static/documentation_options.js b/docs/_build/html/_static/documentation_options.js
index 5314893..3d088f5 100644
--- a/docs/_build/html/_static/documentation_options.js
+++ b/docs/_build/html/_static/documentation_options.js
@@ -1,6 +1,6 @@
 var DOCUMENTATION_OPTIONS = {
     URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'),
-    VERSION: '1.1.2',
+    VERSION: '1.1.3',
     LANGUAGE: 'None',
     COLLAPSE_INDEX: false,
     BUILDER: 'html',
diff --git a/docs/_build/html/api.html b/docs/_build/html/api.html
index 04904b7..3001f29 100644
--- a/docs/_build/html/api.html
+++ b/docs/_build/html/api.html
@@ -6,7 +6,7 @@
     
     
 
-    API Reference — TensorKrowch 1.1.2 documentation
+    API Reference — TensorKrowch 1.1.3 documentation
     
   
   
diff --git a/docs/_build/html/components.html b/docs/_build/html/components.html
index 79314d3..9888808 100644
--- a/docs/_build/html/components.html
+++ b/docs/_build/html/components.html
@@ -6,7 +6,7 @@
     
     
 
-    Components — TensorKrowch 1.1.2 documentation
+    Components — TensorKrowch 1.1.3 documentation
     
   
   
@@ -705,7 +705,7 @@ 

Nodes

AbstractNode#

-class tensorkrowch.AbstractNode(shape=None, axes_names=None, name=None, network=None, data=False, virtual=False, override_node=False, tensor=None, edges=None, override_edges=False, node1_list=None, init_method=None, device=None, **kwargs)[source]#
+class tensorkrowch.AbstractNode(shape=None, axes_names=None, name=None, network=None, data=False, virtual=False, override_node=False, tensor=None, edges=None, override_edges=False, node1_list=None, init_method=None, device=None, dtype=None, **kwargs)[source]#

Abstract class for all types of nodes. Defines what a node is and most of its properties and methods. Since it is an abstract class, cannot be instantiated.

Nodes are the elements that make up a TensorNetwork. At its most @@ -948,6 +948,24 @@

AbstractNode +
+is_conj()[source]#
+

Equivalent to torch.is_conj().

+

+ +
+
+is_complex()[source]#
+

Equivalent to torch.is_complex().

+
+ +
+
+is_floating_point()[source]#
+

Equivalent to torch.is_floating_point().

+
+
size(axis=None)[source]#
@@ -1134,7 +1152,7 @@

AbstractNode
-make_tensor(shape=None, init_method='zeros', device=device(type='cpu'), **kwargs)[source]#
+make_tensor(shape=None, init_method='zeros', device=None, dtype=None, **kwargs)[source]#

Returns a tensor that can be put in the node, and is initialized according to init_method. By default, it has the same shape as the node.

@@ -1143,6 +1161,7 @@

AbstractNodeNone, node’s shape will be used.

  • init_method ({"zeros", "ones", "copy", "rand", "randn"}, optional) – Initialization method.

  • device (torch.device, optional) – Device where to initialize the tensor.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor.

  • kwargs (float) –

    Keyword arguments for the different initialization methods:

    @@ -1527,6 +1547,28 @@

    AbstractNode +
    +conj()#
    +

    Returns a view of the node’s tensor with a flipped conjugate bit. If the +node has a non-complex dtype, this function returns a new node with the +same tensor.

    +

    See conj +in the PyTorch documentation.

    +
    +
    Return type
    +

    Node

    +
    +
    +

    Examples

    +
    >>> nodeA = tk.randn((3, 3), dtype=torch.complex64)
    +>>> conjA = nodeA.conj()
    +>>> conjA.is_conj()
    +True
    +
    +
    +

  • +
    contract_between(node2)#
    @@ -1813,7 +1855,7 @@

    AbstractNode#

    -class tensorkrowch.Node(shape=None, axes_names=None, name=None, network=None, data=False, virtual=False, override_node=False, tensor=None, edges=None, override_edges=False, node1_list=None, init_method=None, device=None, **kwargs)[source]#
    +class tensorkrowch.Node(shape=None, axes_names=None, name=None, network=None, data=False, virtual=False, override_node=False, tensor=None, edges=None, override_edges=False, node1_list=None, init_method=None, device=None, dtype=None, **kwargs)[source]#

    Base class for non-trainable nodes. Should be subclassed by any class of nodes that are not intended to be trained (e.g. StackNode).

    Can be used for fixed nodes of the TensorNetwork, or intermediate @@ -1869,6 +1911,7 @@

    Node#< should also be provided.

  • init_method ({"zeros", "ones", "copy", "rand", "randn"}, optional) – Initialization method.

  • device (torch.device, optional) – Device where to initialize the tensor if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See AbstractNode.make_tensor().

  • @@ -2042,7 +2085,7 @@

    Node#<

    ParamNode#

    -class tensorkrowch.ParamNode(shape=None, axes_names=None, name=None, network=None, virtual=False, override_node=False, tensor=None, edges=None, override_edges=False, node1_list=None, init_method=None, device=None, **kwargs)[source]#
    +class tensorkrowch.ParamNode(shape=None, axes_names=None, name=None, network=None, virtual=False, override_node=False, tensor=None, edges=None, override_edges=False, node1_list=None, init_method=None, device=None, dtype=None, **kwargs)[source]#

    Class for trainable nodes. Should be subclassed by any class of nodes that are intended to be trained (e.g. ParamStackNode).

    Should be used as the initial nodes conforming the TensorNetwork, @@ -2104,6 +2147,7 @@

    ParamNodeinit_method is provided.

    +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See AbstractNode.make_tensor().

  • diff --git a/docs/_build/html/contents.html b/docs/_build/html/contents.html index fcaef76..09518ae 100644 --- a/docs/_build/html/contents.html +++ b/docs/_build/html/contents.html @@ -6,7 +6,7 @@ - <no title> — TensorKrowch 1.1.2 documentation + <no title> — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/decompositions.html b/docs/_build/html/decompositions.html index 4d073be..c66ed6d 100644 --- a/docs/_build/html/decompositions.html +++ b/docs/_build/html/decompositions.html @@ -6,7 +6,7 @@ - Decompositions — TensorKrowch 1.1.2 documentation + Decompositions — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/embeddings.html b/docs/_build/html/embeddings.html index 5e8e34d..56de0af 100644 --- a/docs/_build/html/embeddings.html +++ b/docs/_build/html/embeddings.html @@ -6,7 +6,7 @@ - Embeddings — TensorKrowch 1.1.2 documentation + Embeddings — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/examples.html b/docs/_build/html/examples.html index 41d6f03..73d815e 100644 --- a/docs/_build/html/examples.html +++ b/docs/_build/html/examples.html @@ -6,7 +6,7 @@ - Examples — TensorKrowch 1.1.2 documentation + Examples — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/examples/hybrid_tnn_model.html b/docs/_build/html/examples/hybrid_tnn_model.html index 5e33e48..b6e00ca 100644 --- a/docs/_build/html/examples/hybrid_tnn_model.html +++ b/docs/_build/html/examples/hybrid_tnn_model.html @@ -6,7 +6,7 @@ - Hybrid Tensorial Neural Network model — TensorKrowch 1.1.2 documentation + Hybrid Tensorial Neural Network model — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/examples/mps_dmrg.html b/docs/_build/html/examples/mps_dmrg.html index 7741b4c..76b90e9 100644 --- a/docs/_build/html/examples/mps_dmrg.html +++ b/docs/_build/html/examples/mps_dmrg.html @@ -6,7 +6,7 @@ - DMRG-like training of MPS — TensorKrowch 1.1.2 documentation + DMRG-like training of MPS — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/examples/mps_dmrg.ipynb b/docs/_build/html/examples/mps_dmrg.ipynb deleted file mode 100644 index f681d39..0000000 --- a/docs/_build/html/examples/mps_dmrg.ipynb +++ /dev/null @@ -1,587 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# DMRG-like training of MPS\n", - "\n", - "Here we show how one can use MPS models to train via DMRG, as shown in\n", - "[[SS16']](https://arxiv.org/abs/1605.05775)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%mkdir data\n", - "%mkdir models" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import torch.optim as optim\n", - "from torch.utils.data import DataLoader\n", - "\n", - "import torchvision.transforms as transforms\n", - "import torchvision.datasets as datasets\n", - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "import tensorkrowch as tk" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "device(type='cuda', index=0)" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "device = torch.device('cpu')\n", - "\n", - "if torch.cuda.is_available():\n", - " device = torch.device('cuda:0')\n", - "elif torch.backends.mps.is_available():\n", - " device = torch.device('mps:0')\n", - "else:\n", - " device = torch.device('cpu')\n", - "\n", - "device" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# MNIST Dataset\n", - "dataset_name = 'mnist'\n", - "batch_size = 64\n", - "image_size = 15\n", - "input_size = image_size ** 2\n", - "num_classes = 10\n", - "\n", - "transform = transforms.Compose([transforms.ToTensor(),\n", - " transforms.Resize(image_size, antialias=True),\n", - " ])\n", - "\n", - "# Load data\n", - "train_dataset = datasets.MNIST(root='data/',\n", - " train=True,\n", - " transform=transform,\n", - " download=True)\n", - "test_dataset = datasets.MNIST(root='data/',\n", - " train=False,\n", - " transform=transform,\n", - " download=True)\n", - "\n", - "train_loader = DataLoader(dataset=train_dataset,\n", - " batch_size=batch_size,\n", - " shuffle=True)\n", - "test_loader = DataLoader(dataset=test_dataset,\n", - " batch_size=batch_size,\n", - " shuffle=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAceElEQVR4nO3df2xV9f3H8dell166rlxsHW3vbKEaIgKVoQhRzAaxkTWIsgWdptYGjc6tCrWGlW4rbipU3OYq2BQxmbBE/PGHRSVT13UIGuVXa51kW4HYYZW1nYneCyVca+/5/rFw96386oVzePeW5yM5f/Tew+fzvmB5etrLqc9xHEcAAJxjI6wHAACcnwgQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAw4bce4OtisZgOHjyojIwM+Xw+63EAAAlyHEeHDh1SKBTSiBEnv84ZcgE6ePCg8vLyrMcAAJylzs5OXXTRRSd9fsgFKCMjQ9J/Bx89erTxNACAREUiEeXl5cX/Pj+ZIRegY192Gz16NAECgCR2um+j8CYEAIAJAgQAMEGAAAAmCBAAwAQBAgCY8CxA9fX1Gj9+vEaNGqWZM2dq586dXm0FAEhCngToxRdfVGVlpR566CG1trZq6tSpmjt3rnp6erzYDgCQhDwJ0BNPPKG7775bixYt0qRJk7R27Vp94xvf0B/+8AcvtgMAJCHXA/Tll1+qpaVFRUVF/9tkxAgVFRXpvffeO+78aDSqSCQy4AAADH+uB+izzz5Tf3+/srOzBzyenZ2trq6u486vra1VMBiMH9wHDgDOD+bvgquurlY4HI4fnZ2d1iMBAM4B1+8Fd+GFFyolJUXd3d0DHu/u7lZOTs5x5wcCAQUCAbfHAAAMca5fAaWmpurKK69Uc3Nz/LFYLKbm5mZdffXVbm8HAEhSntwNu7KyUmVlZZo+fbpmzJihuro69fb2atGiRV5sBwBIQp4E6Ec/+pH+85//aPny5erq6tJ3vvMdvfHGG8e9MQEAcP7yOY7jWA/x/0UiEQWDQYXDYX4eEAAkocH+PW7+LjgAwPmJAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAm/NYD4PwWjUY93yMcDnu+x9GjRz1d/5vf/Kan60tSRkaG53uMHDnS8z2QPLgCAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMCE6wGqra3VVVddpYyMDI0dO1YLFixQe3u729sAAJKc6wHaunWrysvLtX37djU1Namvr0/XX3+9ent73d4KAJDEXL8VzxtvvDHg4/Xr12vs2LFqaWnRd7/7Xbe3AwAkKc/vBXfsPlyZmZknfD4ajQ64H1gkEvF6JADAEODpmxBisZgqKio0a9YsTZky5YTn1NbWKhgMxo+8vDwvRwIADBGeBqi8vFx79uzRCy+8cNJzqqurFQ6H40dnZ6eXIwEAhgjPvgR33333afPmzdq2bZsuuuiik54XCAQUCAS8GgMAMES5HiDHcXT//fersbFRb731lgoKCtzeAgAwDLgeoPLycm3cuFGvvPKKMjIy1NXVJUkKBoNKS0tzezsAQJJy/XtADQ0NCofDmj17tnJzc+PHiy++6PZWAIAk5smX4AAAOB3uBQcAMEGAAAAmCBAAwAQBAgCYIEAAABOe34wU3jly5Ijne7z77ruerl9fX+/p+pK0f/9+z/dIT0/3dP3+/n5P15ek0tJSz/e4//77PV3f5/N5uj7cxRUQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJvzWAwxXjuN4vseKFSs836O9vd3T9RcuXOjp+pI0ffp0z/cIBoOerv/nP//Z0/Ul6dlnn/V8jzvuuMPT9ceMGePp+nAXV0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMCE5wF67LHH5PP5VFFR4fVWAIAk4mmAdu3apaefflqXX365l9sAAJKQZwE6fPiwSkpK9Mwzz+iCCy7wahsAQJLyLEDl5eWaN2+eioqKvNoCAJDEPLkX3AsvvKDW1lbt2rXrtOdGo1FFo9H4x5FIxIuRAABDjOtXQJ2dnVqyZImee+45jRo16rTn19bWKhgMxo+8vDy3RwIADEGuB6ilpUU9PT264oor5Pf75ff7tXXrVq1evVp+v1/9/f0Dzq+urlY4HI4fnZ2dbo8EABiCXP8S3HXXXacPP/xwwGOLFi3SxIkTVVVVpZSUlAHPBQIBBQIBt8cAAAxxrgcoIyNDU6ZMGfBYenq6srKyjnscAHD+4k4IAAAT5+Qnor711lvnYhsAQBLhCggAYIIAAQBMECAAgAkCBAAwQYAAACbOybvgzkc+n8/zPe666y7P90hLS/N0/ezsbE/Xl6QRI7z//6y+vj5P1//3v//t6fqS1N3d7fkeXv8+IblwBQQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJv/UAOHMXX3yx9QhJwXEcz/d4//33PV1//fr1nq4vSQsXLvR8jwsuuMDzPZA8uAICAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwIQnAfr00091++23KysrS2lpaSosLNTu3bu92AoAkKRcvxPC559/rlmzZmnOnDl6/fXX9a1vfUv79u3jX0ADAAZwPUCrVq1SXl6enn322fhjBQUFbm8DAEhyrn8J7tVXX9X06dN18803a+zYsZo2bZqeeeaZk54fjUYViUQGHACA4c/1AH300UdqaGjQhAkT9Oabb+onP/mJFi9erA0bNpzw/NraWgWDwfiRl5fn9kgAgCHI9QDFYjFdccUVWrlypaZNm6Z77rlHd999t9auXXvC86urqxUOh+NHZ2en2yMBAIYg1wOUm5urSZMmDXjssssu08cff3zC8wOBgEaPHj3gAAAMf64HaNasWWpvbx/w2N69ezVu3Di3twIAJDHXA/TAAw9o+/btWrlypfbv36+NGzdq3bp1Ki8vd3srAEAScz1AV111lRobG/X8889rypQpeuSRR1RXV6eSkhK3twIAJDFPfiT3DTfcoBtuuMGLpQEAwwT3ggMAmCBAAAATBAgAYIIAAQBMECAAgAlP3gUHDJbjOJ7v8cknn3i+x8MPP+zp+mlpaZ6uL0l33XWX53v4/fyVg//hCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmPBbD4ChzXEcT9fv6enxdH1Jqqqq8nyPDz/80NP116xZ4+n6khQKhTzfo6+vz9P1+/v7PV1fkr766ivP90hNTU3q9QeLKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATLgeoP7+ftXU1KigoEBpaWm65JJL9Mgjj3j+DxoBAMnF9TshrFq1Sg0NDdqwYYMmT56s3bt3a9GiRQoGg1q8eLHb2wEAkpTrAXr33Xd10003ad68eZKk8ePH6/nnn9fOnTvd3goAkMRc/xLcNddco+bmZu3du1eS9MEHH+idd95RcXHxCc+PRqOKRCIDDgDA8Of6FdCyZcsUiUQ0ceJEpaSkqL+/XytWrFBJSckJz6+trdWvf/1rt8cAAAxxrl8BvfTSS3ruuee0ceNGtba2asOGDfrtb3+rDRs2nPD86upqhcPh+NHZ2en2SACAIcj1K6ClS5dq2bJluvXWWyVJhYWFOnDggGpra1VWVnbc+YFAQIFAwO0xAABDnOtXQEeOHNGIEQOXTUlJUSwWc3srAEASc/0KaP78+VqxYoXy8/M1efJkvf/++3riiSd05513ur0VACCJuR6gNWvWqKamRj/96U/V09OjUCikH//4x1q+fLnbWwEAkpjrAcrIyFBdXZ3q6urcXhoAMIxwLzgAgAkCBAAwQYAAACYIEADABAECAJhw/V1wGF66u7s9Xb+8vNzT9SXptdde83yPH/zgB56uP2bMGE/Xl6SXX37Z8z1aWlo8Xf+TTz7xdH1J6urq8nyP73//+56uX1FR4en60Wh0UOdxBQQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJv/UAOHOxWMzzPerr6z1df9OmTZ6uL52b36fGxkZP1//Tn/7k6fqS5DiO53vk5+d7uv6kSZM8XV+SCgoKPN9j3Lhxnq7v8/mGxPpcAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMJB2jbtm2aP3++QqGQfD7fcf+Ow3EcLV++XLm5uUpLS1NRUZH27dvn1rwAgGEi4QD19vZq6tSpJ/0Hio8//rhWr16ttWvXaseOHUpPT9fcuXN19OjRsx4WADB8JHwnhOLiYhUXF5/wOcdxVFdXp1/+8pe66aabJEl//OMflZ2drU2bNunWW289u2kBAMOGq98D6ujoUFdXl4qKiuKPBYNBzZw5U++9994Jf000GlUkEhlwAACGP1cD1NXVJUnKzs4e8Hh2dnb8ua+rra1VMBiMH3l5eW6OBAAYoszfBVddXa1wOBw/Ojs7rUcCAJwDrgYoJydHktTd3T3g8e7u7vhzXxcIBDR69OgBBwBg+HM1QAUFBcrJyVFzc3P8sUgkoh07dujqq692cysAQJJL+F1whw8f1v79++Mfd3R0qK2tTZmZmcrPz1dFRYUeffRRTZgwQQUFBaqpqVEoFNKCBQvcnBsAkOQSDtDu3bs1Z86c+MeVlZWSpLKyMq1fv14/+9nP1Nvbq3vuuUdffPGFrr32Wr3xxhsaNWqUe1MDAJJewgGaPXv2KX9yos/n08MPP6yHH374rAYDAAxv5u+CAwCcnwgQAMAEAQIAmCBAAAATBAgAYCLhd8Fh6DjVuxHdEgwGPV2/pKTE0/UlnZO7a4wfP97T9QsLCz1dX5Iuvvhiz/cYO3asp+tnZGR4ur4kjRjB/7efTmpq6qDO43cSAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE37rAXDmUlJSPN+joqLC8z28di5+n7zm8/msRwBcxxUQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwkXCAtm3bpvnz5ysUCsnn82nTpk3x5/r6+lRVVaXCwkKlp6crFArpjjvu0MGDB92cGQAwDCQcoN7eXk2dOlX19fXHPXfkyBG1traqpqZGra2tevnll9Xe3q4bb7zRlWEBAMNHwndCKC4uVnFx8QmfCwaDampqGvDYU089pRkzZujjjz9Wfn7+mU0JABh2PP8eUDgcls/n05gxY7zeCgCQRDy9F9zRo0dVVVWl2267TaNHjz7hOdFoVNFoNP5xJBLxciQAwBDh2RVQX1+fbrnlFjmOo4aGhpOeV1tbq2AwGD/y8vK8GgkAMIR4EqBj8Tlw4ICamppOevUjSdXV1QqHw/Gjs7PTi5EAAEOM61+COxafffv2acuWLcrKyjrl+YFAQIFAwO0xAABDXMIBOnz4sPbv3x//uKOjQ21tbcrMzFRubq4WLlyo1tZWbd68Wf39/erq6pIkZWZmKjU11b3JAQBJzec4jpPIL3jrrbc0Z86c4x4vKyvTr371KxUUFJzw123ZskWzZ88+7fqRSETBYFDhcPiUX7rDufHVV19Zj3DW+IF0wLk12L/HE74Cmj17tk7VrAR7BgA4T3EvOACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATnt6MFMnP7+c/EQDe4AoIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMJB2jbtm2aP3++QqGQfD6fNm3adNJz7733Xvl8PtXV1Z3FiACA4SjhAPX29mrq1Kmqr68/5XmNjY3avn27QqHQGQ8HABi+/In+guLiYhUXF5/ynE8//VT333+/3nzzTc2bN++MhwMADF8JB+h0YrGYSktLtXTpUk2ePPm050ejUUWj0fjHkUjE7ZEAAEOQ629CWLVqlfx+vxYvXjyo82traxUMBuNHXl6e2yMBAIYgVwPU0tKiJ598UuvXr5fP5xvUr6murlY4HI4fnZ2dbo4EABiiXA3Q22+/rZ6eHuXn58vv98vv9+vAgQN68MEHNX78+BP+mkAgoNGjRw84AADDn6vfAyotLVVRUdGAx+bOnavS0lItWrTIza0AAEku4QAdPnxY+/fvj3/c0dGhtrY2ZWZmKj8/X1lZWQPOHzlypHJycnTppZee/bQAgGEj4QDt3r1bc+bMiX9cWVkpSSorK9P69etdGwwAMLwlHKDZs2fLcZxBn/+vf/0r0S0AAOcB7gUHADBBgAAAJggQAMAEAQIAmHD9XnBn69gbHLgnHAAkp2N/f5/uDWtDLkCHDh2SJO4JBwBJ7tChQwoGgyd93uck8p7qcyAWi+ngwYPKyMgY9P3kIpGI8vLy1NnZmbS38uE1DB3D4XXwGoaG4fAapMRfh+M4OnTokEKhkEaMOPl3eobcFdCIESN00UUXndGvHQ73kuM1DB3D4XXwGoaG4fAapMRex6mufI7hTQgAABMECABgYlgEKBAI6KGHHlIgELAe5YzxGoaO4fA6eA1Dw3B4DZJ3r2PIvQkBAHB+GBZXQACA5EOAAAAmCBAAwAQBAgCYSPoA1dfXa/z48Ro1apRmzpypnTt3Wo+UkNraWl111VXKyMjQ2LFjtWDBArW3t1uPdVYee+wx+Xw+VVRUWI+SkE8//VS33367srKylJaWpsLCQu3evdt6rEHr7+9XTU2NCgoKlJaWpksuuUSPPPJIQj9A0sK2bds0f/58hUIh+Xw+bdq0acDzjuNo+fLlys3NVVpamoqKirRv3z6bYU/iVK+hr69PVVVVKiwsVHp6ukKhkO644w4dPHjQbuATON2fw/937733yufzqa6u7qz2TOoAvfjii6qsrNRDDz2k1tZWTZ06VXPnzlVPT4/1aIO2detWlZeXa/v27WpqalJfX5+uv/569fb2Wo92Rnbt2qWnn35al19+ufUoCfn88881a9YsjRw5Uq+//rr+/ve/63e/+50uuOAC69EGbdWqVWpoaNBTTz2lf/zjH1q1apUef/xxrVmzxnq0U+rt7dXUqVNVX19/wucff/xxrV69WmvXrtWOHTuUnp6uuXPn6ujRo+d40pM71Ws4cuSIWltbVVNTo9bWVr388stqb2/XjTfeaDDpyZ3uz+GYxsZGbd++XaFQ6Ow3dZLYjBkznPLy8vjH/f39TigUcmpraw2nOjs9PT2OJGfr1q3WoyTs0KFDzoQJE5ympibne9/7nrNkyRLrkQatqqrKufbaa63HOCvz5s1z7rzzzgGP/fCHP3RKSkqMJkqcJKexsTH+cSwWc3Jycpzf/OY38ce++OILJxAIOM8//7zBhKf39ddwIjt37nQkOQcOHDg3QyXoZK/hk08+cb797W87e/bsccaNG+f8/ve/P6t9kvYK6Msvv1RLS4uKiorij40YMUJFRUV67733DCc7O+FwWJKUmZlpPEniysvLNW/evAF/Jsni1Vdf1fTp03XzzTdr7NixmjZtmp555hnrsRJyzTXXqLm5WXv37pUkffDBB3rnnXdUXFxsPNmZ6+joUFdX14D/poLBoGbOnJn0n+c+n09jxoyxHmXQYrGYSktLtXTpUk2ePNmVNYfczUgH67PPPlN/f7+ys7MHPJ6dna1//vOfRlOdnVgspoqKCs2aNUtTpkyxHichL7zwglpbW7Vr1y7rUc7IRx99pIaGBlVWVurnP/+5du3apcWLFys1NVVlZWXW4w3KsmXLFIlENHHiRKWkpKi/v18rVqxQSUmJ9WhnrKurS5JO+Hl+7Llkc/ToUVVVVem2225LqhuUrlq1Sn6/X4sXL3ZtzaQN0HBUXl6uPXv26J133rEeJSGdnZ1asmSJmpqaNGrUKOtxzkgsFtP06dO1cuVKSdK0adO0Z88erV27NmkC9NJLL+m5557Txo0bNXnyZLW1tamiokKhUChpXsNw19fXp1tuuUWO46ihocF6nEFraWnRk08+qdbW1kH/mJzBSNovwV144YVKSUlRd3f3gMe7u7uVk5NjNNWZu++++7R582Zt2bLljH8chZWWlhb19PToiiuukN/vl9/v19atW7V69Wr5/X719/dbj3haubm5mjRp0oDHLrvsMn388cdGEyVu6dKlWrZsmW699VYVFhaqtLRUDzzwgGpra61HO2PHPpeHw+f5sfgcOHBATU1NSXX18/bbb6unp0f5+fnxz/EDBw7owQcf1Pjx48943aQNUGpqqq688ko1NzfHH4vFYmpubtbVV19tOFliHMfRfffdp8bGRv31r39VQUGB9UgJu+666/Thhx+qra0tfkyfPl0lJSVqa2tTSkqK9YinNWvWrOPe/r53716NGzfOaKLEHTly5Lgf/pWSkqJYLGY00dkrKChQTk7OgM/zSCSiHTt2JNXn+bH47Nu3T3/5y1+UlZVlPVJCSktL9be//W3A53goFNLSpUv15ptvnvG6Sf0luMrKSpWVlWn69OmaMWOG6urq1Nvbq0WLFlmPNmjl5eXauHGjXnnlFWVkZMS/rh0MBpWWlmY83eBkZGQc9z2r9PR0ZWVlJc33sh544AFdc801WrlypW655Rbt3LlT69at07p166xHG7T58+drxYoVys/P1+TJk/X+++/riSee0J133mk92ikdPnxY+/fvj3/c0dGhtrY2ZWZmKj8/XxUVFXr00Uc1YcIEFRQUqKamRqFQSAsWLLAb+mtO9Rpyc3O1cOFCtba2avPmzerv749/nmdmZio1NdVq7AFO9+fw9WiOHDlSOTk5uvTSS89807N6D90QsGbNGic/P99JTU11ZsyY4Wzfvt16pIRIOuHx7LPPWo92VpLtbdiO4zivvfaaM2XKFCcQCDgTJ0501q1bZz1SQiKRiLNkyRInPz/fGTVqlHPxxRc7v/jFL5xoNGo92ilt2bLlhJ8DZWVljuP8963YNTU1TnZ2thMIBJzrrrvOaW9vtx36a071Gjo6Ok76eb5lyxbr0eNO9+fwdW68DZsfxwAAMJG03wMCACQ3AgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMDE/wGCpihoW+pRwgAAAABJRU5ErkJggg==", - "text/plain": [ - "
    " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2\n" - ] - } - ], - "source": [ - "random_sample = torch.randint(low=0, high=len(train_dataset), size=(1,)).item()\n", - "\n", - "plt.imshow(train_dataset[random_sample][0].squeeze(0), cmap='Greys')\n", - "plt.show()\n", - "\n", - "print(train_dataset[random_sample][1])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Define model" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "class MPS_DMRG(tk.models.MPSLayer):\n", - " \n", - " def __init__(self, *args, **kwargs):\n", - " super().__init__(*args, **kwargs)\n", - " \n", - " self.parameterize(set_param=False, override=True)\n", - " \n", - " self.out_node.get_axis('input').name = 'output'\n", - " \n", - " self.block_position = None\n", - " self.block_length = None\n", - " \n", - " @property\n", - " def block(self):\n", - " if self.block_position is not None:\n", - " return self.mats_env[self.block_position]\n", - " return None\n", - " \n", - " def merge_block(self, block_position, block_length):\n", - " if block_position + block_length > self.n_features:\n", - " raise ValueError(\n", - " f'Last position of the block ({block_position + block_length}) '\n", - " f'exceeds the range of MPS sites ({self.n_features})')\n", - " elif block_length < 1:\n", - " raise ValueError(\n", - " '`block_length` should be greater than or equal to 1')\n", - " \n", - " if self.block_position is not None:\n", - " raise ValueError(\n", - " 'Cannot create block if there is already a merged block')\n", - " \n", - " block_nodes = self.mats_env[block_position:(block_position + block_length)]\n", - " \n", - " block = block_nodes[0]\n", - " for node in block_nodes[1:]:\n", - " block = tk.contract_between_(block, node)\n", - " block = block.parameterize(True)\n", - " block.name = 'block'\n", - " \n", - " self.block_position = block_position\n", - " self.block_length = block_length\n", - " self._mats_env = self._mats_env[:block_position] + [block] + \\\n", - " self._mats_env[(block_position + block_length):]\n", - " \n", - " def unmerge_block(self, side='right', rank=None, cum_percentage=None):\n", - " block = self.block\n", - " \n", - " block_nodes = []\n", - " for i in range(self.block_length - 1):\n", - " node1_axes = block.axes[:2]\n", - " node2_axes = block.axes[2:]\n", - " \n", - " node, block = tk.split_(block,\n", - " node1_axes,\n", - " node2_axes,\n", - " side=side,\n", - " rank=rank,\n", - " cum_percentage=cum_percentage)\n", - " block.get_axis('split').name = 'left'\n", - " node.get_axis('split').name = 'right'\n", - " node.name = f'mats_env_({self.block_position + i})'\n", - " \n", - " block_nodes.append(node)\n", - " \n", - " block.name = f'mats_env_({self.block_position + i + 1})'\n", - " block_nodes.append(block)\n", - " \n", - " self._mats_env = self._mats_env[:self.block_position] + block_nodes + \\\n", - " self._mats_env[(self.block_position + 1):]\n", - " \n", - " self.block_position = None\n", - " self.block_length = None\n", - " \n", - " def contract(self):\n", - " result_mats = []\n", - " for node in self.mats_env:\n", - " while any(['input' in name for name in node.axes_names]):\n", - " for axis in node.axes:\n", - " if 'input' in axis.name:\n", - " data_node = node.neighbours(axis)\n", - " node = node @ data_node\n", - " break\n", - " result_mats.append(node)\n", - " \n", - " result_mats = [self.left_node] + result_mats + [self.right_node]\n", - " \n", - " result = result_mats[0]\n", - " for node in result_mats[1:]:\n", - " result @= node\n", - " \n", - " return result" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "# Model hyperparameters\n", - "embedding_dim = 2\n", - "output_dim = num_classes\n", - "bond_dim = 50\n", - "init_method = 'unit'\n", - "block_length = 2\n", - "cum_percentage = 0.98" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "# Initialize network\n", - "model_name = 'mps_dmrg'\n", - "mps = MPS_DMRG(n_features=input_size + 1,\n", - " in_dim=embedding_dim,\n", - " out_dim=num_classes,\n", - " bond_dim=bond_dim,\n", - " boundary='obc',\n", - " init_method=init_method,\n", - " device=device)\n", - "\n", - "# Important to set data nodes before merging nodes\n", - "mps.set_data_nodes()" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "def embedding(x):\n", - " x = tk.embeddings.unit(x, dim=embedding_dim)\n", - " return x" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "# Hyperparameters\n", - "learning_rate = 1e-3\n", - "weight_decay = 1e-8\n", - "num_epochs = 100\n", - "move_block_epochs = 100\n", - "\n", - "# Loss and optimizer\n", - "criterion = nn.CrossEntropyLoss()" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "# Check accuracy on training & test to see how good our model is\n", - "def check_accuracy(loader, model):\n", - " num_correct = 0\n", - " num_samples = 0\n", - " model.eval()\n", - " \n", - " with torch.no_grad():\n", - " for x, y in loader:\n", - " x = x.to(device)\n", - " y = y.to(device)\n", - " x = x.reshape(x.shape[0], -1)\n", - " \n", - " scores = model(embedding(x))\n", - " _, predictions = scores.max(1)\n", - " num_correct += (predictions == y).sum()\n", - " num_samples += predictions.size(0)\n", - " \n", - " accuracy = float(num_correct) / float(num_samples) * 100\n", - " model.train()\n", - " return accuracy" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Epoch 1 (block_position=9, direction=1)=> Train. Acc.: 12.75, Test Acc.: 13.15\n", - "* Epoch 2 (block_position=18, direction=1)=> Train. Acc.: 12.88, Test Acc.: 13.35\n", - "* Epoch 3 (block_position=27, direction=1)=> Train. Acc.: 13.07, Test Acc.: 13.63\n", - "* Epoch 4 (block_position=36, direction=1)=> Train. Acc.: 13.25, Test Acc.: 13.20\n", - "* Epoch 5 (block_position=45, direction=1)=> Train. Acc.: 15.93, Test Acc.: 15.27\n", - "* Epoch 6 (block_position=54, direction=1)=> Train. Acc.: 17.78, Test Acc.: 16.93\n", - "* Epoch 7 (block_position=63, direction=1)=> Train. Acc.: 18.97, Test Acc.: 17.51\n", - "* Epoch 8 (block_position=72, direction=1)=> Train. Acc.: 21.57, Test Acc.: 18.85\n", - "* Epoch 9 (block_position=81, direction=1)=> Train. Acc.: 22.80, Test Acc.: 19.67\n", - "* Epoch 10 (block_position=90, direction=1)=> Train. Acc.: 24.91, Test Acc.: 21.50\n", - "* Epoch 11 (block_position=99, direction=1)=> Train. Acc.: 27.40, Test Acc.: 23.48\n", - "* Epoch 12 (block_position=108, direction=1)=> Train. Acc.: 29.46, Test Acc.: 25.31\n", - "* Epoch 13 (block_position=117, direction=1)=> Train. Acc.: 38.99, Test Acc.: 33.37\n", - "* Epoch 14 (block_position=126, direction=1)=> Train. Acc.: 41.44, Test Acc.: 35.75\n", - "* Epoch 15 (block_position=135, direction=1)=> Train. Acc.: 44.87, Test Acc.: 38.75\n", - "* Epoch 16 (block_position=144, direction=1)=> Train. Acc.: 47.04, Test Acc.: 41.17\n", - "* Epoch 17 (block_position=153, direction=1)=> Train. Acc.: 48.40, Test Acc.: 41.93\n", - "* Epoch 18 (block_position=162, direction=1)=> Train. Acc.: 49.52, Test Acc.: 42.38\n", - "* Epoch 19 (block_position=171, direction=1)=> Train. Acc.: 50.02, Test Acc.: 42.80\n", - "* Epoch 20 (block_position=180, direction=1)=> Train. Acc.: 50.37, Test Acc.: 43.01\n", - "* Epoch 21 (block_position=189, direction=1)=> Train. Acc.: 50.22, Test Acc.: 42.90\n", - "* Epoch 22 (block_position=198, direction=1)=> Train. Acc.: 50.56, Test Acc.: 43.45\n", - "* Epoch 23 (block_position=207, direction=1)=> Train. Acc.: 50.45, Test Acc.: 43.09\n", - "* Epoch 24 (block_position=216, direction=1)=> Train. Acc.: 50.38, Test Acc.: 43.32\n", - "* Epoch 25 (block_position=223, direction=-1)=> Train. Acc.: 50.67, Test Acc.: 43.13\n", - "* Epoch 26 (block_position=214, direction=-1)=> Train. Acc.: 50.41, Test Acc.: 43.18\n", - "* Epoch 27 (block_position=205, direction=-1)=> Train. Acc.: 50.31, Test Acc.: 42.80\n", - "* Epoch 28 (block_position=196, direction=-1)=> Train. Acc.: 50.62, Test Acc.: 43.36\n", - "* Epoch 29 (block_position=187, direction=-1)=> Train. Acc.: 50.35, Test Acc.: 43.42\n", - "* Epoch 30 (block_position=178, direction=-1)=> Train. Acc.: 50.87, Test Acc.: 43.31\n", - "* Epoch 31 (block_position=169, direction=-1)=> Train. Acc.: 51.33, Test Acc.: 43.28\n", - "* Epoch 32 (block_position=160, direction=-1)=> Train. Acc.: 51.49, Test Acc.: 42.85\n", - "* Epoch 33 (block_position=151, direction=-1)=> Train. Acc.: 52.19, Test Acc.: 43.20\n", - "* Epoch 34 (block_position=142, direction=-1)=> Train. Acc.: 52.50, Test Acc.: 43.21\n", - "* Epoch 35 (block_position=133, direction=-1)=> Train. Acc.: 53.17, Test Acc.: 43.05\n", - "* Epoch 36 (block_position=124, direction=-1)=> Train. Acc.: 54.22, Test Acc.: 43.48\n", - "* Epoch 37 (block_position=115, direction=-1)=> Train. Acc.: 55.39, Test Acc.: 44.47\n", - "* Epoch 38 (block_position=106, direction=-1)=> Train. Acc.: 59.71, Test Acc.: 48.28\n", - "* Epoch 39 (block_position=97, direction=-1)=> Train. Acc.: 61.19, Test Acc.: 49.97\n", - "* Epoch 40 (block_position=88, direction=-1)=> Train. Acc.: 63.27, Test Acc.: 52.01\n", - "* Epoch 41 (block_position=79, direction=-1)=> Train. Acc.: 65.81, Test Acc.: 54.78\n", - "* Epoch 42 (block_position=70, direction=-1)=> Train. Acc.: 67.14, Test Acc.: 56.31\n", - "* Epoch 43 (block_position=61, direction=-1)=> Train. Acc.: 69.11, Test Acc.: 58.45\n", - "* Epoch 44 (block_position=52, direction=-1)=> Train. Acc.: 69.54, Test Acc.: 59.10\n", - "* Epoch 45 (block_position=43, direction=-1)=> Train. Acc.: 69.92, Test Acc.: 59.63\n", - "* Epoch 46 (block_position=34, direction=-1)=> Train. Acc.: 69.91, Test Acc.: 59.61\n", - "* Epoch 47 (block_position=25, direction=-1)=> Train. Acc.: 69.96, Test Acc.: 59.37\n", - "* Epoch 48 (block_position=16, direction=-1)=> Train. Acc.: 69.94, Test Acc.: 59.60\n", - "* Epoch 49 (block_position=7, direction=-1)=> Train. Acc.: 69.57, Test Acc.: 59.07\n", - "* Epoch 50 (block_position=2, direction=1)=> Train. Acc.: 70.16, Test Acc.: 59.82\n", - "* Epoch 51 (block_position=11, direction=1)=> Train. Acc.: 69.99, Test Acc.: 59.69\n", - "* Epoch 52 (block_position=20, direction=1)=> Train. Acc.: 69.91, Test Acc.: 59.58\n", - "* Epoch 53 (block_position=29, direction=1)=> Train. Acc.: 69.87, Test Acc.: 59.23\n", - "* Epoch 54 (block_position=38, direction=1)=> Train. Acc.: 69.94, Test Acc.: 59.44\n", - "* Epoch 55 (block_position=47, direction=1)=> Train. Acc.: 70.32, Test Acc.: 59.61\n", - "* Epoch 56 (block_position=56, direction=1)=> Train. Acc.: 70.44, Test Acc.: 59.36\n", - "* Epoch 57 (block_position=65, direction=1)=> Train. Acc.: 70.80, Test Acc.: 59.50\n", - "* Epoch 58 (block_position=74, direction=1)=> Train. Acc.: 71.57, Test Acc.: 59.42\n", - "* Epoch 59 (block_position=83, direction=1)=> Train. Acc.: 72.07, Test Acc.: 59.73\n", - "* Epoch 60 (block_position=92, direction=1)=> Train. Acc.: 72.71, Test Acc.: 60.39\n", - "* Epoch 61 (block_position=101, direction=1)=> Train. Acc.: 73.74, Test Acc.: 61.01\n", - "* Epoch 62 (block_position=110, direction=1)=> Train. Acc.: 74.56, Test Acc.: 61.42\n", - "* Epoch 63 (block_position=119, direction=1)=> Train. Acc.: 77.45, Test Acc.: 64.26\n", - "* Epoch 64 (block_position=128, direction=1)=> Train. Acc.: 78.30, Test Acc.: 65.43\n", - "* Epoch 65 (block_position=137, direction=1)=> Train. Acc.: 79.33, Test Acc.: 66.56\n", - "* Epoch 66 (block_position=146, direction=1)=> Train. Acc.: 80.41, Test Acc.: 67.67\n", - "* Epoch 67 (block_position=155, direction=1)=> Train. Acc.: 81.28, Test Acc.: 68.59\n", - "* Epoch 68 (block_position=164, direction=1)=> Train. Acc.: 82.38, Test Acc.: 69.93\n", - "* Epoch 69 (block_position=173, direction=1)=> Train. Acc.: 82.69, Test Acc.: 70.40\n", - "* Epoch 70 (block_position=182, direction=1)=> Train. Acc.: 82.86, Test Acc.: 69.84\n", - "* Epoch 71 (block_position=191, direction=1)=> Train. Acc.: 82.85, Test Acc.: 70.54\n", - "* Epoch 72 (block_position=200, direction=1)=> Train. Acc.: 82.70, Test Acc.: 70.19\n", - "* Epoch 73 (block_position=209, direction=1)=> Train. Acc.: 82.95, Test Acc.: 70.56\n", - "* Epoch 74 (block_position=218, direction=1)=> Train. Acc.: 82.81, Test Acc.: 70.34\n", - "* Epoch 75 (block_position=221, direction=-1)=> Train. Acc.: 83.05, Test Acc.: 70.74\n", - "* Epoch 76 (block_position=212, direction=-1)=> Train. Acc.: 82.64, Test Acc.: 70.36\n", - "* Epoch 77 (block_position=203, direction=-1)=> Train. Acc.: 82.80, Test Acc.: 70.40\n", - "* Epoch 78 (block_position=194, direction=-1)=> Train. Acc.: 83.10, Test Acc.: 70.45\n", - "* Epoch 79 (block_position=185, direction=-1)=> Train. Acc.: 82.98, Test Acc.: 70.27\n", - "* Epoch 80 (block_position=176, direction=-1)=> Train. Acc.: 83.10, Test Acc.: 70.23\n", - "* Epoch 81 (block_position=167, direction=-1)=> Train. Acc.: 83.75, Test Acc.: 70.21\n", - "* Epoch 82 (block_position=158, direction=-1)=> Train. Acc.: 84.00, Test Acc.: 70.05\n", - "* Epoch 83 (block_position=149, direction=-1)=> Train. Acc.: 84.41, Test Acc.: 70.18\n", - "* Epoch 84 (block_position=140, direction=-1)=> Train. Acc.: 84.52, Test Acc.: 70.06\n", - "* Epoch 85 (block_position=131, direction=-1)=> Train. Acc.: 84.89, Test Acc.: 69.79\n", - "* Epoch 86 (block_position=122, direction=-1)=> Train. Acc.: 85.35, Test Acc.: 69.74\n", - "* Epoch 87 (block_position=113, direction=-1)=> Train. Acc.: 85.68, Test Acc.: 70.67\n", - "* Epoch 88 (block_position=104, direction=-1)=> Train. Acc.: 87.23, Test Acc.: 71.58\n", - "* Epoch 89 (block_position=95, direction=-1)=> Train. Acc.: 87.77, Test Acc.: 72.49\n", - "* Epoch 90 (block_position=86, direction=-1)=> Train. Acc.: 88.09, Test Acc.: 72.60\n", - "* Epoch 91 (block_position=77, direction=-1)=> Train. Acc.: 89.02, Test Acc.: 73.84\n", - "* Epoch 92 (block_position=68, direction=-1)=> Train. Acc.: 89.36, Test Acc.: 74.31\n", - "* Epoch 93 (block_position=59, direction=-1)=> Train. Acc.: 89.81, Test Acc.: 75.01\n", - "* Epoch 94 (block_position=50, direction=-1)=> Train. Acc.: 89.46, Test Acc.: 75.64\n", - "* Epoch 95 (block_position=41, direction=-1)=> Train. Acc.: 89.64, Test Acc.: 75.40\n", - "* Epoch 96 (block_position=32, direction=-1)=> Train. Acc.: 89.38, Test Acc.: 74.92\n", - "* Epoch 97 (block_position=23, direction=-1)=> Train. Acc.: 89.54, Test Acc.: 75.21\n", - "* Epoch 98 (block_position=14, direction=-1)=> Train. Acc.: 89.53, Test Acc.: 75.20\n", - "* Epoch 99 (block_position=5, direction=-1)=> Train. Acc.: 88.86, Test Acc.: 74.91\n", - "* Epoch 100 (block_position=4, direction=1)=> Train. Acc.: 89.37, Test Acc.: 75.13\n" - ] - } - ], - "source": [ - "# Train network\n", - "block_position = 0\n", - "direction = 1\n", - "mps.merge_block(block_position, block_length)\n", - "mps.trace(torch.zeros(1, input_size, embedding_dim, device=device))\n", - "optimizer = optim.Adam(mps.parameters(),\n", - " lr=learning_rate,\n", - " weight_decay=weight_decay)\n", - "\n", - "for epoch in range(num_epochs):\n", - " for batch_idx, (data, targets) in enumerate(train_loader):\n", - " # Get data to cuda if possible\n", - " data = data.to(device)\n", - " targets = targets.to(device)\n", - " \n", - " # Get to correct shape\n", - " data = data.reshape(data.shape[0], -1)\n", - " \n", - " # Forward\n", - " scores = mps(embedding(data))\n", - " loss = criterion(scores, targets)\n", - " \n", - " # Backward\n", - " optimizer.zero_grad()\n", - " loss.backward()\n", - " \n", - " # Gradient descent\n", - " optimizer.step()\n", - " \n", - " if (batch_idx + 1) % move_block_epochs == 0:\n", - " if block_position + direction + block_length > mps.n_features:\n", - " direction *= -1\n", - " if block_position + direction < 0:\n", - " direction *= -1\n", - " if block_length == mps.n_features:\n", - " direction = 0\n", - " \n", - " if direction >= 0:\n", - " mps.unmerge_block(side='left',\n", - " rank=bond_dim, \n", - " cum_percentage=cum_percentage)\n", - " else:\n", - " mps.unmerge_block(side='right',\n", - " rank=bond_dim, \n", - " cum_percentage=cum_percentage)\n", - " \n", - " block_position += direction\n", - " mps.merge_block(block_position, block_length)\n", - " mps.trace(torch.zeros(1, input_size, embedding_dim, device=device))\n", - " optimizer = optim.Adam(mps.parameters(),\n", - " lr=learning_rate,\n", - " weight_decay=weight_decay)\n", - " \n", - " train_acc = check_accuracy(train_loader, mps)\n", - " test_acc = check_accuracy(test_loader, mps)\n", - " \n", - " print(f'* Epoch {epoch + 1:<3} ({block_position=}, {direction=})=>'\n", - " f' Train. Acc.: {train_acc:.2f},'\n", - " f' Test Acc.: {test_acc:.2f}')\n", - "\n", - "# Reset before saving the model\n", - "mps.reset()\n", - "torch.save(mps.state_dict(), f'models/{model_name}_{dataset_name}.pt')" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "mps.unmerge_block(rank=bond_dim, cum_percentage=cum_percentage)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "mps.update_bond_dim()" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcG0lEQVR4nO3df2zX9Z3A8VdraUGh7QBp4WgF75zoHNzGZv1m22WHPSshRkf/cIbkmCEz86o56X6cTW6yLbvA7RLdeal4WRjckmNs/KELemNxddR4tkyrZv5YiCx4sMG3bJq2gKMg/dwfC9/kO1AptO/+4PFIPinfz+fTz/f19bN++9yn3/ZbkmVZFgAAiZSO9QAAwMVFfAAASYkPACAp8QEAJCU+AICkxAcAkJT4AACSEh8AQFJlYz3AnxsaGoqDBw/GjBkzoqSkZKzHAQDOQZZlceTIkZg3b16Ulr7/tY1xFx8HDx6Murq6sR4DADgPBw4ciPnz57/vPuMuPmbMmBERfxq+srJyjKcBAM7FwMBA1NXVFb6Pv59xFx+nf9RSWVkpPgBggjmXl0x4wSkAkJT4AACSEh8AQFLiAwBISnwAAEmJDwAgKfEBACQlPgCApMQHAJCU+AAAkhpWfHzjG9+IkpKSomXRokWF7cePH4+WlpaYNWtWTJ8+PZqbm6O3t3fEhwYAJq5hX/n4yEc+EocOHSoszz77bGHb2rVrY8eOHbF9+/bo7OyMgwcPxsqVK0d0YABgYhv2G8uVlZVFbW3tGev7+/tj06ZNsXXr1li2bFlERGzevDmuueaa6O7ujhtuuOHCpwUAJrxhX/l44403Yt68eXHllVfGqlWrYv/+/RER0dPTEydPnozGxsbCvosWLYr6+vro6up6z+MNDg7GwMBA0QIATF7DuvLR0NAQW7ZsiauvvjoOHToU3/zmN+Mzn/lMvPrqq5HP56O8vDyqq6uLPqempiby+fx7HnP9+vXxzW9+87yGT2nB/U+O9Qgwbry5YUUsuP9JH31834+k8+aGFWM9wrAMKz6WL19e+PfixYujoaEhrrjiivjxj38c06ZNO68B2traorW1tXB7YGAg6urqzutYAMD4d0G/altdXR0f/vCHY+/evVFbWxsnTpyIvr6+on16e3vP+hqR0yoqKqKysrJoAQAmrwuKj6NHj8ZvfvObmDt3bixdujSmTJkSHR0dhe179uyJ/fv3Ry6Xu+BBAYDJYVg/dvnKV74St9xyS1xxxRVx8ODBWLduXVxyySVxxx13RFVVVaxZsyZaW1tj5syZUVlZGffee2/kcjm/6QIAFAwrPn7729/GHXfcEW+99VZcfvnl8elPfzq6u7vj8ssvj4iIhx56KEpLS6O5uTkGBwejqakpHnnkkVEZHACYmIYVH9u2bXvf7VOnTo329vZob2+/oKEAgMnLe7sAAEmJDwAgKfEBACQlPgCApMQHAJCU+AAAkhIfAEBSw/o7Hxcj78wIACPLlQ8AICnxAQAkJT4AgKTEBwCQlPgAAJISHwBAUuIDAEhKfAAASYkPACAp8QEAJCU+AICkxAcAkJT4AACSEh8AQFLiAwBISnwAAEmJDwAgKfEBACQlPgCApMQHAJCU+AAAkhIfAEBS4gMASEp8AABJiQ8AICnxAQAkJT4AgKTEBwCQVNlYDzBeLbj/ybEeAQDOyenvWW9uWDHGk5wbVz4AgKTEBwCQlPgAAJISHwBAUuIDAEhKfAAASYkPACAp8QEAJCU+AICkxAcAkJT4AACSEh8AQFLiAwBISnwAAEmJDwAgKfEBACQlPgCApMQHAJCU+AAAkhIfAEBS4gMASEp8AABJXVB8bNiwIUpKSuK+++4rrDt+/Hi0tLTErFmzYvr06dHc3By9vb0XOicAMEmcd3w8//zz8Z//+Z+xePHiovVr166NHTt2xPbt26OzszMOHjwYK1euvOBBAYDJ4bzi4+jRo7Fq1ar43ve+Fx/60IcK6/v7+2PTpk3x4IMPxrJly2Lp0qWxefPmeO6556K7u3vEhgYAJq7zio+WlpZYsWJFNDY2Fq3v6emJkydPFq1ftGhR1NfXR1dX14VNCgBMCmXD/YRt27bFiy++GM8///wZ2/L5fJSXl0d1dXXR+pqamsjn82c93uDgYAwODhZuDwwMDHckAGACGdaVjwMHDsQ//uM/xn//93/H1KlTR2SA9evXR1VVVWGpq6sbkeMCAOPTsOKjp6cnDh8+HB//+MejrKwsysrKorOzMx5++OEoKyuLmpqaOHHiRPT19RV9Xm9vb9TW1p71mG1tbdHf319YDhw4cN4PBgAY/4b1Y5cbb7wxXnnllaJ1d955ZyxatCj+6Z/+Kerq6mLKlCnR0dERzc3NERGxZ8+e2L9/f+RyubMes6KiIioqKs5zfABgohlWfMyYMSOuu+66onWXXXZZzJo1q7B+zZo10draGjNnzozKysq49957I5fLxQ033DByUwMAE9awX3D6QR566KEoLS2N5ubmGBwcjKampnjkkUdG+m4AgAnqguNj165dRbenTp0a7e3t0d7efqGHBgAmIe/tAgAkJT4AgKTEBwCQlPgAAJISHwBAUuIDAEhKfAAASYkPACAp8QEAJCU+AICkxAcAkJT4AACSEh8AQFLiAwBISnwAAEmJDwAgKfEBACQlPgCApMQHAJCU+AAAkhIfAEBS4gMASEp8AABJiQ8AICnxAQAkJT4AgKTEBwCQlPgAAJISHwBAUuIDAEhKfAAASYkPACAp8QEAJCU+AICkxAcAkJT4AACSEh8AQFLiAwBISnwAAEmJDwAgKfEBACQlPgCApMQHAJCU+AAAkhIfAEBS4gMASEp8AABJiQ8AICnxAQAkJT4AgKTEBwCQlPgAAJISHwBAUuIDAEhKfAAASYkPACAp8QEAJCU+AICkxAcAkJT4AACSGlZ8bNy4MRYvXhyVlZVRWVkZuVwufvrTnxa2Hz9+PFpaWmLWrFkxffr0aG5ujt7e3hEfGgCYuIYVH/Pnz48NGzZET09PvPDCC7Fs2bK49dZb47XXXouIiLVr18aOHTti+/bt0dnZGQcPHoyVK1eOyuAAwMRUNpydb7nllqLb//Iv/xIbN26M7u7umD9/fmzatCm2bt0ay5Yti4iIzZs3xzXXXBPd3d1xww03jNzUAMCEdd6v+Th16lRs27Ytjh07FrlcLnp6euLkyZPR2NhY2GfRokVRX18fXV1d73mcwcHBGBgYKFoAgMlr2PHxyiuvxPTp06OioiK+9KUvxWOPPRbXXntt5PP5KC8vj+rq6qL9a2pqIp/Pv+fx1q9fH1VVVYWlrq5u2A8CAJg4hh0fV199dbz88suxe/fuuPvuu2P16tXx+uuvn/cAbW1t0d/fX1gOHDhw3scCAMa/Yb3mIyKivLw8/uqv/ioiIpYuXRrPP/98/Pu//3vcfvvtceLEiejr6yu6+tHb2xu1tbXvebyKioqoqKgY/uQAwIR0wX/nY2hoKAYHB2Pp0qUxZcqU6OjoKGzbs2dP7N+/P3K53IXeDQAwSQzrykdbW1ssX7486uvr48iRI7F169bYtWtX/OxnP4uqqqpYs2ZNtLa2xsyZM6OysjLuvffeyOVyftMFACgYVnwcPnw4/v7v/z4OHToUVVVVsXjx4vjZz34Wf/d3fxcREQ899FCUlpZGc3NzDA4ORlNTUzzyyCOjMjgAMDENKz42bdr0vtunTp0a7e3t0d7efkFDAQCTl/d2AYBJYsH9T8aC+58c6zE+kPgAAJISHwBAUuIDAEhKfAAASYkPACAp8QEAJCU+AICkxAcAkJT4AACSEh8AQFLiAwBISnwAAEmJDwAgqbKxHmC8mQjvBggAE5krHwBAUuIDAEhKfAAASYkPACAp8QEAJCU+AICkxAcAkJT4AACSEh8AQFLiAwBISnwAAEmJDwAgKfEBACQlPgCApMQHAJCU+AAAkhIfAEBS4gMASEp8AABJiQ8AICnxAQAkJT4AgKTEBwCQlPgAAJISHwBAUuIDAEhKfAAASYkPACAp8QEAJCU+AICkxAcAkJT4AACSEh8AQFLiAwBISnwAAEmJDwAgKfEBACQlPgCApMQHAJCU+AAAkhIfAEBS4gMASEp8AABJiQ8AIKlhxcf69evjk5/8ZMyYMSPmzJkTt912W+zZs6don+PHj0dLS0vMmjUrpk+fHs3NzdHb2zuiQwMAE9ew4qOzszNaWlqiu7s7nnrqqTh58mTcdNNNcezYscI+a9eujR07dsT27dujs7MzDh48GCtXrhzxwQGAialsODvv3Lmz6PaWLVtizpw50dPTE3/zN38T/f39sWnTpti6dWssW7YsIiI2b94c11xzTXR3d8cNN9wwcpMDABPSBb3mo7+/PyIiZs6cGRERPT09cfLkyWhsbCzss2jRoqivr4+urq6zHmNwcDAGBgaKFgBg8jrv+BgaGor77rsvPvWpT8V1110XERH5fD7Ky8ujurq6aN+amprI5/NnPc769eujqqqqsNTV1Z3vSADABHDe8dHS0hKvvvpqbNu27YIGaGtri/7+/sJy4MCBCzoeADC+Des1H6fdc8898cQTT8QzzzwT8+fPL6yvra2NEydORF9fX9HVj97e3qitrT3rsSoqKqKiouJ8xgAAJqBhXfnIsizuueeeeOyxx+Lpp5+OhQsXFm1funRpTJkyJTo6Ogrr9uzZE/v3749cLjcyEwMAE9qwrny0tLTE1q1b4yc/+UnMmDGj8DqOqqqqmDZtWlRVVcWaNWuitbU1Zs6cGZWVlXHvvfdGLpfzmy4AQEQMMz42btwYERGf/exni9Zv3rw5vvCFL0RExEMPPRSlpaXR3Nwcg4OD0dTUFI888siIDAsATHzDio8syz5wn6lTp0Z7e3u0t7ef91AAwOTlvV0AgKTEBwCQlPgAAJISHwBAUuIDAEhKfAAASYkPACAp8QEAJCU+AICkxAcAkJT4AACSEh8AQFLiAwBISnwAAEmJDwAgKfEBACQlPgCApMQHAJCU+AAAkhIfAEBS4gMASEp8AABJlY31AOPFgvufHOsRAGBEnP6e9uaGFWM8ydm58gEAJCU+AICkxAcAkJT4AACSEh8AQFLiAwBISnwAAEmJDwAgKfEBACQlPgCApMQHAJCU+AAAkhIfAEBS4gMASEp8AABJiQ8AICnxAQAkJT4AgKTEBwCQlPgAAJISHwBAUuIDAEhKfAAASYkPACAp8QEAJCU+AICkxAcAkJT4AACSEh8AQFLiAwBISnwAAEmJDwAgKfEBACQlPgCApMQHAJCU+AAAkhIfAEBSw46PZ555Jm655ZaYN29elJSUxOOPP160PcuyeOCBB2Lu3Lkxbdq0aGxsjDfeeGOk5gUAJrhhx8exY8diyZIl0d7eftbt3/nOd+Lhhx+ORx99NHbv3h2XXXZZNDU1xfHjxy94WABg4isb7icsX748li9fftZtWZbFd7/73fjnf/7nuPXWWyMi4gc/+EHU1NTE448/Hp///OcvbFoAYMIb0dd87Nu3L/L5fDQ2NhbWVVVVRUNDQ3R1dZ31cwYHB2NgYKBoAQAmrxGNj3w+HxERNTU1RetramoK2/7c+vXro6qqqrDU1dWN5EgAwDgz5r/t0tbWFv39/YXlwIEDYz0SADCKRjQ+amtrIyKit7e3aH1vb29h25+rqKiIysrKogUAmLxGND4WLlwYtbW10dHRUVg3MDAQu3fvjlwuN5J3BQBMUMP+bZejR4/G3r17C7f37dsXL7/8csycOTPq6+vjvvvui29/+9tx1VVXxcKFC+PrX/96zJs3L2677baRnBsAmKCGHR8vvPBC/O3f/m3hdmtra0RErF69OrZs2RJf+9rX4tixY3HXXXdFX19ffPrTn46dO3fG1KlTR25qAGDCGnZ8fPazn40sy95ze0lJSXzrW9+Kb33rWxc0GAAwOY35b7sAABcX8QEAJCU+AICkxAcAkJT4AACSEh8AQFLiAwBIath/52OyWXD/k2M9AgCMitPf497csGKMJynmygcAkJT4AACSEh8AQFLiAwBISnwAAEmJDwAgKfEBACQlPgCApMQHAJCU+AAAkhIfAEBS4gMASEp8AABJiQ8AICnxAQAkJT4AgKTEBwCQlPgAAJISHwBAUuIDAEhKfAAASYkPACAp8QEAJCU+AICkxAcAkJT4AACSEh8AQFJlYz3AWFlw/5NjPQIAJHH6e96bG1aM8SR/4soHAJCU+AAAkhIfAEBS4gMASEp8AABJiQ8AICnxAQAkJT4AgKTEBwCQlPgAAJISHwBAUuIDAEhKfAAASV108eHdbAG4WC24/8lx8X3woosPAGBsiQ8AICnxAQAkJT4AgKTEBwCQlPgAAJISHwBAUuIDAEhKfAAASY1afLS3t8eCBQti6tSp0dDQEL/85S9H664AgAlkVOLjRz/6UbS2tsa6devixRdfjCVLlkRTU1McPnx4NO4OAJhARiU+HnzwwfjiF78Yd955Z1x77bXx6KOPxqWXXhrf//73R+PuAIAJpGykD3jixIno6emJtra2wrrS0tJobGyMrq6uM/YfHByMwcHBwu3+/v6IiBgYGBjp0SIiYmjwnRgYGIihwXdG5fhwMTj9NeSjj+/3kfFrNL7Hnj5mlmUfvHM2wn73u99lEZE999xzReu/+tWvZtdff/0Z+69bty6LCIvFYrFYLJNgOXDgwAe2wohf+Riutra2aG1tLdweGhqKt99+O2bNmhUlJSUjdj8DAwNRV1cXBw4ciMrKyhE7LufH+RhfnI/xxzkZX5yPD5ZlWRw5ciTmzZv3gfuOeHzMnj07Lrnkkujt7S1a39vbG7W1tWfsX1FRERUVFUXrqqurR3qsgsrKSv/DGUecj/HF+Rh/nJPxxfl4f1VVVee034i/4LS8vDyWLl0aHR0dhXVDQ0PR0dERuVxupO8OAJhgRuXHLq2trbF69er4xCc+Eddff31897vfjWPHjsWdd945GncHAEwgoxIft99+e/z+97+PBx54IPL5fPz1X/917Ny5M2pqakbj7s5JRUVFrFu37owf8TA2nI/xxfkYf5yT8cX5GFklWXYuvxMDADAyvLcLAJCU+AAAkhIfAEBS4gMASOqiiI/29vZYsGBBTJ06NRoaGuKXv/zlWI90UfjGN74RJSUlRcuiRYsK248fPx4tLS0xa9asmD59ejQ3N5/xx+m4MM8880zccsstMW/evCgpKYnHH3+8aHuWZfHAAw/E3LlzY9q0adHY2BhvvPFG0T5vv/12rFq1KiorK6O6ujrWrFkTR48eTfgoJo8POh9f+MIXzviaufnmm4v2cT5Gzvr16+OTn/xkzJgxI+bMmRO33XZb7Nmzp2ifc3me2r9/f6xYsSIuvfTSmDNnTnz1q1+Nd999N+VDmXAmfXz86Ec/itbW1li3bl28+OKLsWTJkmhqaorDhw+P9WgXhY985CNx6NChwvLss88Wtq1duzZ27NgR27dvj87Ozjh48GCsXLlyDKedfI4dOxZLliyJ9vb2s27/zne+Ew8//HA8+uijsXv37rjsssuiqakpjh8/Xthn1apV8dprr8VTTz0VTzzxRDzzzDNx1113pXoIk8oHnY+IiJtvvrnoa+aHP/xh0XbnY+R0dnZGS0tLdHd3x1NPPRUnT56Mm266KY4dO1bY54Oep06dOhUrVqyIEydOxHPPPRf/9V//FVu2bIkHHnhgLB7SxDEi7yY3jl1//fVZS0tL4fapU6eyefPmZevXrx/DqS4O69aty5YsWXLWbX19fdmUKVOy7du3F9b9+te/ziIi6+rqSjThxSUisscee6xwe2hoKKutrc3+7d/+rbCur68vq6ioyH74wx9mWZZlr7/+ehYR2fPPP1/Y56c//WlWUlKS/e53v0s2+2T05+cjy7Js9erV2a233vqen+N8jK7Dhw9nEZF1dnZmWXZuz1P/8z//k5WWlmb5fL6wz8aNG7PKyspscHAw7QOYQCb1lY8TJ05ET09PNDY2FtaVlpZGY2NjdHV1jeFkF4833ngj5s2bF1deeWWsWrUq9u/fHxERPT09cfLkyaJzs2jRoqivr3duEtm3b1/k8/mic1BVVRUNDQ2Fc9DV1RXV1dXxiU98orBPY2NjlJaWxu7du5PPfDHYtWtXzJkzJ66++uq4++6746233ipscz5GV39/f0REzJw5MyLO7Xmqq6srPvrRjxb9Ec2mpqYYGBiI1157LeH0E8ukjo8//OEPcerUqTP+smpNTU3k8/kxmuri0dDQEFu2bImdO3fGxo0bY9++ffGZz3wmjhw5Evl8PsrLy894E0HnJp3T/53f7+sjn8/HnDlziraXlZXFzJkznadRcPPNN8cPfvCD6OjoiH/913+Nzs7OWL58eZw6dSoinI/RNDQ0FPfdd1986lOfiuuuuy4i4pyep/L5/Fm/hk5v4+xG5c+rQ0TE8uXLC/9evHhxNDQ0xBVXXBE//vGPY9q0aWM4GYxPn//85wv//uhHPxqLFy+Ov/zLv4xdu3bFjTfeOIaTTX4tLS3x6quvFr0ujdEzqa98zJ49Oy655JIzXpnc29sbtbW1YzTVxau6ujo+/OEPx969e6O2tjZOnDgRfX19Rfs4N+mc/u/8fl8ftbW1Z7w4+9133423337beUrgyiuvjNmzZ8fevXsjwvkYLffcc0888cQT8Ytf/CLmz59fWH8uz1O1tbVn/Ro6vY2zm9TxUV5eHkuXLo2Ojo7CuqGhoejo6IhcLjeGk12cjh49Gr/5zW9i7ty5sXTp0pgyZUrRudmzZ0/s37/fuUlk4cKFUVtbW3QOBgYGYvfu3YVzkMvloq+vL3p6egr7PP300zE0NBQNDQ3JZ77Y/Pa3v4233nor5s6dGxHOx0jLsizuueeeeOyxx+Lpp5+OhQsXFm0/l+epXC4Xr7zySlEUPvXUU1FZWRnXXnttmgcyEY31K15H27Zt27KKiopsy5Yt2euvv57dddddWXV1ddErkxkdX/7yl7Ndu3Zl+/bty/73f/83a2xszGbPnp0dPnw4y7Is+9KXvpTV19dnTz/9dPbCCy9kuVwuy+VyYzz15HLkyJHspZdeyl566aUsIrIHH3wwe+mll7L/+7//y7IsyzZs2JBVV1dnP/nJT7Jf/epX2a233potXLgw++Mf/1g4xs0335x97GMfy3bv3p09++yz2VVXXZXdcccdY/WQJrT3Ox9HjhzJvvKVr2RdXV3Zvn37sp///OfZxz/+8eyqq67Kjh8/XjiG8zFy7r777qyqqirbtWtXdujQocLyzjvvFPb5oOepd999N7vuuuuym266KXv55ZeznTt3ZpdffnnW1tY2Fg9pwpj08ZFlWfYf//EfWX19fVZeXp5df/31WXd391iPdFG4/fbbs7lz52bl5eXZX/zFX2S33357tnfv3sL2P/7xj9k//MM/ZB/60IeySy+9NPvc5z6XHTp0aAwnnnx+8YtfZBFxxrJ69eosy/7067Zf//rXs5qamqyioiK78cYbsz179hQd46233sruuOOObPr06VllZWV25513ZkeOHBmDRzPxvd/5eOedd7Kbbropu/zyy7MpU6ZkV1xxRfbFL37xjP+j5HyMnLOdi4jINm/eXNjnXJ6n3nzzzWz58uXZtGnTstmzZ2df/vKXs5MnTyZ+NBNLSZZlWeqrLQDAxWtSv+YDABh/xAcAkJT4AACSEh8AQFLiAwBISnwAAEmJDwAgKfEBACQlPgCApMQHAJCU+AAAkhIfAEBS/w/4YEFhFFWs1gAAAABJRU5ErkJggg==", - "text/plain": [ - "
    " - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "plt.bar(torch.arange(mps.n_features - 1) + 1, torch.tensor(mps.bond_dim))\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "test_tk", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/docs/_build/html/examples/mps_dmrg_hybrid.html b/docs/_build/html/examples/mps_dmrg_hybrid.html index 4cb3f09..56c5566 100644 --- a/docs/_build/html/examples/mps_dmrg_hybrid.html +++ b/docs/_build/html/examples/mps_dmrg_hybrid.html @@ -6,7 +6,7 @@ - Hybrid DMRG-like training of MPS — TensorKrowch 1.1.2 documentation + Hybrid DMRG-like training of MPS — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/examples/tensorizing_nn.html b/docs/_build/html/examples/tensorizing_nn.html index 24b9032..437aacd 100644 --- a/docs/_build/html/examples/tensorizing_nn.html +++ b/docs/_build/html/examples/tensorizing_nn.html @@ -6,7 +6,7 @@ - Tensorizing Neural Networks — TensorKrowch 1.1.2 documentation + Tensorizing Neural Networks — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/examples/training_mps.html b/docs/_build/html/examples/training_mps.html index ec54c00..dc665e7 100644 --- a/docs/_build/html/examples/training_mps.html +++ b/docs/_build/html/examples/training_mps.html @@ -6,7 +6,7 @@ - Training MPS in different ways — TensorKrowch 1.1.2 documentation + Training MPS in different ways — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/genindex.html b/docs/_build/html/genindex.html index 1b7c1e1..26779b1 100644 --- a/docs/_build/html/genindex.html +++ b/docs/_build/html/genindex.html @@ -5,7 +5,7 @@ - Index — TensorKrowch 1.1.2 documentation + Index — TensorKrowch 1.1.3 documentation @@ -446,6 +446,8 @@

    C

  • (tensorkrowch.ParamNode method)
  • +
  • conj() (tensorkrowch.AbstractNode method) +
  • connect() (in module tensorkrowch)
  • + + - +
  • is_complex() (tensorkrowch.AbstractNode method) +
  • +
  • is_conj() (tensorkrowch.AbstractNode method) +
  • is_connected_to() (tensorkrowch.AbstractNode method)
  • is_dangling() (tensorkrowch.Edge method)
  • is_data() (tensorkrowch.AbstractNode method) +
  • +
  • is_floating_point() (tensorkrowch.AbstractNode method)
  • is_leaf() (tensorkrowch.AbstractNode method)
  • @@ -1246,6 +1254,8 @@

    T

    diff --git a/docs/_build/html/index.html b/docs/_build/html/index.html index c0a4036..11e3955 100644 --- a/docs/_build/html/index.html +++ b/docs/_build/html/index.html @@ -6,7 +6,7 @@ - TensorKrowch documentation — TensorKrowch 1.1.2 documentation + TensorKrowch documentation — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/initializers.html b/docs/_build/html/initializers.html index acaeaa5..a59d24c 100644 --- a/docs/_build/html/initializers.html +++ b/docs/_build/html/initializers.html @@ -6,7 +6,7 @@ - Initializers — TensorKrowch 1.1.2 documentation + Initializers — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/installation.html b/docs/_build/html/installation.html index e3f3bf3..a51f62b 100644 --- a/docs/_build/html/installation.html +++ b/docs/_build/html/installation.html @@ -6,7 +6,7 @@ - Installation — TensorKrowch 1.1.2 documentation + Installation — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/models.html b/docs/_build/html/models.html index 6933718..fbe0f1e 100644 --- a/docs/_build/html/models.html +++ b/docs/_build/html/models.html @@ -6,7 +6,7 @@ - Models — TensorKrowch 1.1.2 documentation + Models — TensorKrowch 1.1.3 documentation @@ -670,7 +670,7 @@

    MPS#MPS#

    -class tensorkrowch.models.MPS(n_features=None, phys_dim=None, bond_dim=None, boundary='obc', tensors=None, in_features=None, out_features=None, n_batches=1, init_method='randn', device=None, **kwargs)[source]#
    +class tensorkrowch.models.MPS(n_features=None, phys_dim=None, bond_dim=None, boundary='obc', tensors=None, in_features=None, out_features=None, n_batches=1, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Class for Matrix Product States. This is the base class from which UMPS, MPSLayer and UMPSLayer inherit.

    Matrix Product States are formed by:

    @@ -737,6 +737,7 @@

    MPS#

    init_method ({"zeros", "ones", "copy", "rand", "randn", "randn_eye", "unit"}, optional) – Initialization method. Check initialize() for a more detailed explanation of the different initialization methods.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • @@ -885,14 +886,14 @@

    MPS#
    -initialize(tensors=None, init_method='randn', device=None, **kwargs)[source]#
    +initialize(tensors=None, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Initializes all the nodes of the MPS. It can be called when instantiating the model, or to override the existing nodes’ tensors.

    There are different methods to initialize the nodes:

    • {"zeros", "ones", "copy", "rand", "randn"}: Each node is initialized calling set_tensor() with -the given method, device and kwargs.

    • +the given method, device, dtype and kwargs.

    • "randn_eye": Nodes are initialized as in this paper, adding identities at the top of random gaussian tensors. In this case, std should be @@ -917,6 +918,7 @@

      MPS#

    • init_method ({"zeros", "ones", "copy", "rand", "randn", "randn_eye", "unit", "canonical"}, optional) – Initialization method.

    • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

    • +
    • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

    • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

    @@ -1243,7 +1245,7 @@

    MPS#UMPS#

    -class tensorkrowch.models.UMPS(n_features, phys_dim=None, bond_dim=None, tensor=None, in_features=None, out_features=None, n_batches=1, init_method='randn', device=None, **kwargs)[source]#
    +class tensorkrowch.models.UMPS(n_features, phys_dim=None, bond_dim=None, tensor=None, in_features=None, out_features=None, n_batches=1, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Class for Uniform (translationally invariant) Matrix Product States. It is the uniform version of MPS, that is, all nodes share the same tensor. Thus this class cannot have different physical or bond dimensions @@ -1282,6 +1284,7 @@

    UMPS#<
  • init_method ({"zeros", "ones", "copy", "rand", "randn", "randn_eye", "unit"}, optional) – Initialization method. Check initialize() for a more detailed explanation of the different initialization methods.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • @@ -1302,14 +1305,14 @@

    UMPS#<

    -initialize(tensors=None, init_method='randn', device=None, **kwargs)[source]#
    +initialize(tensors=None, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Initializes the common tensor of the UMPS. It can be called when instantiating the model, or to override the existing nodes’ tensors.

    There are different methods to initialize the nodes:

    @@ -1381,7 +1385,7 @@

    UMPS#<

    MPSLayer#

    -class tensorkrowch.models.MPSLayer(n_features=None, in_dim=None, out_dim=None, bond_dim=None, out_position=None, boundary='obc', tensors=None, n_batches=1, init_method='randn', device=None, **kwargs)[source]#
    +class tensorkrowch.models.MPSLayer(n_features=None, in_dim=None, out_dim=None, bond_dim=None, out_position=None, boundary='obc', tensors=None, n_batches=1, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Class for Matrix Product States with a single output node. That is, this MPS has \(n\) nodes, being \(n-1\) input nodes connected to data nodes (nodes that will contain the data tensors), and one output node, @@ -1447,6 +1451,7 @@

    MPSLayerinitialize() for a more detailed explanation of the different initialization methods.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • @@ -1503,14 +1508,14 @@

    MPSLayer
    -initialize(tensors=None, init_method='randn', device=None, **kwargs)[source]#
    +initialize(tensors=None, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Initializes all the nodes of the MPSLayer. It can be called when instantiating the model, or to override the existing nodes’ tensors.

    There are different methods to initialize the nodes:

    @@ -1567,7 +1573,7 @@

    MPSLayer#

    -class tensorkrowch.models.UMPSLayer(n_features, in_dim=None, out_dim=None, bond_dim=None, out_position=None, tensors=None, n_batches=1, init_method='randn', device=None, **kwargs)[source]#
    +class tensorkrowch.models.UMPSLayer(n_features, in_dim=None, out_dim=None, bond_dim=None, out_position=None, tensors=None, n_batches=1, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Class for Uniform (translationally invariant) Matrix Product States with an output node. It is the uniform version of MPSLayer, with all input nodes sharing the same tensor, but with a different node for the output node. @@ -1607,6 +1613,7 @@

    UMPSLayerinitialize() for a more detailed explanation of the different initialization methods.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • @@ -1654,14 +1661,14 @@

    UMPSLayer
    -initialize(tensors=None, init_method='randn', device=None, **kwargs)[source]#
    +initialize(tensors=None, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Initializes the common tensor of the UMPSLayer. It can be called when instantiating the model, or to override the existing nodes’ tensors.

    There are different methods to initialize the nodes:

    @@ -1734,7 +1742,7 @@

    UMPSLayer#

    -class tensorkrowch.models.ConvMPS(in_channels, bond_dim, kernel_size, stride=1, padding=0, dilation=1, boundary='obc', tensors=None, init_method='randn', device=None, **kwargs)[source]#
    +class tensorkrowch.models.ConvMPS(in_channels, bond_dim, kernel_size, stride=1, padding=0, dilation=1, boundary='obc', tensors=None, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Convolutional version of MPS, where the input data is assumed to be a batch of images.

    Input data as well as initialization parameters are described in torch.nn.Conv2d.

    @@ -1772,6 +1780,7 @@

    ConvMPSinitialize() for a more detailed explanation of the different initialization methods.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • @@ -1869,7 +1878,7 @@

    ConvMPS#

    -class tensorkrowch.models.ConvUMPS(in_channels, bond_dim, kernel_size, stride=1, padding=0, dilation=1, tensor=None, init_method='randn', device=None, **kwargs)[source]#
    +class tensorkrowch.models.ConvUMPS(in_channels, bond_dim, kernel_size, stride=1, padding=0, dilation=1, tensor=None, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Convolutional version of UMPS, where the input data is assumed to be a batch of images.

    Input data as well as initialization parameters are described in torch.nn.Conv2d.

    @@ -1897,6 +1906,7 @@

    ConvUMPSinitialize() for a more detailed explanation of the different initialization methods.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • @@ -1997,7 +2007,7 @@

    ConvUMPS#

    -class tensorkrowch.models.ConvMPSLayer(in_channels, out_channels, bond_dim, kernel_size, stride=1, padding=0, dilation=1, out_position=None, boundary='obc', tensors=None, init_method='randn', device=None, **kwargs)[source]#
    +class tensorkrowch.models.ConvMPSLayer(in_channels, out_channels, bond_dim, kernel_size, stride=1, padding=0, dilation=1, out_position=None, boundary='obc', tensors=None, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Convolutional version of MPSLayer, where the input data is assumed to be a batch of images.

    Input data as well as initialization parameters are described in torch.nn.Conv2d.

    @@ -2040,6 +2050,7 @@

    ConvMPSLayerinitialize() for a more detailed explanation of the different initialization methods.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • @@ -2144,7 +2155,7 @@

    ConvMPSLayer#

    -class tensorkrowch.models.ConvUMPSLayer(in_channels, out_channels, bond_dim, kernel_size, stride=1, padding=0, dilation=1, out_position=None, tensors=None, init_method='randn', device=None, **kwargs)[source]#
    +class tensorkrowch.models.ConvUMPSLayer(in_channels, out_channels, bond_dim, kernel_size, stride=1, padding=0, dilation=1, out_position=None, tensors=None, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Convolutional version of UMPSLayer, where the input data is assumed to be a batch of images.

    Input data as well as initialization parameters are described in torch.nn.Conv2d.

    @@ -2178,6 +2189,7 @@

    ConvUMPSLayerinitialize() for a more detailed explanation of the different initialization methods.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • @@ -2289,7 +2301,7 @@

    MPSData#

    -class tensorkrowch.models.MPSData(n_features=None, phys_dim=None, bond_dim=None, boundary='obc', n_batches=1, tensors=None, init_method=None, device=None, **kwargs)[source]#
    +class tensorkrowch.models.MPSData(n_features=None, phys_dim=None, bond_dim=None, boundary='obc', n_batches=1, tensors=None, init_method=None, device=None, dtype=None, **kwargs)[source]#

    Class for data vectors in the form of Matrix Product States. That is, this is a class similar to MPS, but where all nodes can have additional batch edges.

    @@ -2344,6 +2356,7 @@

    MPSDataadd_data() to see how to initialize MPS nodes with data tensors.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • @@ -2409,16 +2422,22 @@

    MPSData

    Returns the list of nodes in mats_env.

    +
    +
    +property tensors#
    +

    Returns the list of MPS tensors.

    +
    +
    -initialize(init_method='randn', device=None, **kwargs)[source]#
    +initialize(init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Initializes all the nodes of the MPSData. It can be called when instantiating the model, or to override the existing nodes’ tensors.

    There are different methods to initialize the nodes:

    • {"zeros", "ones", "copy", "rand", "randn"}: Each node is initialized calling set_tensor() with -the given method, device and kwargs.

    • +the given method, device, dtype and kwargs.

    Parameters
    @@ -2426,6 +2445,7 @@

    MPSData

    init_method ({"zeros", "ones", "copy", "rand", "randn"}, optional) – Initialization method. Check add_data() to see how to initialize MPS nodes with data tensors.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • @@ -2461,7 +2481,7 @@

    MPO#MPO#

    -class tensorkrowch.models.MPO(n_features=None, in_dim=None, out_dim=None, bond_dim=None, boundary='obc', tensors=None, n_batches=1, init_method='randn', device=None, **kwargs)[source]#
    +class tensorkrowch.models.MPO(n_features=None, in_dim=None, out_dim=None, bond_dim=None, boundary='obc', tensors=None, n_batches=1, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Class for Matrix Product Operators. This is the base class from which UMPO inherits.

    Matrix Product Operators are formed by:

    @@ -2511,6 +2531,7 @@

    MPO#

    init_method ({"zeros", "ones", "copy", "rand", "randn"}, optional) – Initialization method. Check initialize() for a more detailed explanation of the different initialization methods.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • @@ -2604,14 +2625,14 @@

    MPO#
    -initialize(tensors=None, init_method='randn', device=None, **kwargs)[source]#
    +initialize(tensors=None, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Initializes all the nodes of the MPO. It can be called when instantiating the model, or to override the existing nodes’ tensors.

    There are different methods to initialize the nodes:

    • {"zeros", "ones", "copy", "rand", "randn"}: Each node is initialized calling set_tensor() with -the given method, device and kwargs.

    • +the given method, device, dtype and kwargs.

    Parameters
    @@ -2623,6 +2644,7 @@

    MPO#
  • init_method ({"zeros", "ones", "copy", "rand", "randn"}, optional) – Initialization method.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • @@ -2808,7 +2830,7 @@

    MPO#UMPO#

    -class tensorkrowch.models.UMPO(n_features=None, in_dim=None, out_dim=None, bond_dim=None, tensor=None, n_batches=1, init_method='randn', device=None, **kwargs)[source]#
    +class tensorkrowch.models.UMPO(n_features=None, in_dim=None, out_dim=None, bond_dim=None, tensor=None, n_batches=1, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Class for Uniform (translationally invariant) Matrix Product Operators. It is the uniform version of MPO, that is, all nodes share the same tensor. Thus this class cannot have different input/output or bond dimensions @@ -2836,6 +2858,7 @@

    UMPO#<
  • init_method ({"zeros", "ones", "copy", "rand", "randn"}, optional) – Initialization method. Check initialize() for a more detailed explanation of the different initialization methods.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • @@ -2857,14 +2880,14 @@

    UMPO#<

    -initialize(tensors=None, init_method='randn', device=None, **kwargs)[source]#
    +initialize(tensors=None, init_method='randn', device=None, dtype=None, **kwargs)[source]#

    Initializes the common tensor of the UMPO. It can be called when instantiating the model, or to override the existing nodes’ tensors.

    There are different methods to initialize the nodes:

    • {"zeros", "ones", "copy", "rand", "randn"}: The tensor is initialized calling set_tensor() with -the given method, device and kwargs.

    • +the given method, device, dtype and kwargs.

    Parameters
    @@ -2874,6 +2897,7 @@

    UMPO#< equal.

  • init_method ({"zeros", "ones", "copy", "rand", "randn"}, optional) – Initialization method.

  • device (torch.device, optional) – Device where to initialize the tensors if init_method is provided.

  • +
  • dtype (torch.dtype, optional) – Dtype of the tensor if init_method is provided.

  • kwargs (float) – Keyword arguments for the different initialization methods. See make_tensor().

  • diff --git a/docs/_build/html/objects.inv b/docs/_build/html/objects.inv index 08d7c19..519ff1c 100644 Binary files a/docs/_build/html/objects.inv and b/docs/_build/html/objects.inv differ diff --git a/docs/_build/html/operations.html b/docs/_build/html/operations.html index 6e068cd..0521ed6 100644 --- a/docs/_build/html/operations.html +++ b/docs/_build/html/operations.html @@ -6,7 +6,7 @@ - Operations — TensorKrowch 1.1.2 documentation + Operations — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/search.html b/docs/_build/html/search.html index 38aa2fe..5a1402c 100644 --- a/docs/_build/html/search.html +++ b/docs/_build/html/search.html @@ -5,7 +5,7 @@ - Search — TensorKrowch 1.1.2 documentation + Search — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/searchindex.js b/docs/_build/html/searchindex.js index f4ec712..5dfcfda 100644 --- a/docs/_build/html/searchindex.js +++ b/docs/_build/html/searchindex.js @@ -1 +1 @@ -Search.setIndex({docnames:["api","components","contents","decompositions","embeddings","examples","examples/hybrid_tnn_model","examples/mps_dmrg","examples/mps_dmrg_hybrid","examples/tensorizing_nn","examples/training_mps","index","initializers","installation","models","operations","tutorials","tutorials/0_first_steps","tutorials/1_creating_tensor_network","tutorials/2_contracting_tensor_network","tutorials/3_memory_management","tutorials/4_types_of_nodes","tutorials/5_subclass_tensor_network","tutorials/6_mix_with_pytorch"],envversion:{"sphinx.domains.c":2,"sphinx.domains.changeset":1,"sphinx.domains.citation":1,"sphinx.domains.cpp":5,"sphinx.domains.index":1,"sphinx.domains.javascript":2,"sphinx.domains.math":2,"sphinx.domains.python":3,"sphinx.domains.rst":2,"sphinx.domains.std":2,"sphinx.ext.viewcode":1,nbsphinx:4,sphinx:56},filenames:["api.rst","components.rst","contents.rst","decompositions.rst","embeddings.rst","examples.rst","examples/hybrid_tnn_model.ipynb","examples/mps_dmrg.ipynb","examples/mps_dmrg_hybrid.ipynb","examples/tensorizing_nn.ipynb","examples/training_mps.ipynb","index.rst","initializers.rst","installation.rst","models.rst","operations.rst","tutorials.rst","tutorials/0_first_steps.rst","tutorials/1_creating_tensor_network.rst","tutorials/2_contracting_tensor_network.rst","tutorials/3_memory_management.rst","tutorials/4_types_of_nodes.rst","tutorials/5_subclass_tensor_network.rst","tutorials/6_mix_with_pytorch.rst"],objects:{"tensorkrowch.AbstractNode":[[1,1,1,"","axes"],[1,1,1,"","axes_names"],[1,2,1,"","contract_between"],[1,2,1,"","contract_between_"],[1,1,1,"","device"],[1,2,1,"","disconnect"],[1,1,1,"","dtype"],[1,1,1,"","edges"],[1,2,1,"","get_axis"],[1,2,1,"","get_axis_num"],[1,2,1,"","get_edge"],[1,2,1,"","in_which_axis"],[1,2,1,"","is_connected_to"],[1,2,1,"","is_data"],[1,2,1,"","is_leaf"],[1,2,1,"","is_node1"],[1,2,1,"","is_resultant"],[1,2,1,"","is_virtual"],[1,2,1,"","make_tensor"],[1,2,1,"","mean"],[1,2,1,"","move_to_network"],[1,1,1,"","name"],[1,2,1,"","neighbours"],[1,1,1,"","network"],[1,2,1,"","node_ref"],[1,2,1,"","norm"],[1,2,1,"","numel"],[1,2,1,"","permute"],[1,2,1,"","permute_"],[1,1,1,"","rank"],[1,2,1,"","reattach_edges"],[1,2,1,"","renormalize"],[1,2,1,"","reset_tensor_address"],[1,2,1,"","set_tensor"],[1,2,1,"","set_tensor_from"],[1,1,1,"","shape"],[1,2,1,"","size"],[1,2,1,"","split"],[1,2,1,"","split_"],[1,2,1,"","std"],[1,1,1,"","successors"],[1,2,1,"","sum"],[1,1,1,"","tensor"],[1,2,1,"","tensor_address"],[1,2,1,"","unset_tensor"]],"tensorkrowch.Axis":[[1,2,1,"","is_batch"],[1,2,1,"","is_node1"],[1,1,1,"","name"],[1,1,1,"","node"],[1,1,1,"","num"]],"tensorkrowch.Edge":[[1,1,1,"","axes"],[1,1,1,"","axis1"],[1,1,1,"","axis2"],[1,2,1,"","change_size"],[1,2,1,"","connect"],[1,2,1,"","contract"],[1,2,1,"","contract_"],[1,2,1,"","copy"],[1,2,1,"","disconnect"],[1,2,1,"","is_attached_to"],[1,2,1,"","is_batch"],[1,2,1,"","is_dangling"],[1,1,1,"","name"],[1,1,1,"","node1"],[1,1,1,"","node2"],[1,1,1,"","nodes"],[1,2,1,"","qr"],[1,2,1,"","qr_"],[1,2,1,"","rq"],[1,2,1,"","rq_"],[1,2,1,"","size"],[1,2,1,"","svd"],[1,2,1,"","svd_"],[1,2,1,"","svdr"],[1,2,1,"","svdr_"]],"tensorkrowch.Node":[[1,2,1,"","change_type"],[1,2,1,"","copy"],[1,2,1,"","parameterize"]],"tensorkrowch.ParamNode":[[1,2,1,"","change_type"],[1,2,1,"","copy"],[1,1,1,"","grad"],[1,2,1,"","parameterize"]],"tensorkrowch.ParamStackNode":[[1,1,1,"","edges_dict"],[1,1,1,"","node1_lists_dict"],[1,2,1,"","reconnect"],[1,2,1,"","unbind"]],"tensorkrowch.StackEdge":[[1,2,1,"","connect"],[1,1,1,"","edges"],[1,1,1,"","node1_list"]],"tensorkrowch.StackNode":[[1,1,1,"","edges_dict"],[1,1,1,"","node1_lists_dict"],[1,2,1,"","reconnect"],[1,2,1,"","unbind"]],"tensorkrowch.TensorNetwork":[[1,2,1,"","add_data"],[1,1,1,"","auto_stack"],[1,1,1,"","auto_unbind"],[1,2,1,"","contract"],[1,2,1,"","copy"],[1,1,1,"","data_nodes"],[1,2,1,"","delete_node"],[1,1,1,"","edges"],[1,2,1,"","forward"],[1,1,1,"","leaf_nodes"],[1,1,1,"","nodes"],[1,1,1,"","nodes_names"],[1,2,1,"","parameterize"],[1,2,1,"","reset"],[1,1,1,"","resultant_nodes"],[1,2,1,"","set_data_nodes"],[1,2,1,"","trace"],[1,2,1,"","unset_data_nodes"],[1,1,1,"","virtual_nodes"]],"tensorkrowch.decompositions":[[3,3,1,"","mat_to_mpo"],[3,3,1,"","vec_to_mps"]],"tensorkrowch.embeddings":[[4,3,1,"","add_ones"],[4,3,1,"","basis"],[4,3,1,"","discretize"],[4,3,1,"","poly"],[4,3,1,"","unit"]],"tensorkrowch.models":[[14,0,1,"","ConvMPS"],[14,0,1,"","ConvMPSLayer"],[14,0,1,"","ConvPEPS"],[14,0,1,"","ConvTree"],[14,0,1,"","ConvUMPS"],[14,0,1,"","ConvUMPSLayer"],[14,0,1,"","ConvUPEPS"],[14,0,1,"","ConvUTree"],[14,0,1,"","MPO"],[14,0,1,"","MPS"],[14,0,1,"","MPSData"],[14,0,1,"","MPSLayer"],[14,0,1,"","PEPS"],[14,0,1,"","Tree"],[14,0,1,"","UMPO"],[14,0,1,"","UMPS"],[14,0,1,"","UMPSLayer"],[14,0,1,"","UPEPS"],[14,0,1,"","UTree"]],"tensorkrowch.models.ConvMPS":[[14,2,1,"","copy"],[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvMPSLayer":[[14,2,1,"","copy"],[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","out_channels"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvPEPS":[[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvTree":[[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvUMPS":[[14,2,1,"","copy"],[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvUMPSLayer":[[14,2,1,"","copy"],[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","out_channels"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvUPEPS":[[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvUTree":[[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.MPO":[[14,1,1,"","bond_dim"],[14,1,1,"","boundary"],[14,2,1,"","canonicalize"],[14,2,1,"","contract"],[14,2,1,"","copy"],[14,1,1,"","in_dim"],[14,2,1,"","initialize"],[14,1,1,"","left_node"],[14,1,1,"","mats_env"],[14,1,1,"","n_batches"],[14,1,1,"","n_features"],[14,1,1,"","out_dim"],[14,2,1,"","parameterize"],[14,1,1,"","right_node"],[14,2,1,"","set_data_nodes"],[14,1,1,"","tensors"],[14,2,1,"","update_bond_dim"]],"tensorkrowch.models.MPS":[[14,1,1,"","bond_dim"],[14,1,1,"","boundary"],[14,2,1,"","canonicalize"],[14,2,1,"","canonicalize_univocal"],[14,2,1,"","contract"],[14,2,1,"","copy"],[14,2,1,"","entropy"],[14,1,1,"","in_env"],[14,1,1,"","in_features"],[14,1,1,"","in_regions"],[14,2,1,"","initialize"],[14,1,1,"","left_node"],[14,1,1,"","mats_env"],[14,1,1,"","n_batches"],[14,1,1,"","n_features"],[14,2,1,"","norm"],[14,1,1,"","out_env"],[14,1,1,"","out_features"],[14,1,1,"","out_regions"],[14,2,1,"","parameterize"],[14,2,1,"","partial_density"],[14,1,1,"","phys_dim"],[14,1,1,"","right_node"],[14,2,1,"","set_data_nodes"],[14,1,1,"","tensors"],[14,2,1,"","update_bond_dim"]],"tensorkrowch.models.MPSData":[[14,2,1,"","add_data"],[14,1,1,"","bond_dim"],[14,1,1,"","boundary"],[14,2,1,"","initialize"],[14,1,1,"","left_node"],[14,1,1,"","mats_env"],[14,1,1,"","n_batches"],[14,1,1,"","n_features"],[14,1,1,"","phys_dim"],[14,1,1,"","right_node"]],"tensorkrowch.models.MPSLayer":[[14,2,1,"","copy"],[14,1,1,"","in_dim"],[14,2,1,"","initialize"],[14,1,1,"","out_dim"],[14,1,1,"","out_node"],[14,1,1,"","out_position"]],"tensorkrowch.models.PEPS":[[14,1,1,"","bond_dim"],[14,1,1,"","boundary"],[14,2,1,"","contract"],[14,1,1,"","in_dim"],[14,2,1,"","initialize"],[14,1,1,"","n_batches"],[14,1,1,"","n_cols"],[14,1,1,"","n_rows"],[14,2,1,"","set_data_nodes"]],"tensorkrowch.models.Tree":[[14,1,1,"","bond_dim"],[14,2,1,"","canonicalize"],[14,2,1,"","contract"],[14,2,1,"","initialize"],[14,1,1,"","n_batches"],[14,2,1,"","set_data_nodes"],[14,1,1,"","sites_per_layer"]],"tensorkrowch.models.UMPO":[[14,2,1,"","copy"],[14,2,1,"","initialize"],[14,2,1,"","parameterize"]],"tensorkrowch.models.UMPS":[[14,2,1,"","copy"],[14,2,1,"","initialize"],[14,2,1,"","parameterize"]],"tensorkrowch.models.UMPSLayer":[[14,2,1,"","copy"],[14,1,1,"","in_dim"],[14,2,1,"","initialize"],[14,1,1,"","out_dim"],[14,1,1,"","out_node"],[14,1,1,"","out_position"],[14,2,1,"","parameterize"]],"tensorkrowch.models.UPEPS":[[14,1,1,"","bond_dim"],[14,2,1,"","contract"],[14,1,1,"","in_dim"],[14,2,1,"","initialize"],[14,1,1,"","n_batches"],[14,1,1,"","n_cols"],[14,1,1,"","n_rows"],[14,2,1,"","set_data_nodes"]],"tensorkrowch.models.UTree":[[14,1,1,"","bond_dim"],[14,2,1,"","contract"],[14,2,1,"","initialize"],[14,1,1,"","n_batches"],[14,2,1,"","set_data_nodes"],[14,1,1,"","sites_per_layer"]],tensorkrowch:[[1,0,1,"","AbstractNode"],[1,0,1,"","Axis"],[1,0,1,"","Edge"],[1,0,1,"","Node"],[15,0,1,"","Operation"],[1,0,1,"","ParamNode"],[1,0,1,"","ParamStackNode"],[1,0,1,"","StackEdge"],[1,0,1,"","StackNode"],[1,0,1,"","Successor"],[1,0,1,"","TensorNetwork"],[15,3,1,"","add"],[15,3,1,"","connect"],[15,3,1,"","connect_stack"],[15,3,1,"","contract"],[15,3,1,"","contract_"],[15,3,1,"","contract_between"],[15,3,1,"","contract_between_"],[15,3,1,"","contract_edges"],[12,3,1,"","copy"],[15,3,1,"","disconnect"],[15,3,1,"","div"],[15,3,1,"","einsum"],[12,3,1,"","empty"],[15,3,1,"","mul"],[12,3,1,"","ones"],[15,3,1,"","permute"],[15,3,1,"","permute_"],[15,3,1,"","qr"],[15,3,1,"","qr_"],[12,3,1,"","rand"],[12,3,1,"","randn"],[15,3,1,"","renormalize"],[15,3,1,"","rq"],[15,3,1,"","rq_"],[15,3,1,"","split"],[15,3,1,"","split_"],[15,3,1,"","stack"],[15,3,1,"","stacked_einsum"],[15,3,1,"","sub"],[15,3,1,"","svd"],[15,3,1,"","svd_"],[15,3,1,"","svdr"],[15,3,1,"","svdr_"],[15,3,1,"","tprod"],[15,3,1,"","unbind"],[12,3,1,"","zeros"]]},objnames:{"0":["py","class","Python class"],"1":["py","property","Python property"],"2":["py","method","Python method"],"3":["py","function","Python function"]},objtypes:{"0":"py:class","1":"py:property","2":"py:method","3":"py:function"},terms:{"0":[1,3,4,6,7,8,9,10,11,14,15,17,19,21,22,23],"00":[4,7],"0000":4,"0000e":4,"0086":1,"01":[7,10],"0101":1,"0188":1,"02":7,"0234":1,"03":[9,10],"0371":1,"04":[7,9],"0412":1,"0440":1,"0445":1,"0465":1,"05":7,"0501":1,"0521":1,"06":[7,9],"0633":1,"0639":17,"07":[7,9],"0704":1,"0728":17,"0743":1,"0750":17,"08":[4,9],"0806":1,"0820":17,"08595":11,"09":[7,10,14],"0913":17,"0983":17,"0998":17,"1":[1,3,4,6,7,8,9,10,11,14,15,23],"10":[1,4,6,7,8,9,10,11,14,15,17,19,22,23],"100":[1,4,6,7,8,9,10,11,14,15,18,19,20,21],"10000":[17,23],"101":7,"1024":[17,23],"104":7,"106":7,"108":7,"1091":1,"11":[6,7,8,9,10],"110":7,"1123":17,"113":7,"115":7,"117":7,"119":7,"12":[6,7,9,10],"1211":1,"122":7,"124":7,"126":7,"128":7,"13":[6,7,9],"1308":17,"131":7,"133":7,"135":7,"137":7,"1371":1,"14":[6,7,8,9,10,23],"140":7,"142":7,"144":7,"146":7,"149":7,"1496":1,"15":[1,6,7,8,9,10,15,19],"151":7,"153":7,"155":7,"1573":17,"158":7,"159674":10,"159962":10,"16":[6,7,8,9,10],"160":7,"161204":17,"162":7,"164":7,"16406":9,"1654":1,"167":7,"169":7,"17":[6,7,9,10],"171":7,"173":7,"176":7,"178":7,"18":[6,7,10],"180":7,"182":7,"185":7,"187":7,"1878":23,"189":7,"19":[7,9,10],"191":7,"194":7,"1955":17,"196":7,"198":7,"1990":17,"1998":1,"1e":[6,7,8,9,10,14,17,22,23],"2":[1,4,6,7,8,9,10,14,15,23],"20":[1,7,10,14,15,23],"200":7,"203":7,"205":7,"2052":23,"207":7,"209":7,"2091":1,"2095":1,"21":[7,10],"2111":1,"212":7,"214":7,"216":7,"218":7,"22":[6,7,10],"2203":23,"221":7,"223":7,"2254":1,"23":[7,10],"2306":11,"2316":1,"2357":23,"236220":[10,17],"2380":1,"2390":1,"24":[7,8],"2449":1,"2461":1,"2477":1,"25":[7,8],"2503":1,"2517":1,"26":7,"2618":23,"27":[7,8],"2730":1,"2775":1,"27944":9,"2799":1,"28":[6,7,8,9,10,17,23],"2808":1,"2840":23,"2856":1,"2866":1,"2898":1,"29":[7,8],"2d":14,"2f":[6,7,8,9,10],"3":[1,4,6,7,8,9,10,11,14,15,21,23],"30":[7,9,23],"3083":17,"3088":1,"31":7,"3139":1,"3149":23,"32":[7,14,17],"3222":1,"33":[6,7,9],"3340":1,"3370":1,"3371":1,"3381":1,"3393":1,"34":7,"3427":1,"3489":1,"35":[6,7,9],"3508":1,"3513":1,"3585":1,"36":7,"37":[7,8],"3711e":4,"3714":23,"3760":1,"3784":1,"38":[6,7,8],"3821":1,"39":[6,7,8,9,10],"39760":9,"4":[1,4,6,7,8,9,10,14,15,19,21,23],"40":[6,7,10,23],"4005":1,"41":7,"4181":1,"4184":1,"42":[7,8],"4216":1,"43":7,"4383":1,"4385":1,"44":7,"4402":1,"4431":1,"4461":1,"45":[6,7,8,9],"4572":1,"4588":1,"46":7,"462803":6,"47":7,"4731":1,"4761":1,"48":7,"49":[7,8],"4974":1,"499320":23,"4f":[17,23],"5":[1,4,6,7,8,9,10,11,14,15,18,19,20,21,23],"50":[6,7,9,23],"500":[11,17,22,23],"5000":4,"5021":1,"5029":1,"5069":1,"51":7,"5161":1,"52":[7,8,9,10],"5224":23,"53":7,"54":7,"5401":1,"5406":1,"55":[6,7,10],"553036":6,"553186":23,"5567":1,"5570":1,"56":[7,9],"5676":17,"57":[7,9,10],"5731":1,"5760":1,"5797":1,"58":[7,10],"59":[7,8],"5920":1,"6":[1,4,6,7,8,9,10,14,23],"60":[7,23],"60000":[17,23],"6023":1,"61":[6,7],"6149":17,"61999999999999":9,"62":[7,8,9],"6225":1,"6227":1,"6295":1,"63":7,"6356":1,"6399":1,"64":[6,7,8,9,10],"6492":1,"6495":1,"65":[7,8,10],"6524":1,"6545":1,"6551":1,"66":7,"67":[7,10],"68":7,"6811":1,"69":[7,8,9,10],"6925":1,"6982":1,"7":[1,6,7,8,9,10,11,14,15,19,20,23],"70":[7,8,9,23],"71":7,"72":[7,10],"73":[7,9],"74":[7,10],"75":[4,7],"7500":4,"7592":1,"76":[7,10],"7688":1,"77":7,"7752":1,"78":7,"7812":1,"79":[6,7,10],"7997":1,"8":[6,7,8,9,10,11,14,17],"80":[6,7,23],"8090":1,"81":[6,7,8,9],"8147":1,"82":7,"8227":1,"83":[6,7,9,10],"8361":1,"8387":1,"84":[6,7,8,9,10],"8441":1,"8442":1,"85":[6,7,9],"8502":23,"86":[6,7,9],"8627":23,"8633":1,"8649":1,"87":[6,7,9],"8795":23,"88":[6,7,9],"8815":1,"8820":17,"8848":23,"8851":23,"8859":1,"89":[6,7,9],"8901":23,"8911":1,"8915":23,"8948":23,"8968":23,"8984":23,"9":[6,7,8,9,10,11,14,17,22,23],"90":[7,9,23],"9006":1,"9009":23,"9011":[17,23],"9026":23,"9048":23,"9053":17,"91":[7,10],"9125":23,"9145":1,"9174":23,"92":[7,10],"9231":23,"9265":1,"9284":23,"93":[7,10],"9320":1,"9371":17,"9396":17,"94":[7,9],"9400":17,"9432":1,"95":[6,7,8,9,10],"9509":17,"9526":17,"9551":1,"9561":1,"9585":17,"96":[7,8,9,10],"9600":17,"9618":1,"9621":17,"9625":17,"9668":17,"9677":17,"9696":17,"97":[7,8,10],"9700":17,"9721":17,"9729":17,"9731":17,"9734":17,"9736":17,"9738":17,"9743":17,"9768":17,"9775":17,"9793":17,"98":[6,7,8,10,23],"9844":1,"99":[6,7,17],"9925":1,"9942":1,"9957":1,"abstract":1,"boolean":[1,12,14],"break":[7,8],"byte":[17,23],"case":[1,3,11,14,20,21,22],"class":[0,1,6,7,8,9,14,17,19,22,23],"default":[1,4,14,20,21],"do":[1,14,17,18,20,21,22],"final":[1,3,14,19,22],"float":[1,3,6,7,8,9,10,14,15],"function":[1,14,15,19,23],"garc\u00eda":11,"import":[1,6,7,8,9,10,11,18,19,20,21,22,23],"int":[1,3,4,10,12,14,15],"n\u00ba":[17,23],"new":[1,4,14,15,17,18,19,20,22,23],"p\u00e9rez":11,"return":[1,3,4,6,7,8,9,10,12,14,15,17,18,19,20,22,23],"super":[1,6,7,8,9,22,23],"true":[1,6,7,8,9,10,11,12,14,15,17,19,20,21,22,23],"try":[10,19,23],"while":[1,7,8,11,19,20,21],A:[1,11,14,15,18],And:[20,23],As:[11,14,17,18,19,20],At:[1,18],Be:[1,19],But:[19,20,22,23],By:[1,4,11,14,17,18],For:[1,14,15,18,19,20,21,22],If:[1,3,4,11,14,15,18,20,22],In:[1,3,8,11,14,15,17,18,19,20,21,22,23],Is:14,It:[1,3,11,14,15,17,18,19,23],Its:[1,15],Not:22,Of:[20,22],On:1,One:[1,10,18],That:[1,3,4,14,15,19,20,22,23],The:[1,3,4,11,14,15,16,18,19,20,23],Then:[1,14,20],There:[1,14,19,21,22],These:[1,3,11,14,18,19,21],To:[1,3,4,11,13,14,15,17,18,19,20,21,22],With:[1,10,11,15,18,19,22],_:[1,6,7,8,9,10,14,15,17,18,22,23],__call__:1,__init__:[1,6,7,8,9,22,23],_channel:14,_copi:1,_dim:1,_mats_env:[7,8],_percentag:[1,3,14,15],_size:14,_size_0:14,_size_1:14,_size_:1,a_:15,abil:11,abl:[1,11,18,19,22],about:[1,15,18,19,20,21,22],abov:[1,18,22],abstractnod:[0,12,15],acc:[6,7,8,9,10,17,23],acceler:[1,9,10],accept:19,access:[1,18,19,20],accomplish:[1,3,14,21],accord:[1,15],accordingli:14,account:[1,14,21],accross:[1,15],accumul:[3,14],accuraci:[6,7,8,9,10,17,23],achiev:1,act:[1,14,18,20],action:1,activ:1,actual:[1,14,18,19,20,23],ad:[1,4,14,15,19,22],adam:[6,7,8,9,10,17,23],adapt:[11,15],add:[0,1,14,17,19,22,23],add_data:[1,9,14,21,22],add_on:[0,23],addit:[1,11,14,15],addition:1,address:[1,18,20],admit:[14,15],advanc:[1,11,16],advantag:[8,20,22],advis:11,affect:[1,15],after:[1,3,8,14,17,23],afterward:[1,14,15],again:[8,10,14,17,20,22],against:14,aim:[3,14,17],alejandro:11,algorithm:[1,14,20,22,23],all:[1,3,4,8,10,14,15,17,18,19,20,21,22,23],allow:[1,11,14],almost:[1,19,20,21,23],along:[1,15],alreadi:[1,7,14,15,17,18,19,22],also:[1,3,8,11,13,14,15,17,18,19,20,21,22],alter:14,although:[1,10,18,19,20,21,22],alwai:[1,11,14,15,17,18,19,20,21,22],among:[3,14],amount:[15,17],an:[1,3,4,6,9,10,11,12,14,15,17,18,19,20,21,22],ancillari:[1,20,21],ani:[1,7,8,14,15,18,19,22],anomali:11,anoth:[1,14,15,18,19],antialia:[7,8,9,10],anymor:1,api:2,appear:[1,15,18],append:[6,7,8,19,20,21,22,23],appli:[1,11,14,15,19],approach:[8,10,11],appropi:[14,15,22,23],approxim:1,ar:[1,3,4,8,11,13,14,15,17,18,19,21,22,23],arang:[4,7,8],arbitrari:1,architectur:11,archiveprefix:11,arg:[1,6,7,8,9,12,14],argument:[1,3,12,14,15,22],aros:1,around:17,arrai:18,arrow:15,arxiv:11,ascertain:11,assert:[1,11,14,15,18,19,20,21],assign:14,assum:[1,4,14,15],attach:[1,15,18],attribut:[1,14,22],author:11,auto_stack:[1,15,17,20,21,22,23],auto_unbind:[1,15,17,20,22,23],automat:[1,15,18,19,20],auxiliari:15,avail:[11,17,18],avoid:[1,3,14,15,20],awar:[1,19],ax:[1,7,8,14,15,18,19],axes_nam:[1,7,8,11,15,18,19,20,21,22],axi:[0,4,7,8,15,18,22,23],axis1:1,axis2:1,axis_0:1,axis_1:1,axis_2:1,axis_n:1,b:[4,14,15,19],b_1:3,b_:15,b_m:3,back:1,backend:[6,7,8,9,10],backpropag:[17,23],backward:[1,6,7,8,9,10,11,17,19,23],bar:[7,8],base:[1,4,10,14,15,19],basi:[0,10,14],basic:[1,11,18,19],batch:[1,3,4,10,11,14,15,17,18,19,21,22],batch_0:[4,14],batch_1:[1,14,22],batch_idx:[6,7,8,9,10],batch_m:1,batch_n:[4,14,22],batch_siz:[6,7,8,9,10,14,17,23],becaus:[1,14,20,23],becom:[1,11,14,15],been:[1,14,15,18,22],befor:[1,6,7,8,9,10,14,17,18,22,23],beforehand:18,beggin:8,begin:[4,17,18,23],behav:[1,20],behavior:[3,14],behaviour:[1,14,15,18,20,21],being:[1,3,4,14,15,19],belong:[1,14,17,18,19],besid:[1,14,15,18],best:[10,11],better:18,between:[1,3,4,10,11,14,15,18,20,21,23],bia:9,big:[14,19],bigger:14,binom:4,blank:1,block:[7,8],block_:8,block_length:[7,8],block_nod:[7,8],block_posit:7,bmatrix:4,bodi:18,bond:[3,8,14],bond_dim:[6,7,8,9,10,11,14,17,23],bool:[1,3,12,14],border:14,both:[1,14,15,18,19,20,21,22],bottom:[1,14,23],boudari:14,bound:[1,3,14,15],boundari:[3,7,8,9,10,14],bracket:1,bridg:11,bring:17,build:[1,6,11,16,19],built:[11,17,19,22],bunch:[1,19,22],c_:15,cach:[1,20],calcul:20,call:[1,14,15,17,19,22,23],callabl:15,can:[1,3,4,6,7,8,9,10,11,13,14,15,17,18,19,20,21,22,23],cannot:[1,7,14,15,19,21],canon:[6,10,14,17,23],canonic:[6,9,10,14,17,23],canonicalize_univoc:14,carri:[1,14,15,18],cast:1,cat:[6,23],caus:14,cdot:[3,4,14],center:14,central:[1,18,23],certain:[1,4,10,11,15],chang:[1,6,10,11,14,15,18,20,22,23],change_s:1,change_typ:1,channel:14,charact:1,check:[1,6,7,8,9,10,14,15,17,18,19,20,21,23],check_accuraci:[6,7,8,9,10],check_first:15,child:1,choic:10,choos:10,chose:21,classif:[14,17],classifi:[17,22],clone:[11,13,14],close:14,cmap:[6,7,8,9,10],cnn:[6,23],cnn_snake:[6,23],cnn_snakesb:[6,23],co:4,code:[1,10,11,18,23],coincid:[14,15,18],collect:[1,11,15,19,20],column:14,com:[11,13],combin:[1,6,10,14,18,22,23],come:[1,15,17,18,21,22],comma:15,command:[11,13],common:[1,14,17,22],compar:20,complementari:14,complet:[1,15],compon:[0,4,11,15,17,19],compos:[7,8,9,10,17,23],comput:[1,3,4,11,13,14,15,17,19,20,22,23],concaten:4,concept:11,condit:[3,14,15],configur:[10,11],conform:1,connect:[0,1,9,11,14,18,19,20,21,22],connect_stack:0,consecut:[3,14],consid:[11,14],consist:[1,15],constraint:11,construct:[1,11,19,20],contain:[1,14,18,19,20,21],context:11,continu:[14,17],contract:[0,1,6,7,8,9,10,11,14,16,17,18,20,21,22],contract_:[0,1,19],contract_between:[0,1,19],contract_between_:[0,1,7,8,19],contract_edg:[0,1,19],contract_edges_ip:[1,15],contrast:[10,14],contrat:14,control:[1,18,20],conv2d:[6,14,23],conv_mp:14,conv_mps_lay:14,conv_pep:14,conv_tre:14,conveni:[14,21],convent:[1,15,19,22],converg:10,convmp:0,convmpslay:[0,6,23],convolut:[6,14,22,23],convpep:0,convtre:0,convump:0,convumpslay:0,convupep:0,convutre:0,copi:[0,1,14],core:[8,10,14,18,23],corner:14,correct:[1,7,8,9,10,20],correctli:15,correspond:[1,3,4,14,15,18,19,20,22],costli:[1,14,20],could:[14,17,18,20,22],count:17,coupl:[1,18,21],cours:[20,22],cpu:[1,6,7,8,9,10,17,22,23],creat:[1,6,7,11,14,15,16,19,20,21,22],creation:[11,15],criterion:[6,7,8,9,10],crop:[1,14],crossentropyloss:[6,7,8,9,10,17,23],cuda:[1,6,7,8,9,10,17,22,23],cum:[1,3,14,15],cum_percentag:[1,3,6,7,8,9,10,14,15,17,23],current:[1,14],custom:[1,11,14,16],cut:[3,17,23],cutoff:[1,3,14,15],d:[4,11],d_1:3,d_n:3,dagger:[1,15],dangl:[1,15,17,18,19],data:[1,4,6,7,8,9,10,11,14,15,19,21,22,23],data_0:[1,21],data_1:1,data_nod:[1,7,8,19,21,22],data_node_:19,dataload:[6,7,8,9,10,17,23],dataset:[14,17,23],dataset_nam:[6,7,8,9,10],david:[4,11],ddot:4,de:[1,14,19],decompos:[3,15],decomposit:[0,1,9,14,15,17,19,23],decreas:1,dedic:14,deep:[11,17],deepcopi:1,def:[1,6,7,8,9,10,17,22,23],defin:[1,4,11,14,18,19,21,22,23],degre:[4,8,10],del:[1,15],delet:[1,15],delete_nod:1,demonstr:11,densiti:14,depend:[1,3,10],deriv:14,descent:[6,7,8,9,10,18,19],describ:[1,9,14,22],design:14,desir:[4,8,19,20,21,22],detach:9,detail:[1,11,14,19],determin:11,develop:11,devic:[1,6,7,8,9,10,14,17,22,23],devot:20,df:1,diagon:[1,12,15],diagram:18,dictionari:1,did:23,differ:[1,5,6,11,14,15,16,17,18,19,20,22,23],differenti:[11,16,18],dilat:14,dim:[4,6,7,10,17,23],dimens:[1,3,4,8,14,15,17,18,22],dimension:[18,23],direct:7,directli:[1,11,13,15,18,19,20,21],disconnect:[0,1,14,18],discret:[0,10,14],distinct:1,distinguish:21,distribut:[3,12,14],div:[0,19],divers:11,divid:[1,15],divis:15,divisor:15,dmrg:[5,10],document:[1,15],doe:[1,14,18,20],don:15,done:[1,3,14,18],down:14,down_bord:14,download:[6,7,8,9,10,11,13,23],drawn:12,drop_last:[17,23],dtype:1,due:14,duplic:14,dure:[1,8,17,20,21],dynam:14,e:[1,4,11,14,15,21],each:[1,3,4,8,14,15,17,18,19,20,21,22,23],earlier:18,easi:[1,11],easier:18,easili:[1,18,19,22,23],edg:[0,11,14,17,18,19,20,21,22,23],edge1:[1,15],edge2:[1,15],edgea:1,edgeb:1,edges_dict:1,effect:[1,19],effici:[1,11,18],einstein:19,einsum:[0,19],either:[14,15],element:[1,4,11,12,14,15,18],element_s:[17,23],elif:[6,7,8,9,10],els:[6,7,8,9,10,17,22,23],emb:[4,17,22],emb_a:4,emb_b:4,embed:[0,6,7,8,10,14,17,22,23],embedd:4,embedding_dim:[7,8,10],embedding_matric:14,empti:[0,1,14,18,19,20],enabl:[1,11,14,18,19,20],end:[1,4,8,15],engin:18,enhanc:11,enough:[1,14],ensur:[1,11],entail:[1,20],entangl:14,entir:1,entropi:14,enumer:[1,6,7,8,9,10,14],environ:14,epoch:[1,6,7,8,9,10,17,22,23],epoch_num:[17,23],eprint:11,equal:[1,7,14,22],equilibrium:17,equival:[1,4,14,18,23],error:[11,14],essenti:[1,11],etc:[1,15,18,22],eval:[6,7,8,9,10],evalu:19,even:[1,3,8,11,14,18,20,22],evenli:[3,14],everi:[1,15,18,20,23],exact:1,exactli:1,exampl:[1,2,4,6,9,14,15,17,19,21,22],exce:7,exceed:14,excel:11,except:[12,14],exclud:[1,21],execut:15,exist:[1,14,15,19],exp:14,expand:[1,4,22],expect:[3,14],experi:[1,11,18,20],experiment:11,explain:[1,18,19,20],explan:[1,14,15],explicitli:[1,15,21,22],explod:[3,14],explor:[11,18],express:15,extend:1,extra:[3,14,15],extract:[3,14,18],extractor:23,ey:22,eye_tensor:22,f:[6,7,8,9,10,17,18,19,20,21,23],facilit:[1,11,15,18],fact:[1,21,22],factor:[3,14],fail:[14,15],failur:11,fals:[1,3,6,7,8,9,10,12,14,15,17,19,20,22,23],familiar:[11,18],fashion:[8,19],fashion_mnist:[6,9],fashionmnist:[6,9,23],faster:[1,8,22],fastest:11,fc1:9,fc2:9,featur:[1,4,14,15,18,19,20,21,23],feature_dim:[1,22],feature_s:14,fed:23,feed:9,few:[8,11],fffc:9,file:11,fill:[1,12,18,19,22],find:11,finish:[18,22],finit:14,first:[1,4,10,11,14,15,16,18,19,20,21,23],fix:[1,18,19],flat:14,flatten:14,flavour:1,flip:[6,23],flips_x:6,fn_first:15,fn_next:15,folder:[11,13],follow:[1,11,13,14,15],forget:1,form:[1,14,17,18,19,21,23],former:[1,14,15],forward:[1,6,7,8,9,10,14,15,22,23],four:22,frac:[1,3,4,14,15],framework:11,free:14,friendli:11,from:[1,3,6,7,8,9,10,11,12,13,14,15,17,18,19,21,22,23],from_sid:14,from_siz:14,full:17,fulli:[9,11],func:[1,14],functool:23,fundament:11,further:[1,11],furthermor:[1,20,22],futur:[1,20],g:[1,14,15,21],gain:11,gap:11,garc:11,gaussian:14,ge:[1,3,14,15],gener:[1,20],get:[1,6,7,8,9,10,17,18,22,23],get_axi:[1,7,8,18],get_axis_num:[1,18],get_edg:1,git:[11,13],github:[11,13],give:[1,18,20,21],given:[1,4,6,14],glimps:17,go:[1,15,17,19,20,23],goal:11,goe:14,good:[1,6,7,8,9,10,17],gpc20:6,gpu:[17,22],grad:[1,11,19],gradient:[1,6,7,8,9,10,11,14,18,19],graph:[1,18,20,22],grasp:11,greater:[1,7,22],greatli:11,grei:[6,7,8,9,10],grid:14,grid_env:14,group:15,gt:[6,7,8,9,10],guid:11,ha:[1,3,8,14,15,17,18,20,21,22,23],had:[1,15,20],hand:[1,11,14,22],happen:[1,3,14,20,23],har:11,hardwar:11,hat:4,have:[1,3,4,14,15,17,18,19,20,21,22,23],heavi:22,height:[14,17,22],help:[1,20],henc:[1,15,17,18,19,20,22],here:[1,7,8,10,14,19],hidden:[1,21],high:[1,4,6,7,8,9,10,12],highli:11,hint:1,hold:[19,21],hood:20,horizont:14,how:[1,6,7,8,9,10,11,14,15,16,17,19,21,23],howev:[1,11,14,15,18,19,20],http:[11,13],hundread:19,hybrid:[5,11,16],hyperparamet:[6,7,8,9,10],i:[1,3,6,7,8,11,14,15,18,19,20,21,22],i_1:12,i_n:12,ideal:11,ident:14,identif:11,idx:18,ignor:[6,10],ijb:[15,19],ijk:15,im:15,imag:[6,14,17,22,23],image_s:[6,7,8,9,10,17,22,23],immers:11,implement:[1,11,17,23],improv:17,imshow:[6,7,8,9,10],in_1:3,in_channel:[6,14,23],in_dim:[7,8,10,11,14,17],in_env:14,in_featur:14,in_n:3,in_region:14,in_which_axi:1,includ:[1,11,14,18],incorpor:11,inde:[1,14],independ:[1,15],index:[1,6,7,8,9,10,18,20],indic:[1,3,12,14,15,18,19,20,21],infer:[1,14,15,18,20,22],inform:[1,11,14,15,18,19,20],ingredi:[1,18],inherit:[1,14,15,21],inic:22,init:19,init_method:[1,6,7,8,10,14,17,18,19,23],initi:[0,1,7,8,9,10,14,17,18,19,22,23],inlin:14,inline_input:[6,9,10,14],inline_mat:[6,9,10,14],inner:1,input:[1,3,7,8,14,15,17,18,19,20,21,22,23],input_edg:[1,21,22],input_nod:22,input_s:[6,7,8,9,10],insid:[1,11,13],insight:19,instal:[2,17,18],instanc:[1,15,20,21,22],instanti:[1,3,14,18,20,21,22],instati:18,instead:[1,14,19,20,21,22,23],integ:4,integr:11,intend:[1,14],interior:14,intermedi:[1,19,21,22],intern:[14,15],intric:11,intricaci:11,introduc:[6,14],introduct:11,invari:[1,14,20,21],inverse_memori:1,involv:[1,15,20],is_attached_to:1,is_avail:[6,7,8,9,10,17,22,23],is_batch:[1,18],is_connected_to:1,is_dangl:[1,18],is_data:[1,21],is_leaf:[1,21],is_node1:1,is_result:[1,21],is_virtu:[1,21],isinst:[1,19],isn:[17,18],isometri:14,issu:1,item:[6,7,8,9,10,17,23],iter:[1,8,14,20],its:[1,11,12,14,15,18,20,22,23],itself:[1,14,22],itselv:14,j:[1,4,8,11,15],jkb:[15,19],jl:15,jo:11,joserapa98:[11,13],just:[1,11,14,18,20,21,22,23],k:[11,15],keep:[1,3,14,15,20],keepdim:1,kei:[1,10,11,18],kept:[1,3,14,15],kernel:14,kernel_s:[6,14,23],kerstjen:11,keyword:[1,12,14],kib:[15,19],kind:[1,17],klm:15,know:[1,18,19,20],known:22,kwarg:[1,6,7,8,9,12,14],l2_reg:[17,23],l:[4,15],label:[14,17,21,23],lambda:[17,23],larger:1,largest:1,last:[1,7,14,15,22,23],later:22,latter:[1,14],layer:[1,6,11,14,17,22,23],ldot:12,lead:[1,15,21],leaf:[1,15,21,22],leaf_nod:[1,21],learn:[1,8,11,15,17,18,19,20,21,22,23],learn_rat:[17,23],learnabl:19,learning_r:[6,7,8,9,10],least:[1,14,15,17,20,23],leav:[8,14,20],left1:19,left2:19,left:[1,7,8,11,14,15,18,19,20,21,22,23],left_bord:14,left_down_corn:14,left_nod:[7,8,14],left_up_corn:14,leg:1,len:[6,7,8,9,10,15,22],length:[1,14],let:[17,18,19,20,22,23],level:[1,4,10,18],leverag:18,lfloor:4,li:[1,11,18],librari:[11,18,23],like:[0,1,5,10,11,14,17,18,19,20,21,22,23],line:[11,14,23],linear:[9,11,22],link:1,list:[1,3,12,14,15,18,19,22],load:[6,7,8,9,10,17,22],load_state_dict:[9,10,22],loader:[6,7,8,9,10,17,23],local:[4,14],locat:[1,14],log_norm:14,log_scal:14,logarithm:[3,14],logarithmoc:14,logit:[17,23],longer:[1,20],look:[1,10],loop:[17,22],lose:17,loss:[6,7,8,9,10,23],loss_fun:[17,23],lot:[1,17],low:[1,4,6,7,8,9,10,12,14],lower:[1,3,14,15],lr:[6,7,8,9,10,17,23],lst_y:6,lt:10,lvert:4,m:[1,3,11,13,15],machin:[11,17],made:[1,15],mai:[1,3,10,11,14],main:[1,18,19,20,22],mainli:1,maintain:1,make:[1,8,18,19,20,21],make_tensor:[1,14],manag:[1,22],mandatori:[1,15],mani:[1,8,14,17,18,20,23],manipul:18,manual:[1,14,22],manual_se:[17,23],map:4,margin:14,marginalize_output:14,master:[11,13],mat:3,mat_to_mpo:[0,9],match:[1,10,14,15],matplotlib:[6,7,8,9,10],matric:[14,15,19,22],matrix:[1,3,4,9,11,14,15,17,18,20,22],mats_env:[7,8,14],mats_env_:[7,8],max:[6,7,8,9,10,17,23],max_bond:14,maximum:[4,14],maxpool2d:[6,23],mayb:15,mb:[17,23],mean:[1,12,14],megabyt:[17,23],mehtod:18,member:[1,14],memori:[1,11,16,17,18,21,22,23],mention:[14,22],merg:[7,8],merge_block:7,messag:11,method:[1,14,15,18,19,20,21,22],middl:14,middle_sit:14,might:[1,14,20,21,22],mile:4,mimic:22,minuend:15,misc:11,miscellan:[17,23],mit:11,mkdir:[6,7,8,9,10],mnist:[7,8,10,17],mod:4,mode:[1,6,14,15,17,21,22,23],model:[0,1,5,9,11,16,18,19,20,21],model_nam:[6,7,8,9,10],modif:1,modifi:[1,14,19],modul:[1,6,9,11,13,14,17,18,19,22,23],modulelist:[6,23],monomi:4,monturiol:11,mooth:11,more:[1,3,11,14,15,18,19,20,21],most:[1,11,18,21],move:[1,14,18],move_block_epoch:[7,8],move_nam:1,move_to_network:[1,15],movec:15,mp:[0,3,5,6,9,11,17,18,19,20,21,22,23],mpo:[0,3,9,22],mpo_tensor:9,mps_data:[9,14],mps_dmrg:7,mps_dmrg_hybrid:8,mps_hdmrg:8,mps_layer:[6,14],mps_tensor:9,mpsdata:[0,3,9],mpslayer:[0,7,8,10,11,17,22],mpss:6,much:[1,8,20,21,22,23],mul:[0,19],multi:18,multipl:[14,18,20],multipli:[14,15],must:[1,3,15,18,22],my_model:11,my_nod:[1,20],my_paramnod:1,my_stack:1,mybatch:1,n:[1,4,6,9,10,11,12,14,17,23],n_:[1,4],n_batch:[3,9,14],n_block:8,n_col:14,n_epoch:22,n_featur:[7,8,9,10,11,14,17,22],n_param:[6,9,10,17,23],n_row:14,name:[1,7,8,11,15,17,18,19,20,22,23],name_:1,necessari:[1,14,15,17,18,22,23],need:[1,11,14,15,17,18,19,20,22],neighbour:[1,7,8,14,19],nelement:[17,23],net2:1,net:[1,11,15,18,19,20,21],network:[0,5,7,8,10,11,14,15,16,20,21,22],neumann:14,neural:[5,11,14,16,17],never:[1,15,21],nevertheless:11,new_bond_dim:10,new_edg:[1,15],new_edgea:1,new_edgeb:1,new_mp:22,new_nodea:[1,15],new_nodeb:[1,15],new_tensor:20,next:[1,14,15,18,19,20,21],nn:[1,6,7,8,9,10,11,14,17,18,19,22,23],no_grad:[6,7,8,9,10,17,23],node1:[1,11,15,18,19,20,21],node1_ax:[1,7,8,15],node1_list:1,node1_lists_dict:1,node2:[1,11,15,18,19,20,21],node2_ax:[1,7,8,15],node3:[19,20],node:[0,3,7,8,11,12,14,16,17,20,22,23],node_0:1,node_1:1,node_:[18,19,20,21],node_left:[1,15],node_n:1,node_ref:1,node_right:[1,15],nodea:[1,15],nodeb:[1,15],nodec:[1,15],nodes_list:15,nodes_nam:1,nodesa:15,nodesb:15,nodesc:15,non:[1,14,15,19,22],none:[1,3,7,8,11,14,15,19,22],norm:[1,3,14,15],normal:[1,3,12,14,15],notabl:11,notat:18,note:[1,11,15,18,19,22],noth:[1,18,20],now:[1,18,19,20,22,23],npov15:9,nto16:10,nu_batch:14,num:1,num_batch:[17,23],num_batch_edg:[1,21,22],num_class:[6,7,8,9,10],num_correct:[6,7,8,9,10],num_epoch:[6,7,8,9,10,17,23],num_epochs_canon:17,num_sampl:[6,7,8,9,10],num_test:[17,23],num_train:[17,23],number:[1,3,14,15,17,19,22,23],numel:[1,6,9,10],numer:15,nummber:1,o:11,obc:[3,7,8,9,10,14],object:[1,18,23],obtain:[14,17],oc:14,occur:[1,14,20],occurr:11,off:[17,20,23],offer:11,often:[1,20],onc:[1,8,11,14,15,20],one:[1,3,4,6,7,9,10,14,15,17,18,19,20,21,22,23],ones:[0,1,4,6,14,18,19,21,23],ones_lik:[6,17,23],onli:[1,11,14,15,18,19,20,21,22],open:14,oper:[0,1,14,18,21,22],operand:[1,18,19],opposit:14,opt_einsum:[11,15,19],optim:[1,6,7,8,9,10,11,14,23],option:[1,3,4,6,10,11,14,15,18,22],order:[1,3,6,14,15,18,19,20,21,22],orient:14,origin:[1,4,14,15,17,21,23],orthogon:14,ot:14,other:[1,8,14,15,18,19,20,21,22,23],other_left:18,other_nod:20,otherwis:[1,14,15,23],otim:4,our:[6,7,8,9,10,17,18,19,20,23],out:[1,14,15,22],out_1:3,out_channel:[6,14,23],out_dim:[7,8,10,11,14,17],out_env:14,out_featur:14,out_n:3,out_nod:[7,8,14],out_posit:14,out_region:14,outdat:14,output:[1,3,4,6,7,8,14,15,17,22],output_dim:[6,7,8,10],output_nod:22,outsid:[11,13],over:[1,15,19],overcom:1,overhead:20,overload:1,overrid:[1,7,14,22],override_edg:1,override_nod:1,overriden:[1,19,22],own:[1,14,15,20,21],ownership:1,p:[1,6,9,10,11,15,17,23],packag:[11,13],pad:[6,14,23],pair:[1,8,14,23],pairwis:14,paper:[4,11,14,23],parallel:[1,14],param:[1,15,17,23],param_nod:[1,11,12,19,21],paramedg:1,paramet:[1,3,4,6,7,8,9,10,12,14,15,17,19,22,23],parameter:[1,7,8,14,19],parametr:[1,14,17,19,22,23],paramnod:[0,12,14,15,20,21,22],paramnode1:19,paramnode2:19,paramnode3:19,paramnodea:1,paramnodeb:1,paramstacknod:[0,15,20,21],pareja2023tensorkrowch:11,pareja:11,parent:[1,22],parenthesi:1,part:[1,14,19,21,22],partial:[14,23],partial_dens:14,particular:[1,22],pass:[14,17,22],patch:[14,22],pattern:23,pave:17,pbc:14,pep:[0,11,22],perform:[1,10,11,14,15,17,18,19,20,22,23],period:14,permut:[0,1,9,19,20],permute_:[0,1,19],persist:11,phase:15,phi:4,phys_dim:[9,14],physic:[3,14,18],pi:4,piec:[18,22],pip:[11,13,17,18],pipelin:[11,17],pixel:[14,17,22,23],place:[1,14,15,19,20,22],plai:21,pleas:11,plt:[6,7,8,9,10],plu:14,plug:14,point:[1,17],poli:[0,8,10,22],pool:[6,23],posit:[7,14,22],possibl:[1,6,7,8,9,10,14,17,18,19,20],potenti:11,power:[4,14,17],poza:11,practic:[1,11],practition:17,pre:9,precis:4,pred:[17,23],predict:[6,7,8,9,10,17,23],present:23,previou:[1,14,19,20,21,22,23],previous:19,primari:11,principl:[1,21],print:[1,6,7,8,9,10,14,15,17,23],probabl:17,problem:[1,20],process:[1,8,10,11,17],prod:6,product:[11,14,15,17,18,20,22],profici:11,project:14,properli:22,properti:[1,7,14,15,18],proport:[1,3,14,15],prove:14,provid:[1,3,11,14,15,18],prune:23,pt:[6,7,8,9,10,22],purpos:[1,18,20,21],put:[1,14,17,23],pyplot:[6,7,8,9,10],pytest:[11,13],python:[11,13,17],pytorch:[1,11,15,17,18,19,20,22,23],q:15,qr:[0,1,3,14,19],qr_:[0,1,14],quantiti:[1,3,14,15],quantum:18,qubit:14,quit:[11,23],r:[11,15],r_1:15,r_2:15,rais:[1,7],ram:11,rand:[0,1,4,6,10,14],randint:[4,6,7,8,9,10],randn:[0,1,4,6,10,11,14,15,18,19,20,21,22],randn_ey:[6,8,10,14,17,23],random:[14,15,19],random_ey:22,random_sampl:[6,7,8,9,10],rang:[1,6,7,8,9,10,11,14,15,17,18,19,20,21,22,23],rangl:4,rank:[1,3,4,7,8,14,15,19],rapid:11,rather:[1,14,15,18,20,21],re:[1,15,17,23],reach:17,readi:22,realli:[1,21],reamin:8,reattach:1,reattach_edg:1,recal:[1,22],recogn:19,recommend:[1,11,14],reconnect:[1,19],recov:[1,3,14],reduc:[1,14,15,17,23],redund:[1,15],refer:[1,2,10,15,18,19,20,21,22],referenc:1,regard:[1,15,19],regio:14,relat:18,relev:[1,18],reli:[11,15],relu:[6,9,11,23],remain:[8,14],remov:[1,14,15],renorm:[0,1,3,6,9,10,14,19],renorma:[1,15],repeat:[1,8,10,15],repeatedli:20,replac:1,repositori:[11,13],repres:[14,18],reproduc:[1,10,11,20],requir:[1,14,15,21],requires_grad:1,rerun:11,res1:19,res2:19,rescal:14,research:11,reserv:[1,15,20],reset:[1,6,7,8,9,10,14,15,20,21,22],reset_tensor_address:1,reshap:[3,7,8,9,10,15,20],resiz:[7,8,9,10,17,23],respect:[1,3,14,15,19,20],rest:[1,19],restrict:[1,15,21],result2:15,result:[1,3,7,8,9,10,11,14,15,19,20,21,22],result_mat:[7,8],resultant_nod:[1,21],retain:1,retriev:[1,18,20,21],reus:14,review:11,rez:11,rfloor:4,rho_a:14,rid:1,right1:19,right2:19,right:[1,7,8,11,14,15,18,19,20,21,22,23],right_bord:14,right_down_corn:14,right_nod:[7,8,14],right_up_corn:14,rightmost:14,ring:19,role:[1,14,21,22],root:[6,7,8,9,10],row:14,rowch:11,rq:[0,1],rq_:[0,1,14],run:[11,13,14,17,23],running_acc:17,running_test_acc:[17,23],running_train_acc:[17,23],running_train_loss:[17,23],s:[1,4,11,12,14,15,17,18,19,20,21,22,23],s_1:4,s_i:[1,3,14,15],s_j:4,s_n:4,sai:[1,20],same:[1,8,10,14,15,18,19,20,23],sampler:[17,23],satisfi:[1,3,14,15],save:[1,6,7,8,9,10,11,14,16,19,22],scalar:14,scale:[3,14],scenario:11,scheme:10,schmidt:14,schwab:4,score:[6,7,8,9,10,17,23],second:[1,14,15,20],secondli:22,section:[18,19,21,23],see:[1,6,7,8,9,10,11,14,15,21,23],seen:21,select:[1,15,18,19,22],self:[1,6,7,8,9,14,22,23],send:[17,22,23],sens:[17,21],sento:22,separ:[14,15],sepcifi:1,sequenc:[1,3,14,15,19,22],sequenti:[11,22],serv:[11,17],set:[1,7,8,14,15,20,21,22,23],set_data_nod:[1,7,8,9,14,21,22],set_param:[1,7,14],set_tensor:[1,14,20],set_tensor_from:[1,20,21,22],sever:[1,15],shall:[1,22],shape:[1,3,4,7,8,9,10,11,12,14,15,17,18,19,20,21,22],shape_all_nodes_layer_1:14,shape_all_nodes_layer_n:14,shape_node_1_layer_1:14,shape_node_1_layer_n:14,shape_node_i1_layer_1:14,shape_node_in_layer_n:14,share:[1,14,15,18,20,21],share_tensor:[1,14],should:[1,3,4,7,11,12,14,15,18,19,20,21,22],show:[6,7,8,9,10,18],showcas:11,shown:7,shuffl:[6,7,8,9,10],side:[1,7,8,14,15,23],similar:[1,14,15,18,19,21,22],simpl:[11,18,19],simpli:[1,18,20,22],simplifi:[1,11],simul:18,sin:4,sinc:[1,8,14,15,19,20,21,22,23],singl:[1,14,15,19,22],singular:[1,3,14,15,17,19,23],site:[7,14,18],sites_per_lay:14,situat:[1,15,21],size:[1,4,6,7,8,9,10,12,14,15,18,19,21,22,23],skip:1,skipp:20,slice:[1,15,20,21],slide:1,slower:[1,20],small:19,smaller:1,smooth:11,snake:[6,14,23],so:[1,14,15,20,22],some:[1,3,14,15,17,18,19,20,21,22],someth:18,sometim:1,sort:[1,14,18,20,21],sourc:[1,3,4,12,14,15],space:[1,17,22,23],special:[1,19,21],specif:[1,11,21],specifi:[1,3,14,15,18,19,21],split:[0,1,3,7,8,14,18,19],split_0:[1,15],split_1:[1,15],split_:[0,1,7,8,19],split_ip:[1,15],split_ip_0:[1,15],split_ip_1:[1,15],sqrt:4,squar:[14,15],squeez:[4,6,7,8,9,10],ss16:[7,10],stack:[0,1,4,6,14,17,19,20,21,22],stack_0:1,stack_1:1,stack_data:[1,15,22],stack_data_memori:[1,21],stack_data_nod:19,stack_input:22,stack_nod:[1,15,19,20,21],stack_result:[19,22],stackabl:15,stacked_einsum:[0,19],stackedg:[0,15,19],stacknod:[0,15,19,21],stai:1,stand:4,start:[1,14,18,20,22,23],state:[1,11,14,17,18,20,22],state_dict:[6,7,8,9,10,14,22],staticmethod:[6,23],statist:11,std:[1,6,8,10,12,14,17,22,23],stem:11,step:[1,6,7,8,9,10,11,16,23],stick:1,still:[1,14,15,19],store:[1,14,15,18,21,22],stoudenmir:4,str:[1,15],straightforward:11,strength:11,stride:[6,14,23],string:[1,14,15,21],structur:[1,11,14,18,22,23],sub:[0,19],subclass:[1,11,16,18,21,23],subsequ:[1,15,20],subsetrandomsampl:[17,23],substitut:1,subsystem:14,subtract:15,subtrahend:15,successfulli:10,successor:[0,15],suffix:1,suitabl:11,sum:[1,3,6,7,8,9,10,14,15,17,19,23],sum_:[1,3,14,15],summat:19,support:11,sv:[1,15],svd:[0,1,3,14,17],svd_:[0,1,14,19],svdr:[0,1,14],svdr_:[0,1,14],t:[14,15,17,18,21,22],t_:[12,15],tailor:11,take:[1,14,15,17,20,21,22],taken:[1,15],target:[6,7,8,9,10],task:[1,14,21],temporari:[1,21],tensor:[0,3,4,5,6,7,8,11,12,14,16,21,22],tensor_address:[1,14,20,21],tensorb:15,tensori:5,tensorkrowch:[1,3,4,6,7,8,9,10,12,13,14,15,16,19,21,22,23],tensornetwork:[1,11,15,16,19,21,23],termin:[11,13],test:[6,7,8,9,10,11,13,17,23],test_acc:[6,7,8,9,10],test_dataset:[6,7,8,9,10],test_load:[6,7,8,9,10],test_set:[17,23],texttt:14,th:14,than:[1,3,7,8,14,15,18,22],thank:18,thei:[1,11,13,14,15,18,19,21],them:[1,14,15,17,18,19,20,22,23],themselv:[14,15],therefor:[1,11,13,14,18,22],thi:[1,3,4,6,8,9,10,11,14,15,17,18,19,20,21,22,23],thing:[1,15,17,21,23],think:20,third:14,those:[1,20,21,22],though:[1,14,15,19,22],thought:1,three:[1,15],through:[1,11,14,15,17,22,23],thu:[1,3,8,11,14,15,18,22],time:[1,3,4,8,10,11,14,15,16,18,19,22],titl:11,tk:[1,4,6,7,8,9,10,11,14,15,17,18,19,20,21,22,23],tn1:9,tn:[9,22],tn_fffc:9,tn_linear:9,tn_model:9,tn_nn:9,togeth:[1,11,13,20,21],tool:[11,18,19],top:[11,14,17,19,23],torch:[1,3,4,6,7,8,9,10,11,12,14,15,17,18,19,20,21,22,23],torchvis:[6,7,8,9,10,17,23],total:[1,3,14,15,17,23],total_num:[17,23],totensor:[6,7,8,9,10,17,23],tprod:[0,19],trace:[1,6,7,8,9,10,14,15,17,21,22,23],trace_sit:14,track:[1,14,20],tradit:8,train:[1,5,11,14,18,19,20,22,23],train_acc:[6,7,8,9,10],train_dataset:[6,7,8,9,10],train_load:[6,7,8,9,10],train_set:[17,23],trainabl:[1,14,18,21,22,23],transform:[4,6,7,8,9,10,14,17,19,23],transit:14,translation:[1,14,20,21],transpos:[6,17,22,23],treat:1,tree:[0,22],tri:[1,15],triangular:15,trick:[20,23],tupl:[1,12,14,15],turn:[1,14,20],tutori:[1,2,17,18,19,20,21,22,23],two:[1,14,15,18,19,20],two_0:15,type:[1,3,4,6,7,8,9,10,11,12,14,15,16,18,19,20],typeerror:1,u:[1,12,15],ump:0,umpo:0,umpslay:0,unbind:[0,1,19,20,22],unbind_0:[1,15],unbind_result:19,unbinded_nod:[19,20],unbound:[1,15],under:[11,20],underli:20,underscor:[1,15],understand:[11,18],undesir:[1,3,14,15,21],unfold:14,uniform:[1,12,14,20,21,22],uniform_memori:22,uniform_nod:[20,21],uninform:17,uniqu:[14,20],unit:[0,6,7,10,14],unitari:[14,15],univoc:14,unix:[11,13],unless:1,unmerg:8,unmerge_block:7,unset:1,unset_data_nod:[1,14],unset_tensor:1,unsqueez:23,until:[1,10,14,15,17],up:[1,14,15,18],up_bord:14,updat:[8,14,17,23],update_bond_dim:[7,8,14],upep:0,upon:14,upper:[14,15],ur_1sr_2v:15,us:[1,3,4,7,11,14,15,17,18,19,20,21,22,23],usag:1,useless:17,user:[1,11],usual:[1,14,15,21,22],usv:15,util:[6,7,8,9,10,17,23],utre:0,v:[1,11,13,15],valid:1,valu:[1,3,4,14,15,17,19,22,23],valueerror:[1,7],vanilla:[18,19,20],vanish:14,vari:14,variabl:4,variant:[19,22],variat:1,variou:[11,14],vdot:4,vec:3,vec_to_mp:[0,9],vector:[3,4,14,17,22,23],veri:[1,17,18,20,22],verifi:[1,15],versa:14,versatil:11,version:[1,14,15,19],vertic:[1,14],via:[1,3,7,14,15,18,19,20,21],vice:14,view:[6,17,22,23],virtual:[1,15,20,21,22],virtual_nod:[1,21],virtual_result:[1,21],virtual_result_stack:[1,15,21],virtual_uniform:[1,14,21,22],virtual_uniform_mp:1,visit:1,von:14,wa:[1,6,14,22,23],wai:[1,5,8,14,17,18,22],want:[1,14,18,19,20,22],we:[1,6,7,8,10,11,17,18,19,20,21,22,23],weight:9,weight_decai:[6,7,8,9,10,17,23],well:[1,10,14,15],were:[1,14,15,19,21,22],what:[1,18,20,22],when:[1,14,15,17,18,19,20,21,22],whenev:[1,15],where:[1,4,8,14,15,17,18,19,20,22],whether:[1,3,12,14,18,20],which:[1,4,10,12,14,15,17,18,19,20,21,22],whilst:14,who:11,whole:[1,8,14,17,19,20,21,22],whose:[1,6,14,15,20,21],why:20,wide:[11,22],width:[14,17,22],wise:[1,15],wish:22,without:[1,12,14,15,17],won:[14,22],word:[1,18],work:[1,11,15,19,23],workflow:[1,22],worri:[1,20],would:[1,14,15,18,20,23],wouldn:21,wow:23,wrap:[1,18],written:[11,13],x:[4,6,7,8,9,10,11,14,17,22,23],x_1:4,x_j:4,x_n:4,y1:23,y2:23,y3:23,y4:23,y:[6,7,8,9,10,23],yet:[18,22],you:[11,13,17,18,19,20,21,22],your:[11,13,17,18,22],your_nod:20,yourself:11,zero:[0,1,6,7,8,9,10,14,17,22,23],zero_grad:[6,7,8,9,10,17,23],zeros_lik:20},titles:["API Reference","Components","<no title>","Decompositions","Embeddings","Examples","Hybrid Tensorial Neural Network model","DMRG-like training of MPS","Hybrid DMRG-like training of MPS","Tensorizing Neural Networks","Training MPS in different ways","TensorKrowch documentation","Initializers","Installation","Models","Operations","Tutorials","First Steps with TensorKrowch","Creating a Tensor Network in TensorKrowch","Contracting and Differentiating the Tensor Network","How to save Memory and Time with TensorKrowch (ADVANCED)","The different Types of Nodes (ADVANCED)","How to subclass TensorNetwork to build Custom Models","Creating a Hybrid Neural-Tensor Network Model"],titleterms:{"1":[17,18,19,20,21,22],"2":[17,18,19,20,21,22],"3":[17,18,19,20,22],"4":17,"5":17,"6":17,"7":17,"class":15,"function":17,"import":17,The:[21,22],abstractnod:1,add:15,add_on:4,advanc:[20,21],api:0,ar:20,axi:1,basi:4,between:19,build:[18,22],choos:17,cite:11,compon:[1,18,22],connect:15,connect_stack:15,content:2,contract:[15,19],contract_:15,contract_between:15,contract_between_:15,contract_edg:15,convmp:14,convmpslay:14,convpep:14,convtre:14,convump:14,convumpslay:14,convupep:14,convutre:14,copi:12,creat:[18,23],custom:22,data:17,dataset:[6,7,8,9,10],decomposit:3,defin:[6,7,8,9],differ:[10,21],differenti:19,disconnect:15,discret:4,distinguish:19,div:15,dmrg:[7,8],document:11,download:17,edg:[1,15],einsum:15,embed:4,empti:12,everyth:22,exampl:[5,11],faster:20,first:[17,22],how:[18,20,22],hybrid:[6,8,23],hyperparamet:17,initi:12,instal:[11,13],instanti:[10,17],introduct:[17,18,19,20,21,22],layer:9,librari:17,licens:11,like:[7,8,15],loss:17,manag:20,mat_to_mpo:3,matrix:19,memori:20,mode:20,model:[6,7,8,10,14,17,22,23],mp:[7,8,10,14],mpo:14,mpsdata:14,mpslayer:14,mul:15,name:21,network:[1,6,9,17,18,19,23],neural:[6,9,23],node:[1,15,18,19,21],notebook:11,ones:12,oper:[15,19,20],optim:17,our:22,paramnod:[1,19],paramstacknod:1,pep:14,permut:15,permute_:15,poli:4,product:19,prune:[10,17],put:22,qr:15,qr_:15,rand:12,randn:12,refer:0,renorm:15,requir:11,reserv:21,retrain:10,rq:15,rq_:15,run:20,save:20,set:17,setup:[17,18],skip:20,split:15,split_:15,stack:15,stacked_einsum:15,stackedg:1,stacknod:1,start:17,state:19,step:[17,18,19,20,21,22],store:20,sub:15,subclass:22,successor:1,svd:15,svd_:15,svdr:15,svdr_:15,tensor:[1,9,15,17,18,19,20,23],tensori:6,tensorkrowch:[11,17,18,20],tensornetwork:[18,20,22],time:20,togeth:22,tprod:15,train:[6,7,8,9,10,17],tree:14,tutori:[11,16],type:21,ump:14,umpo:14,umpslay:14,unbind:15,unit:4,upep:14,utre:14,vec_to_mp:3,wai:10,zero:12}}) \ No newline at end of file +Search.setIndex({docnames:["api","components","contents","decompositions","embeddings","examples","examples/hybrid_tnn_model","examples/mps_dmrg","examples/mps_dmrg_hybrid","examples/tensorizing_nn","examples/training_mps","index","initializers","installation","models","operations","tutorials","tutorials/0_first_steps","tutorials/1_creating_tensor_network","tutorials/2_contracting_tensor_network","tutorials/3_memory_management","tutorials/4_types_of_nodes","tutorials/5_subclass_tensor_network","tutorials/6_mix_with_pytorch"],envversion:{"sphinx.domains.c":2,"sphinx.domains.changeset":1,"sphinx.domains.citation":1,"sphinx.domains.cpp":5,"sphinx.domains.index":1,"sphinx.domains.javascript":2,"sphinx.domains.math":2,"sphinx.domains.python":3,"sphinx.domains.rst":2,"sphinx.domains.std":2,"sphinx.ext.viewcode":1,nbsphinx:4,sphinx:56},filenames:["api.rst","components.rst","contents.rst","decompositions.rst","embeddings.rst","examples.rst","examples/hybrid_tnn_model.ipynb","examples/mps_dmrg.ipynb","examples/mps_dmrg_hybrid.ipynb","examples/tensorizing_nn.ipynb","examples/training_mps.ipynb","index.rst","initializers.rst","installation.rst","models.rst","operations.rst","tutorials.rst","tutorials/0_first_steps.rst","tutorials/1_creating_tensor_network.rst","tutorials/2_contracting_tensor_network.rst","tutorials/3_memory_management.rst","tutorials/4_types_of_nodes.rst","tutorials/5_subclass_tensor_network.rst","tutorials/6_mix_with_pytorch.rst"],objects:{"tensorkrowch.AbstractNode":[[1,1,1,"","axes"],[1,1,1,"","axes_names"],[1,2,1,"","conj"],[1,2,1,"","contract_between"],[1,2,1,"","contract_between_"],[1,1,1,"","device"],[1,2,1,"","disconnect"],[1,1,1,"","dtype"],[1,1,1,"","edges"],[1,2,1,"","get_axis"],[1,2,1,"","get_axis_num"],[1,2,1,"","get_edge"],[1,2,1,"","in_which_axis"],[1,2,1,"","is_complex"],[1,2,1,"","is_conj"],[1,2,1,"","is_connected_to"],[1,2,1,"","is_data"],[1,2,1,"","is_floating_point"],[1,2,1,"","is_leaf"],[1,2,1,"","is_node1"],[1,2,1,"","is_resultant"],[1,2,1,"","is_virtual"],[1,2,1,"","make_tensor"],[1,2,1,"","mean"],[1,2,1,"","move_to_network"],[1,1,1,"","name"],[1,2,1,"","neighbours"],[1,1,1,"","network"],[1,2,1,"","node_ref"],[1,2,1,"","norm"],[1,2,1,"","numel"],[1,2,1,"","permute"],[1,2,1,"","permute_"],[1,1,1,"","rank"],[1,2,1,"","reattach_edges"],[1,2,1,"","renormalize"],[1,2,1,"","reset_tensor_address"],[1,2,1,"","set_tensor"],[1,2,1,"","set_tensor_from"],[1,1,1,"","shape"],[1,2,1,"","size"],[1,2,1,"","split"],[1,2,1,"","split_"],[1,2,1,"","std"],[1,1,1,"","successors"],[1,2,1,"","sum"],[1,1,1,"","tensor"],[1,2,1,"","tensor_address"],[1,2,1,"","unset_tensor"]],"tensorkrowch.Axis":[[1,2,1,"","is_batch"],[1,2,1,"","is_node1"],[1,1,1,"","name"],[1,1,1,"","node"],[1,1,1,"","num"]],"tensorkrowch.Edge":[[1,1,1,"","axes"],[1,1,1,"","axis1"],[1,1,1,"","axis2"],[1,2,1,"","change_size"],[1,2,1,"","connect"],[1,2,1,"","contract"],[1,2,1,"","contract_"],[1,2,1,"","copy"],[1,2,1,"","disconnect"],[1,2,1,"","is_attached_to"],[1,2,1,"","is_batch"],[1,2,1,"","is_dangling"],[1,1,1,"","name"],[1,1,1,"","node1"],[1,1,1,"","node2"],[1,1,1,"","nodes"],[1,2,1,"","qr"],[1,2,1,"","qr_"],[1,2,1,"","rq"],[1,2,1,"","rq_"],[1,2,1,"","size"],[1,2,1,"","svd"],[1,2,1,"","svd_"],[1,2,1,"","svdr"],[1,2,1,"","svdr_"]],"tensorkrowch.Node":[[1,2,1,"","change_type"],[1,2,1,"","copy"],[1,2,1,"","parameterize"]],"tensorkrowch.ParamNode":[[1,2,1,"","change_type"],[1,2,1,"","copy"],[1,1,1,"","grad"],[1,2,1,"","parameterize"]],"tensorkrowch.ParamStackNode":[[1,1,1,"","edges_dict"],[1,1,1,"","node1_lists_dict"],[1,2,1,"","reconnect"],[1,2,1,"","unbind"]],"tensorkrowch.StackEdge":[[1,2,1,"","connect"],[1,1,1,"","edges"],[1,1,1,"","node1_list"]],"tensorkrowch.StackNode":[[1,1,1,"","edges_dict"],[1,1,1,"","node1_lists_dict"],[1,2,1,"","reconnect"],[1,2,1,"","unbind"]],"tensorkrowch.TensorNetwork":[[1,2,1,"","add_data"],[1,1,1,"","auto_stack"],[1,1,1,"","auto_unbind"],[1,2,1,"","contract"],[1,2,1,"","copy"],[1,1,1,"","data_nodes"],[1,2,1,"","delete_node"],[1,1,1,"","edges"],[1,2,1,"","forward"],[1,1,1,"","leaf_nodes"],[1,1,1,"","nodes"],[1,1,1,"","nodes_names"],[1,2,1,"","parameterize"],[1,2,1,"","reset"],[1,1,1,"","resultant_nodes"],[1,2,1,"","set_data_nodes"],[1,2,1,"","trace"],[1,2,1,"","unset_data_nodes"],[1,1,1,"","virtual_nodes"]],"tensorkrowch.decompositions":[[3,3,1,"","mat_to_mpo"],[3,3,1,"","vec_to_mps"]],"tensorkrowch.embeddings":[[4,3,1,"","add_ones"],[4,3,1,"","basis"],[4,3,1,"","discretize"],[4,3,1,"","poly"],[4,3,1,"","unit"]],"tensorkrowch.models":[[14,0,1,"","ConvMPS"],[14,0,1,"","ConvMPSLayer"],[14,0,1,"","ConvPEPS"],[14,0,1,"","ConvTree"],[14,0,1,"","ConvUMPS"],[14,0,1,"","ConvUMPSLayer"],[14,0,1,"","ConvUPEPS"],[14,0,1,"","ConvUTree"],[14,0,1,"","MPO"],[14,0,1,"","MPS"],[14,0,1,"","MPSData"],[14,0,1,"","MPSLayer"],[14,0,1,"","PEPS"],[14,0,1,"","Tree"],[14,0,1,"","UMPO"],[14,0,1,"","UMPS"],[14,0,1,"","UMPSLayer"],[14,0,1,"","UPEPS"],[14,0,1,"","UTree"]],"tensorkrowch.models.ConvMPS":[[14,2,1,"","copy"],[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvMPSLayer":[[14,2,1,"","copy"],[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","out_channels"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvPEPS":[[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvTree":[[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvUMPS":[[14,2,1,"","copy"],[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvUMPSLayer":[[14,2,1,"","copy"],[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","out_channels"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvUPEPS":[[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.ConvUTree":[[14,1,1,"","dilation"],[14,2,1,"","forward"],[14,1,1,"","in_channels"],[14,1,1,"","kernel_size"],[14,1,1,"","padding"],[14,1,1,"","stride"]],"tensorkrowch.models.MPO":[[14,1,1,"","bond_dim"],[14,1,1,"","boundary"],[14,2,1,"","canonicalize"],[14,2,1,"","contract"],[14,2,1,"","copy"],[14,1,1,"","in_dim"],[14,2,1,"","initialize"],[14,1,1,"","left_node"],[14,1,1,"","mats_env"],[14,1,1,"","n_batches"],[14,1,1,"","n_features"],[14,1,1,"","out_dim"],[14,2,1,"","parameterize"],[14,1,1,"","right_node"],[14,2,1,"","set_data_nodes"],[14,1,1,"","tensors"],[14,2,1,"","update_bond_dim"]],"tensorkrowch.models.MPS":[[14,1,1,"","bond_dim"],[14,1,1,"","boundary"],[14,2,1,"","canonicalize"],[14,2,1,"","canonicalize_univocal"],[14,2,1,"","contract"],[14,2,1,"","copy"],[14,2,1,"","entropy"],[14,1,1,"","in_env"],[14,1,1,"","in_features"],[14,1,1,"","in_regions"],[14,2,1,"","initialize"],[14,1,1,"","left_node"],[14,1,1,"","mats_env"],[14,1,1,"","n_batches"],[14,1,1,"","n_features"],[14,2,1,"","norm"],[14,1,1,"","out_env"],[14,1,1,"","out_features"],[14,1,1,"","out_regions"],[14,2,1,"","parameterize"],[14,2,1,"","partial_density"],[14,1,1,"","phys_dim"],[14,1,1,"","right_node"],[14,2,1,"","set_data_nodes"],[14,1,1,"","tensors"],[14,2,1,"","update_bond_dim"]],"tensorkrowch.models.MPSData":[[14,2,1,"","add_data"],[14,1,1,"","bond_dim"],[14,1,1,"","boundary"],[14,2,1,"","initialize"],[14,1,1,"","left_node"],[14,1,1,"","mats_env"],[14,1,1,"","n_batches"],[14,1,1,"","n_features"],[14,1,1,"","phys_dim"],[14,1,1,"","right_node"],[14,1,1,"","tensors"]],"tensorkrowch.models.MPSLayer":[[14,2,1,"","copy"],[14,1,1,"","in_dim"],[14,2,1,"","initialize"],[14,1,1,"","out_dim"],[14,1,1,"","out_node"],[14,1,1,"","out_position"]],"tensorkrowch.models.PEPS":[[14,1,1,"","bond_dim"],[14,1,1,"","boundary"],[14,2,1,"","contract"],[14,1,1,"","in_dim"],[14,2,1,"","initialize"],[14,1,1,"","n_batches"],[14,1,1,"","n_cols"],[14,1,1,"","n_rows"],[14,2,1,"","set_data_nodes"]],"tensorkrowch.models.Tree":[[14,1,1,"","bond_dim"],[14,2,1,"","canonicalize"],[14,2,1,"","contract"],[14,2,1,"","initialize"],[14,1,1,"","n_batches"],[14,2,1,"","set_data_nodes"],[14,1,1,"","sites_per_layer"]],"tensorkrowch.models.UMPO":[[14,2,1,"","copy"],[14,2,1,"","initialize"],[14,2,1,"","parameterize"]],"tensorkrowch.models.UMPS":[[14,2,1,"","copy"],[14,2,1,"","initialize"],[14,2,1,"","parameterize"]],"tensorkrowch.models.UMPSLayer":[[14,2,1,"","copy"],[14,1,1,"","in_dim"],[14,2,1,"","initialize"],[14,1,1,"","out_dim"],[14,1,1,"","out_node"],[14,1,1,"","out_position"],[14,2,1,"","parameterize"]],"tensorkrowch.models.UPEPS":[[14,1,1,"","bond_dim"],[14,2,1,"","contract"],[14,1,1,"","in_dim"],[14,2,1,"","initialize"],[14,1,1,"","n_batches"],[14,1,1,"","n_cols"],[14,1,1,"","n_rows"],[14,2,1,"","set_data_nodes"]],"tensorkrowch.models.UTree":[[14,1,1,"","bond_dim"],[14,2,1,"","contract"],[14,2,1,"","initialize"],[14,1,1,"","n_batches"],[14,2,1,"","set_data_nodes"],[14,1,1,"","sites_per_layer"]],tensorkrowch:[[1,0,1,"","AbstractNode"],[1,0,1,"","Axis"],[1,0,1,"","Edge"],[1,0,1,"","Node"],[15,0,1,"","Operation"],[1,0,1,"","ParamNode"],[1,0,1,"","ParamStackNode"],[1,0,1,"","StackEdge"],[1,0,1,"","StackNode"],[1,0,1,"","Successor"],[1,0,1,"","TensorNetwork"],[15,3,1,"","add"],[15,3,1,"","connect"],[15,3,1,"","connect_stack"],[15,3,1,"","contract"],[15,3,1,"","contract_"],[15,3,1,"","contract_between"],[15,3,1,"","contract_between_"],[15,3,1,"","contract_edges"],[12,3,1,"","copy"],[15,3,1,"","disconnect"],[15,3,1,"","div"],[15,3,1,"","einsum"],[12,3,1,"","empty"],[15,3,1,"","mul"],[12,3,1,"","ones"],[15,3,1,"","permute"],[15,3,1,"","permute_"],[15,3,1,"","qr"],[15,3,1,"","qr_"],[12,3,1,"","rand"],[12,3,1,"","randn"],[15,3,1,"","renormalize"],[15,3,1,"","rq"],[15,3,1,"","rq_"],[15,3,1,"","split"],[15,3,1,"","split_"],[15,3,1,"","stack"],[15,3,1,"","stacked_einsum"],[15,3,1,"","sub"],[15,3,1,"","svd"],[15,3,1,"","svd_"],[15,3,1,"","svdr"],[15,3,1,"","svdr_"],[15,3,1,"","tprod"],[15,3,1,"","unbind"],[12,3,1,"","zeros"]]},objnames:{"0":["py","class","Python class"],"1":["py","property","Python property"],"2":["py","method","Python method"],"3":["py","function","Python function"]},objtypes:{"0":"py:class","1":"py:property","2":"py:method","3":"py:function"},terms:{"0":[1,3,4,6,7,8,9,10,11,14,15,17,19,21,22,23],"00":[4,7],"0000":4,"0000e":4,"0086":1,"01":[7,10],"0101":1,"0188":1,"02":7,"0234":1,"03":[9,10],"0371":1,"04":[7,9],"0412":1,"0440":1,"0445":1,"0465":1,"05":7,"0501":1,"0521":1,"06":[7,9],"0633":1,"0639":17,"07":[7,9],"0704":1,"0728":17,"0743":1,"0750":17,"08":[4,9],"0806":1,"0820":17,"08595":11,"09":[7,10,14],"0913":17,"0983":17,"0998":17,"1":[1,3,4,6,7,8,9,10,11,14,15,23],"10":[1,4,6,7,8,9,10,11,14,15,17,19,22,23],"100":[1,4,6,7,8,9,10,11,14,15,18,19,20,21],"10000":[17,23],"101":7,"1024":[17,23],"104":7,"106":7,"108":7,"1091":1,"11":[6,7,8,9,10],"110":7,"1123":17,"113":7,"115":7,"117":7,"119":7,"12":[6,7,9,10],"1211":1,"122":7,"124":7,"126":7,"128":7,"13":[6,7,9],"1308":17,"131":7,"133":7,"135":7,"137":7,"1371":1,"14":[6,7,8,9,10,23],"140":7,"142":7,"144":7,"146":7,"149":7,"1496":1,"15":[1,6,7,8,9,10,15,19],"151":7,"153":7,"155":7,"1573":17,"158":7,"159674":10,"159962":10,"16":[6,7,8,9,10],"160":7,"161204":17,"162":7,"164":7,"16406":9,"1654":1,"167":7,"169":7,"17":[6,7,9,10],"171":7,"173":7,"176":7,"178":7,"18":[6,7,10],"180":7,"182":7,"185":7,"187":7,"1878":23,"189":7,"19":[7,9,10],"191":7,"194":7,"1955":17,"196":7,"198":7,"1990":17,"1998":1,"1e":[6,7,8,9,10,14,17,22,23],"2":[1,4,6,7,8,9,10,14,15,23],"20":[1,7,10,14,15,23],"200":7,"203":7,"205":7,"2052":23,"207":7,"209":7,"2091":1,"2095":1,"21":[7,10],"2111":1,"212":7,"214":7,"216":7,"218":7,"22":[6,7,10],"2203":23,"221":7,"223":7,"2254":1,"23":[7,10],"2306":11,"2316":1,"2357":23,"236220":[10,17],"2380":1,"2390":1,"24":[7,8],"2449":1,"2461":1,"2477":1,"25":[7,8],"2503":1,"2517":1,"26":7,"2618":23,"27":[7,8],"2730":1,"2775":1,"27944":9,"2799":1,"28":[6,7,8,9,10,17,23],"2808":1,"2840":23,"2856":1,"2866":1,"2898":1,"29":[7,8],"2d":14,"2f":[6,7,8,9,10],"3":[1,4,6,7,8,9,10,11,14,15,21,23],"30":[7,9,23],"3083":17,"3088":1,"31":7,"3139":1,"3149":23,"32":[7,14,17],"3222":1,"33":[6,7,9],"3340":1,"3370":1,"3371":1,"3381":1,"3393":1,"34":7,"3427":1,"3489":1,"35":[6,7,9],"3508":1,"3513":1,"3585":1,"36":7,"37":[7,8],"3711e":4,"3714":23,"3760":1,"3784":1,"38":[6,7,8],"3821":1,"39":[6,7,8,9,10],"39760":9,"4":[1,4,6,7,8,9,10,14,15,19,21,23],"40":[6,7,10,23],"4005":1,"41":7,"4181":1,"4184":1,"42":[7,8],"4216":1,"43":7,"4383":1,"4385":1,"44":7,"4402":1,"4431":1,"4461":1,"45":[6,7,8,9],"4572":1,"4588":1,"46":7,"462803":6,"47":7,"4731":1,"4761":1,"48":7,"49":[7,8],"4974":1,"499320":23,"4f":[17,23],"5":[1,4,6,7,8,9,10,11,14,15,18,19,20,21,23],"50":[6,7,9,23],"500":[11,17,22,23],"5000":4,"5021":1,"5029":1,"5069":1,"51":7,"5161":1,"52":[7,8,9,10],"5224":23,"53":7,"54":7,"5401":1,"5406":1,"55":[6,7,10],"553036":6,"553186":23,"5567":1,"5570":1,"56":[7,9],"5676":17,"57":[7,9,10],"5731":1,"5760":1,"5797":1,"58":[7,10],"59":[7,8],"5920":1,"6":[1,4,6,7,8,9,10,14,23],"60":[7,23],"60000":[17,23],"6023":1,"61":[6,7],"6149":17,"61999999999999":9,"62":[7,8,9],"6225":1,"6227":1,"6295":1,"63":7,"6356":1,"6399":1,"64":[6,7,8,9,10],"6492":1,"6495":1,"65":[7,8,10],"6524":1,"6545":1,"6551":1,"66":7,"67":[7,10],"68":7,"6811":1,"69":[7,8,9,10],"6925":1,"6982":1,"7":[1,6,7,8,9,10,11,14,15,19,20,23],"70":[7,8,9,23],"71":7,"72":[7,10],"73":[7,9],"74":[7,10],"75":[4,7],"7500":4,"7592":1,"76":[7,10],"7688":1,"77":7,"7752":1,"78":7,"7812":1,"79":[6,7,10],"7997":1,"8":[6,7,8,9,10,11,14,17],"80":[6,7,23],"8090":1,"81":[6,7,8,9],"8147":1,"82":7,"8227":1,"83":[6,7,9,10],"8361":1,"8387":1,"84":[6,7,8,9,10],"8441":1,"8442":1,"85":[6,7,9],"8502":23,"86":[6,7,9],"8627":23,"8633":1,"8649":1,"87":[6,7,9],"8795":23,"88":[6,7,9],"8815":1,"8820":17,"8848":23,"8851":23,"8859":1,"89":[6,7,9],"8901":23,"8911":1,"8915":23,"8948":23,"8968":23,"8984":23,"9":[6,7,8,9,10,11,14,17,22,23],"90":[7,9,23],"9006":1,"9009":23,"9011":[17,23],"9026":23,"9048":23,"9053":17,"91":[7,10],"9125":23,"9145":1,"9174":23,"92":[7,10],"9231":23,"9265":1,"9284":23,"93":[7,10],"9320":1,"9371":17,"9396":17,"94":[7,9],"9400":17,"9432":1,"95":[6,7,8,9,10],"9509":17,"9526":17,"9551":1,"9561":1,"9585":17,"96":[7,8,9,10],"9600":17,"9618":1,"9621":17,"9625":17,"9668":17,"9677":17,"9696":17,"97":[7,8,10],"9700":17,"9721":17,"9729":17,"9731":17,"9734":17,"9736":17,"9738":17,"9743":17,"9768":17,"9775":17,"9793":17,"98":[6,7,8,10,23],"9844":1,"99":[6,7,17],"9925":1,"9942":1,"9957":1,"abstract":1,"boolean":[1,12,14],"break":[7,8],"byte":[17,23],"case":[1,3,11,14,20,21,22],"class":[0,1,6,7,8,9,14,17,19,22,23],"default":[1,4,14,20,21],"do":[1,14,17,18,20,21,22],"final":[1,3,14,19,22],"float":[1,3,6,7,8,9,10,14,15],"function":[1,14,15,19,23],"garc\u00eda":11,"import":[1,6,7,8,9,10,11,18,19,20,21,22,23],"int":[1,3,4,10,12,14,15],"n\u00ba":[17,23],"new":[1,4,14,15,17,18,19,20,22,23],"p\u00e9rez":11,"return":[1,3,4,6,7,8,9,10,12,14,15,17,18,19,20,22,23],"super":[1,6,7,8,9,22,23],"true":[1,6,7,8,9,10,11,12,14,15,17,19,20,21,22,23],"try":[10,19,23],"while":[1,7,8,11,19,20,21],A:[1,11,14,15,18],And:[20,23],As:[11,14,17,18,19,20],At:[1,18],Be:[1,19],But:[19,20,22,23],By:[1,4,11,14,17,18],For:[1,14,15,18,19,20,21,22],If:[1,3,4,11,14,15,18,20,22],In:[1,3,8,11,14,15,17,18,19,20,21,22,23],Is:14,It:[1,3,11,14,15,17,18,19,23],Its:[1,15],Not:22,Of:[20,22],On:1,One:[1,10,18],That:[1,3,4,14,15,19,20,22,23],The:[1,3,4,11,14,15,16,18,19,20,23],Then:[1,14,20],There:[1,14,19,21,22],These:[1,3,11,14,18,19,21],To:[1,3,4,11,13,14,15,17,18,19,20,21,22],With:[1,10,11,15,18,19,22],_:[1,6,7,8,9,10,14,15,17,18,22,23],__call__:1,__init__:[1,6,7,8,9,22,23],_channel:14,_copi:1,_dim:1,_mats_env:[7,8],_percentag:[1,3,14,15],_size:14,_size_0:14,_size_1:14,_size_:1,a_:15,abil:11,abl:[1,11,18,19,22],about:[1,15,18,19,20,21,22],abov:[1,18,22],abstractnod:[0,12,15],acc:[6,7,8,9,10,17,23],acceler:[1,9,10],accept:19,access:[1,18,19,20],accomplish:[1,3,14,21],accord:[1,15],accordingli:14,account:[1,14,21],accross:[1,15],accumul:[3,14],accuraci:[6,7,8,9,10,17,23],achiev:1,act:[1,14,18,20],action:1,activ:1,actual:[1,14,18,19,20,23],ad:[1,4,14,15,19,22],adam:[6,7,8,9,10,17,23],adapt:[11,15],add:[0,1,14,17,19,22,23],add_data:[1,9,14,21,22],add_on:[0,23],addit:[1,11,14,15],addition:1,address:[1,18,20],admit:[14,15],advanc:[1,11,16],advantag:[8,20,22],advis:11,affect:[1,15],after:[1,3,8,14,17,23],afterward:[1,14,15],again:[8,10,14,17,20,22],against:14,aim:[3,14,17],alejandro:11,algorithm:[1,14,20,22,23],all:[1,3,4,8,10,14,15,17,18,19,20,21,22,23],allow:[1,11,14],almost:[1,19,20,21,23],along:[1,15],alreadi:[1,7,14,15,17,18,19,22],also:[1,3,8,11,13,14,15,17,18,19,20,21,22],alter:14,although:[1,10,18,19,20,21,22],alwai:[1,11,14,15,17,18,19,20,21,22],among:[3,14],amount:[15,17],an:[1,3,4,6,9,10,11,12,14,15,17,18,19,20,21,22],ancillari:[1,20,21],ani:[1,7,8,14,15,18,19,22],anomali:11,anoth:[1,14,15,18,19],antialia:[7,8,9,10],anymor:1,api:2,appear:[1,15,18],append:[6,7,8,19,20,21,22,23],appli:[1,11,14,15,19],approach:[8,10,11],appropi:[14,15,22,23],approxim:1,ar:[1,3,4,8,11,13,14,15,17,18,19,21,22,23],arang:[4,7,8],arbitrari:1,architectur:11,archiveprefix:11,arg:[1,6,7,8,9,12,14],argument:[1,3,12,14,15,22],aros:1,around:17,arrai:18,arrow:15,arxiv:11,ascertain:11,assert:[1,11,14,15,18,19,20,21],assign:14,assum:[1,4,14,15],attach:[1,15,18],attribut:[1,14,22],author:11,auto_stack:[1,15,17,20,21,22,23],auto_unbind:[1,15,17,20,22,23],automat:[1,15,18,19,20],auxiliari:15,avail:[11,17,18],avoid:[1,3,14,15,20],awar:[1,19],ax:[1,7,8,14,15,18,19],axes_nam:[1,7,8,11,15,18,19,20,21,22],axi:[0,4,7,8,15,18,22,23],axis1:1,axis2:1,axis_0:1,axis_1:1,axis_2:1,axis_n:1,b:[4,14,15,19],b_1:3,b_:15,b_m:3,back:1,backend:[6,7,8,9,10],backpropag:[17,23],backward:[1,6,7,8,9,10,11,17,19,23],bar:[7,8],base:[1,4,10,14,15,19],basi:[0,10,14],basic:[1,11,18,19],batch:[1,3,4,10,11,14,15,17,18,19,21,22],batch_0:[4,14],batch_1:[1,14,22],batch_idx:[6,7,8,9,10],batch_m:1,batch_n:[4,14,22],batch_siz:[6,7,8,9,10,14,17,23],becaus:[1,14,20,23],becom:[1,11,14,15],been:[1,14,15,18,22],befor:[1,6,7,8,9,10,14,17,18,22,23],beforehand:18,beggin:8,begin:[4,17,18,23],behav:[1,20],behavior:[3,14],behaviour:[1,14,15,18,20,21],being:[1,3,4,14,15,19],belong:[1,14,17,18,19],besid:[1,14,15,18],best:[10,11],better:18,between:[1,3,4,10,11,14,15,18,20,21,23],bia:9,big:[14,19],bigger:14,binom:4,bit:1,blank:1,block:[7,8],block_:8,block_length:[7,8],block_nod:[7,8],block_posit:7,bmatrix:4,bodi:18,bond:[3,8,14],bond_dim:[6,7,8,9,10,11,14,17,23],bool:[1,3,12,14],border:14,both:[1,14,15,18,19,20,21,22],bottom:[1,14,23],boudari:14,bound:[1,3,14,15],boundari:[3,7,8,9,10,14],bracket:1,bridg:11,bring:17,build:[1,6,11,16,19],built:[11,17,19,22],bunch:[1,19,22],c_:15,cach:[1,20],calcul:20,call:[1,14,15,17,19,22,23],callabl:15,can:[1,3,4,6,7,8,9,10,11,13,14,15,17,18,19,20,21,22,23],cannot:[1,7,14,15,19,21],canon:[6,10,14,17,23],canonic:[6,9,10,14,17,23],canonicalize_univoc:14,carri:[1,14,15,18],cast:1,cat:[6,23],caus:14,cdot:[3,4,14],center:14,central:[1,18,23],certain:[1,4,10,11,15],chang:[1,6,10,11,14,15,18,20,22,23],change_s:1,change_typ:1,channel:14,charact:1,check:[1,6,7,8,9,10,14,15,17,18,19,20,21,23],check_accuraci:[6,7,8,9,10],check_first:15,child:1,choic:10,choos:10,chose:21,classif:[14,17],classifi:[17,22],clone:[11,13,14],close:14,cmap:[6,7,8,9,10],cnn:[6,23],cnn_snake:[6,23],cnn_snakesb:[6,23],co:4,code:[1,10,11,18,23],coincid:[14,15,18],collect:[1,11,15,19,20],column:14,com:[11,13],combin:[1,6,10,14,18,22,23],come:[1,15,17,18,21,22],comma:15,command:[11,13],common:[1,14,17,22],compar:20,complementari:14,complet:[1,15],complex64:1,complex:1,compon:[0,4,11,15,17,19],compos:[7,8,9,10,17,23],comput:[1,3,4,11,13,14,15,17,19,20,22,23],concaten:4,concept:11,condit:[3,14,15],configur:[10,11],conform:1,conj:1,conja:1,conjug:1,connect:[0,1,9,11,14,18,19,20,21,22],connect_stack:0,consecut:[3,14],consid:[11,14],consist:[1,15],constraint:11,construct:[1,11,19,20],contain:[1,14,18,19,20,21],context:11,continu:[14,17],contract:[0,1,6,7,8,9,10,11,14,16,17,18,20,21,22],contract_:[0,1,19],contract_between:[0,1,19],contract_between_:[0,1,7,8,19],contract_edg:[0,1,19],contract_edges_ip:[1,15],contrast:[10,14],contrat:14,control:[1,18,20],conv2d:[6,14,23],conv_mp:14,conv_mps_lay:14,conv_pep:14,conv_tre:14,conveni:[14,21],convent:[1,15,19,22],converg:10,convmp:0,convmpslay:[0,6,23],convolut:[6,14,22,23],convpep:0,convtre:0,convump:0,convumpslay:0,convupep:0,convutre:0,copi:[0,1,14],core:[8,10,14,18,23],corner:14,correct:[1,7,8,9,10,20],correctli:15,correspond:[1,3,4,14,15,18,19,20,22],costli:[1,14,20],could:[14,17,18,20,22],count:17,coupl:[1,18,21],cours:[20,22],cpu:[6,7,8,9,10,17,22,23],creat:[1,6,7,11,14,15,16,19,20,21,22],creation:[11,15],criterion:[6,7,8,9,10],crop:[1,14],crossentropyloss:[6,7,8,9,10,17,23],cuda:[1,6,7,8,9,10,17,22,23],cum:[1,3,14,15],cum_percentag:[1,3,6,7,8,9,10,14,15,17,23],current:[1,14],custom:[1,11,14,16],cut:[3,17,23],cutoff:[1,3,14,15],d:[4,11],d_1:3,d_n:3,dagger:[1,15],dangl:[1,15,17,18,19],data:[1,4,6,7,8,9,10,11,14,15,19,21,22,23],data_0:[1,21],data_1:1,data_nod:[1,7,8,19,21,22],data_node_:19,dataload:[6,7,8,9,10,17,23],dataset:[14,17,23],dataset_nam:[6,7,8,9,10],david:[4,11],ddot:4,de:[1,14,19],decompos:[3,15],decomposit:[0,1,9,14,15,17,19,23],decreas:1,dedic:14,deep:[11,17],deepcopi:1,def:[1,6,7,8,9,10,17,22,23],defin:[1,4,11,14,18,19,21,22,23],degre:[4,8,10],del:[1,15],delet:[1,15],delete_nod:1,demonstr:11,densiti:14,depend:[1,3,10],deriv:14,descent:[6,7,8,9,10,18,19],describ:[1,9,14,22],design:14,desir:[4,8,19,20,21,22],detach:9,detail:[1,11,14,19],determin:11,develop:11,devic:[1,6,7,8,9,10,14,17,22,23],devot:20,df:1,diagon:[1,12,15],diagram:18,dictionari:1,did:23,differ:[1,5,6,11,14,15,16,17,18,19,20,22,23],differenti:[11,16,18],dilat:14,dim:[4,6,7,10,17,23],dimens:[1,3,4,8,14,15,17,18,22],dimension:[18,23],direct:7,directli:[1,11,13,15,18,19,20,21],disconnect:[0,1,14,18],discret:[0,10,14],distinct:1,distinguish:21,distribut:[3,12,14],div:[0,19],divers:11,divid:[1,15],divis:15,divisor:15,dmrg:[5,10],document:[1,15],doe:[1,14,18,20],don:15,done:[1,3,14,18],down:14,down_bord:14,download:[6,7,8,9,10,11,13,23],drawn:12,drop_last:[17,23],dtype:[1,14],due:14,duplic:14,dure:[1,8,17,20,21],dynam:14,e:[1,4,11,14,15,21],each:[1,3,4,8,14,15,17,18,19,20,21,22,23],earlier:18,easi:[1,11],easier:18,easili:[1,18,19,22,23],edg:[0,11,14,17,18,19,20,21,22,23],edge1:[1,15],edge2:[1,15],edgea:1,edgeb:1,edges_dict:1,effect:[1,19],effici:[1,11,18],einstein:19,einsum:[0,19],either:[14,15],element:[1,4,11,12,14,15,18],element_s:[17,23],elif:[6,7,8,9,10],els:[6,7,8,9,10,17,22,23],emb:[4,17,22],emb_a:4,emb_b:4,embed:[0,6,7,8,10,14,17,22,23],embedd:4,embedding_dim:[7,8,10],embedding_matric:14,empti:[0,1,14,18,19,20],enabl:[1,11,14,18,19,20],end:[1,4,8,15],engin:18,enhanc:11,enough:[1,14],ensur:[1,11],entail:[1,20],entangl:14,entir:1,entropi:14,enumer:[1,6,7,8,9,10,14],environ:14,epoch:[1,6,7,8,9,10,17,22,23],epoch_num:[17,23],eprint:11,equal:[1,7,14,22],equilibrium:17,equival:[1,4,14,18,23],error:[11,14],essenti:[1,11],etc:[1,15,18,22],eval:[6,7,8,9,10],evalu:19,even:[1,3,8,11,14,18,20,22],evenli:[3,14],everi:[1,15,18,20,23],exact:1,exactli:1,exampl:[1,2,4,6,9,14,15,17,19,21,22],exce:7,exceed:14,excel:11,except:[12,14],exclud:[1,21],execut:15,exist:[1,14,15,19],exp:14,expand:[1,4,22],expect:[3,14],experi:[1,11,18,20],experiment:11,explain:[1,18,19,20],explan:[1,14,15],explicitli:[1,15,21,22],explod:[3,14],explor:[11,18],express:15,extend:1,extra:[3,14,15],extract:[3,14,18],extractor:23,ey:22,eye_tensor:22,f:[6,7,8,9,10,17,18,19,20,21,23],facilit:[1,11,15,18],fact:[1,21,22],factor:[3,14],fail:[14,15],failur:11,fals:[1,3,6,7,8,9,10,12,14,15,17,19,20,22,23],familiar:[11,18],fashion:[8,19],fashion_mnist:[6,9],fashionmnist:[6,9,23],faster:[1,8,22],fastest:11,fc1:9,fc2:9,featur:[1,4,14,15,18,19,20,21,23],feature_dim:[1,22],feature_s:14,fed:23,feed:9,few:[8,11],fffc:9,file:11,fill:[1,12,18,19,22],find:11,finish:[18,22],finit:14,first:[1,4,10,11,14,15,16,18,19,20,21,23],fix:[1,18,19],flat:14,flatten:14,flavour:1,flip:[1,6,23],flips_x:6,fn_first:15,fn_next:15,folder:[11,13],follow:[1,11,13,14,15],forget:1,form:[1,14,17,18,19,21,23],former:[1,14,15],forward:[1,6,7,8,9,10,14,15,22,23],four:22,frac:[1,3,4,14,15],framework:11,free:14,friendli:11,from:[1,3,6,7,8,9,10,11,12,13,14,15,17,18,19,21,22,23],from_sid:14,from_siz:14,full:17,fulli:[9,11],func:[1,14],functool:23,fundament:11,further:[1,11],furthermor:[1,20,22],futur:[1,20],g:[1,14,15,21],gain:11,gap:11,garc:11,gaussian:14,ge:[1,3,14,15],gener:[1,20],get:[1,6,7,8,9,10,17,18,22,23],get_axi:[1,7,8,18],get_axis_num:[1,18],get_edg:1,git:[11,13],github:[11,13],give:[1,18,20,21],given:[1,4,6,14],glimps:17,go:[1,15,17,19,20,23],goal:11,goe:14,good:[1,6,7,8,9,10,17],gpc20:6,gpu:[17,22],grad:[1,11,19],gradient:[1,6,7,8,9,10,11,14,18,19],graph:[1,18,20,22],grasp:11,greater:[1,7,22],greatli:11,grei:[6,7,8,9,10],grid:14,grid_env:14,group:15,gt:[6,7,8,9,10],guid:11,ha:[1,3,8,14,15,17,18,20,21,22,23],had:[1,15,20],hand:[1,11,14,22],happen:[1,3,14,20,23],har:11,hardwar:11,hat:4,have:[1,3,4,14,15,17,18,19,20,21,22,23],heavi:22,height:[14,17,22],help:[1,20],henc:[1,15,17,18,19,20,22],here:[1,7,8,10,14,19],hidden:[1,21],high:[1,4,6,7,8,9,10,12],highli:11,hint:1,hold:[19,21],hood:20,horizont:14,how:[1,6,7,8,9,10,11,14,15,16,17,19,21,23],howev:[1,11,14,15,18,19,20],http:[11,13],hundread:19,hybrid:[5,11,16],hyperparamet:[6,7,8,9,10],i:[1,3,6,7,8,11,14,15,18,19,20,21,22],i_1:12,i_n:12,ideal:11,ident:14,identif:11,idx:18,ignor:[6,10],ijb:[15,19],ijk:15,im:15,imag:[6,14,17,22,23],image_s:[6,7,8,9,10,17,22,23],immers:11,implement:[1,11,17,23],improv:17,imshow:[6,7,8,9,10],in_1:3,in_channel:[6,14,23],in_dim:[7,8,10,11,14,17],in_env:14,in_featur:14,in_n:3,in_region:14,in_which_axi:1,includ:[1,11,14,18],incorpor:11,inde:[1,14],independ:[1,15],index:[1,6,7,8,9,10,18,20],indic:[1,3,12,14,15,18,19,20,21],infer:[1,14,15,18,20,22],inform:[1,11,14,15,18,19,20],ingredi:[1,18],inherit:[1,14,15,21],inic:22,init:19,init_method:[1,6,7,8,10,14,17,18,19,23],initi:[0,1,7,8,9,10,14,17,18,19,22,23],inlin:14,inline_input:[6,9,10,14],inline_mat:[6,9,10,14],inner:1,input:[1,3,7,8,14,15,17,18,19,20,21,22,23],input_edg:[1,21,22],input_nod:22,input_s:[6,7,8,9,10],insid:[1,11,13],insight:19,instal:[2,17,18],instanc:[1,15,20,21,22],instanti:[1,3,14,18,20,21,22],instati:18,instead:[1,14,19,20,21,22,23],integ:4,integr:11,intend:[1,14],interior:14,intermedi:[1,19,21,22],intern:[14,15],intric:11,intricaci:11,introduc:[6,14],introduct:11,invari:[1,14,20,21],inverse_memori:1,involv:[1,15,20],is_attached_to:1,is_avail:[6,7,8,9,10,17,22,23],is_batch:[1,18],is_complex:1,is_conj:1,is_connected_to:1,is_dangl:[1,18],is_data:[1,21],is_floating_point:1,is_leaf:[1,21],is_node1:1,is_result:[1,21],is_virtu:[1,21],isinst:[1,19],isn:[17,18],isometri:14,issu:1,item:[6,7,8,9,10,17,23],iter:[1,8,14,20],its:[1,11,12,14,15,18,20,22,23],itself:[1,14,22],itselv:14,j:[1,4,8,11,15],jkb:[15,19],jl:15,jo:11,joserapa98:[11,13],just:[1,11,14,18,20,21,22,23],k:[11,15],keep:[1,3,14,15,20],keepdim:1,kei:[1,10,11,18],kept:[1,3,14,15],kernel:14,kernel_s:[6,14,23],kerstjen:11,keyword:[1,12,14],kib:[15,19],kind:[1,17],klm:15,know:[1,18,19,20],known:22,kwarg:[1,6,7,8,9,12,14],l2_reg:[17,23],l:[4,15],label:[14,17,21,23],lambda:[17,23],larger:1,largest:1,last:[1,7,14,15,22,23],later:22,latter:[1,14],layer:[1,6,11,14,17,22,23],ldot:12,lead:[1,15,21],leaf:[1,15,21,22],leaf_nod:[1,21],learn:[1,8,11,15,17,18,19,20,21,22,23],learn_rat:[17,23],learnabl:19,learning_r:[6,7,8,9,10],least:[1,14,15,17,20,23],leav:[8,14,20],left1:19,left2:19,left:[1,7,8,11,14,15,18,19,20,21,22,23],left_bord:14,left_down_corn:14,left_nod:[7,8,14],left_up_corn:14,leg:1,len:[6,7,8,9,10,15,22],length:[1,14],let:[17,18,19,20,22,23],level:[1,4,10,18],leverag:18,lfloor:4,li:[1,11,18],librari:[11,18,23],like:[0,1,5,10,11,14,17,18,19,20,21,22,23],line:[11,14,23],linear:[9,11,22],link:1,list:[1,3,12,14,15,18,19,22],load:[6,7,8,9,10,17,22],load_state_dict:[9,10,22],loader:[6,7,8,9,10,17,23],local:[4,14],locat:[1,14],log_norm:14,log_scal:14,logarithm:[3,14],logarithmoc:14,logit:[17,23],longer:[1,20],look:[1,10],loop:[17,22],lose:17,loss:[6,7,8,9,10,23],loss_fun:[17,23],lot:[1,17],low:[1,4,6,7,8,9,10,12,14],lower:[1,3,14,15],lr:[6,7,8,9,10,17,23],lst_y:6,lt:10,lvert:4,m:[1,3,11,13,15],machin:[11,17],made:[1,15],mai:[1,3,10,11,14],main:[1,18,19,20,22],mainli:1,maintain:1,make:[1,8,18,19,20,21],make_tensor:[1,14],manag:[1,22],mandatori:[1,15],mani:[1,8,14,17,18,20,23],manipul:18,manual:[1,14,22],manual_se:[17,23],map:4,margin:14,marginalize_output:14,master:[11,13],mat:3,mat_to_mpo:[0,9],match:[1,10,14,15],matplotlib:[6,7,8,9,10],matric:[14,15,19,22],matrix:[1,3,4,9,11,14,15,17,18,20,22],mats_env:[7,8,14],mats_env_:[7,8],max:[6,7,8,9,10,17,23],max_bond:14,maximum:[4,14],maxpool2d:[6,23],mayb:15,mb:[17,23],mean:[1,12,14],megabyt:[17,23],mehtod:18,member:[1,14],memori:[1,11,16,17,18,21,22,23],mention:[14,22],merg:[7,8],merge_block:7,messag:11,method:[1,14,15,18,19,20,21,22],middl:14,middle_sit:14,might:[1,14,20,21,22],mile:4,mimic:22,minuend:15,misc:11,miscellan:[17,23],mit:11,mkdir:[6,7,8,9,10],mnist:[7,8,10,17],mod:4,mode:[1,6,14,15,17,21,22,23],model:[0,1,5,9,11,16,18,19,20,21],model_nam:[6,7,8,9,10],modif:1,modifi:[1,14,19],modul:[1,6,9,11,13,14,17,18,19,22,23],modulelist:[6,23],monomi:4,monturiol:11,mooth:11,more:[1,3,11,14,15,18,19,20,21],most:[1,11,18,21],move:[1,14,18],move_block_epoch:[7,8],move_nam:1,move_to_network:[1,15],movec:15,mp:[0,3,5,6,9,11,17,18,19,20,21,22,23],mpo:[0,3,9,22],mpo_tensor:9,mps_data:[9,14],mps_dmrg:7,mps_dmrg_hybrid:8,mps_hdmrg:8,mps_layer:[6,14],mps_tensor:9,mpsdata:[0,3,9],mpslayer:[0,7,8,10,11,17,22],mpss:6,much:[1,8,20,21,22,23],mul:[0,19],multi:18,multipl:[14,18,20],multipli:[14,15],must:[1,3,15,18,22],my_model:11,my_nod:[1,20],my_paramnod:1,my_stack:1,mybatch:1,n:[1,4,6,9,10,11,12,14,17,23],n_:[1,4],n_batch:[3,9,14],n_block:8,n_col:14,n_epoch:22,n_featur:[7,8,9,10,11,14,17,22],n_param:[6,9,10,17,23],n_row:14,name:[1,7,8,11,15,17,18,19,20,22,23],name_:1,necessari:[1,14,15,17,18,22,23],need:[1,11,14,15,17,18,19,20,22],neighbour:[1,7,8,14,19],nelement:[17,23],net2:1,net:[1,11,15,18,19,20,21],network:[0,5,7,8,10,11,14,15,16,20,21,22],neumann:14,neural:[5,11,14,16,17],never:[1,15,21],nevertheless:11,new_bond_dim:10,new_edg:[1,15],new_edgea:1,new_edgeb:1,new_mp:22,new_nodea:[1,15],new_nodeb:[1,15],new_tensor:20,next:[1,14,15,18,19,20,21],nn:[1,6,7,8,9,10,11,14,17,18,19,22,23],no_grad:[6,7,8,9,10,17,23],node1:[1,11,15,18,19,20,21],node1_ax:[1,7,8,15],node1_list:1,node1_lists_dict:1,node2:[1,11,15,18,19,20,21],node2_ax:[1,7,8,15],node3:[19,20],node:[0,3,7,8,11,12,14,16,17,20,22,23],node_0:1,node_1:1,node_:[18,19,20,21],node_left:[1,15],node_n:1,node_ref:1,node_right:[1,15],nodea:[1,15],nodeb:[1,15],nodec:[1,15],nodes_list:15,nodes_nam:1,nodesa:15,nodesb:15,nodesc:15,non:[1,14,15,19,22],none:[1,3,7,8,11,14,15,19,22],norm:[1,3,14,15],normal:[1,3,12,14,15],notabl:11,notat:18,note:[1,11,15,18,19,22],noth:[1,18,20],now:[1,18,19,20,22,23],npov15:9,nto16:10,nu_batch:14,num:1,num_batch:[17,23],num_batch_edg:[1,21,22],num_class:[6,7,8,9,10],num_correct:[6,7,8,9,10],num_epoch:[6,7,8,9,10,17,23],num_epochs_canon:17,num_sampl:[6,7,8,9,10],num_test:[17,23],num_train:[17,23],number:[1,3,14,15,17,19,22,23],numel:[1,6,9,10],numer:15,nummber:1,o:11,obc:[3,7,8,9,10,14],object:[1,18,23],obtain:[14,17],oc:14,occur:[1,14,20],occurr:11,off:[17,20,23],offer:11,often:[1,20],onc:[1,8,11,14,15,20],one:[1,3,4,6,7,9,10,14,15,17,18,19,20,21,22,23],ones:[0,1,4,6,14,18,19,21,23],ones_lik:[6,17,23],onli:[1,11,14,15,18,19,20,21,22],open:14,oper:[0,1,14,18,21,22],operand:[1,18,19],opposit:14,opt_einsum:[11,15,19],optim:[1,6,7,8,9,10,11,14,23],option:[1,3,4,6,10,11,14,15,18,22],order:[1,3,6,14,15,18,19,20,21,22],orient:14,origin:[1,4,14,15,17,21,23],orthogon:14,ot:14,other:[1,8,14,15,18,19,20,21,22,23],other_left:18,other_nod:20,otherwis:[1,14,15,23],otim:4,our:[6,7,8,9,10,17,18,19,20,23],out:[1,14,15,22],out_1:3,out_channel:[6,14,23],out_dim:[7,8,10,11,14,17],out_env:14,out_featur:14,out_n:3,out_nod:[7,8,14],out_posit:14,out_region:14,outdat:14,output:[1,3,4,6,7,8,14,15,17,22],output_dim:[6,7,8,10],output_nod:22,outsid:[11,13],over:[1,15,19],overcom:1,overhead:20,overload:1,overrid:[1,7,14,22],override_edg:1,override_nod:1,overriden:[1,19,22],own:[1,14,15,20,21],ownership:1,p:[1,6,9,10,11,15,17,23],packag:[11,13],pad:[6,14,23],pair:[1,8,14,23],pairwis:14,paper:[4,11,14,23],parallel:[1,14],param:[1,15,17,23],param_nod:[1,11,12,19,21],paramedg:1,paramet:[1,3,4,6,7,8,9,10,12,14,15,17,19,22,23],parameter:[1,7,8,14,19],parametr:[1,14,17,19,22,23],paramnod:[0,12,14,15,20,21,22],paramnode1:19,paramnode2:19,paramnode3:19,paramnodea:1,paramnodeb:1,paramstacknod:[0,15,20,21],pareja2023tensorkrowch:11,pareja:11,parent:[1,22],parenthesi:1,part:[1,14,19,21,22],partial:[14,23],partial_dens:14,particular:[1,22],pass:[14,17,22],patch:[14,22],pattern:23,pave:17,pbc:14,pep:[0,11,22],perform:[1,10,11,14,15,17,18,19,20,22,23],period:14,permut:[0,1,9,19,20],permute_:[0,1,19],persist:11,phase:15,phi:4,phys_dim:[9,14],physic:[3,14,18],pi:4,piec:[18,22],pip:[11,13,17,18],pipelin:[11,17],pixel:[14,17,22,23],place:[1,14,15,19,20,22],plai:21,pleas:11,plt:[6,7,8,9,10],plu:14,plug:14,point:[1,17],poli:[0,8,10,22],pool:[6,23],posit:[7,14,22],possibl:[1,6,7,8,9,10,14,17,18,19,20],potenti:11,power:[4,14,17],poza:11,practic:[1,11],practition:17,pre:9,precis:4,pred:[17,23],predict:[6,7,8,9,10,17,23],present:23,previou:[1,14,19,20,21,22,23],previous:19,primari:11,principl:[1,21],print:[1,6,7,8,9,10,14,15,17,23],probabl:17,problem:[1,20],process:[1,8,10,11,17],prod:6,product:[11,14,15,17,18,20,22],profici:11,project:14,properli:22,properti:[1,7,14,15,18],proport:[1,3,14,15],prove:14,provid:[1,3,11,14,15,18],prune:23,pt:[6,7,8,9,10,22],purpos:[1,18,20,21],put:[1,14,17,23],pyplot:[6,7,8,9,10],pytest:[11,13],python:[11,13,17],pytorch:[1,11,15,17,18,19,20,22,23],q:15,qr:[0,1,3,14,19],qr_:[0,1,14],quantiti:[1,3,14,15],quantum:18,qubit:14,quit:[11,23],r:[11,15],r_1:15,r_2:15,rais:[1,7],ram:11,rand:[0,1,4,6,10,14],randint:[4,6,7,8,9,10],randn:[0,1,4,6,10,11,14,15,18,19,20,21,22],randn_ey:[6,8,10,14,17,23],random:[14,15,19],random_ey:22,random_sampl:[6,7,8,9,10],rang:[1,6,7,8,9,10,11,14,15,17,18,19,20,21,22,23],rangl:4,rank:[1,3,4,7,8,14,15,19],rapid:11,rather:[1,14,15,18,20,21],re:[1,15,17,23],reach:17,readi:22,realli:[1,21],reamin:8,reattach:1,reattach_edg:1,recal:[1,22],recogn:19,recommend:[1,11,14],reconnect:[1,19],recov:[1,3,14],reduc:[1,14,15,17,23],redund:[1,15],refer:[1,2,10,15,18,19,20,21,22],referenc:1,regard:[1,15,19],regio:14,relat:18,relev:[1,18],reli:[11,15],relu:[6,9,11,23],remain:[8,14],remov:[1,14,15],renorm:[0,1,3,6,9,10,14,19],renorma:[1,15],repeat:[1,8,10,15],repeatedli:20,replac:1,repositori:[11,13],repres:[14,18],reproduc:[1,10,11,20],requir:[1,14,15,21],requires_grad:1,rerun:11,res1:19,res2:19,rescal:14,research:11,reserv:[1,15,20],reset:[1,6,7,8,9,10,14,15,20,21,22],reset_tensor_address:1,reshap:[3,7,8,9,10,15,20],resiz:[7,8,9,10,17,23],respect:[1,3,14,15,19,20],rest:[1,19],restrict:[1,15,21],result2:15,result:[1,3,7,8,9,10,11,14,15,19,20,21,22],result_mat:[7,8],resultant_nod:[1,21],retain:1,retriev:[1,18,20,21],reus:14,review:11,rez:11,rfloor:4,rho_a:14,rid:1,right1:19,right2:19,right:[1,7,8,11,14,15,18,19,20,21,22,23],right_bord:14,right_down_corn:14,right_nod:[7,8,14],right_up_corn:14,rightmost:14,ring:19,role:[1,14,21,22],root:[6,7,8,9,10],row:14,rowch:11,rq:[0,1],rq_:[0,1,14],run:[11,13,14,17,23],running_acc:17,running_test_acc:[17,23],running_train_acc:[17,23],running_train_loss:[17,23],s:[1,4,11,12,14,15,17,18,19,20,21,22,23],s_1:4,s_i:[1,3,14,15],s_j:4,s_n:4,sai:[1,20],same:[1,8,10,14,15,18,19,20,23],sampler:[17,23],satisfi:[1,3,14,15],save:[1,6,7,8,9,10,11,14,16,19,22],scalar:14,scale:[3,14],scenario:11,scheme:10,schmidt:14,schwab:4,score:[6,7,8,9,10,17,23],second:[1,14,15,20],secondli:22,section:[18,19,21,23],see:[1,6,7,8,9,10,11,14,15,21,23],seen:21,select:[1,15,18,19,22],self:[1,6,7,8,9,14,22,23],send:[17,22,23],sens:[17,21],sento:22,separ:[14,15],sepcifi:1,sequenc:[1,3,14,15,19,22],sequenti:[11,22],serv:[11,17],set:[1,7,8,14,15,20,21,22,23],set_data_nod:[1,7,8,9,14,21,22],set_param:[1,7,14],set_tensor:[1,14,20],set_tensor_from:[1,20,21,22],sever:[1,15],shall:[1,22],shape:[1,3,4,7,8,9,10,11,12,14,15,17,18,19,20,21,22],shape_all_nodes_layer_1:14,shape_all_nodes_layer_n:14,shape_node_1_layer_1:14,shape_node_1_layer_n:14,shape_node_i1_layer_1:14,shape_node_in_layer_n:14,share:[1,14,15,18,20,21],share_tensor:[1,14],should:[1,3,4,7,11,12,14,15,18,19,20,21,22],show:[6,7,8,9,10,18],showcas:11,shown:7,shuffl:[6,7,8,9,10],side:[1,7,8,14,15,23],similar:[1,14,15,18,19,21,22],simpl:[11,18,19],simpli:[1,18,20,22],simplifi:[1,11],simul:18,sin:4,sinc:[1,8,14,15,19,20,21,22,23],singl:[1,14,15,19,22],singular:[1,3,14,15,17,19,23],site:[7,14,18],sites_per_lay:14,situat:[1,15,21],size:[1,4,6,7,8,9,10,12,14,15,18,19,21,22,23],skip:1,skipp:20,slice:[1,15,20,21],slide:1,slower:[1,20],small:19,smaller:1,smooth:11,snake:[6,14,23],so:[1,14,15,20,22],some:[1,3,14,15,17,18,19,20,21,22],someth:18,sometim:1,sort:[1,14,18,20,21],sourc:[1,3,4,12,14,15],space:[1,17,22,23],special:[1,19,21],specif:[1,11,21],specifi:[1,3,14,15,18,19,21],split:[0,1,3,7,8,14,18,19],split_0:[1,15],split_1:[1,15],split_:[0,1,7,8,19],split_ip:[1,15],split_ip_0:[1,15],split_ip_1:[1,15],sqrt:4,squar:[14,15],squeez:[4,6,7,8,9,10],ss16:[7,10],stack:[0,1,4,6,14,17,19,20,21,22],stack_0:1,stack_1:1,stack_data:[1,15,22],stack_data_memori:[1,21],stack_data_nod:19,stack_input:22,stack_nod:[1,15,19,20,21],stack_result:[19,22],stackabl:15,stacked_einsum:[0,19],stackedg:[0,15,19],stacknod:[0,15,19,21],stai:1,stand:4,start:[1,14,18,20,22,23],state:[1,11,14,17,18,20,22],state_dict:[6,7,8,9,10,14,22],staticmethod:[6,23],statist:11,std:[1,6,8,10,12,14,17,22,23],stem:11,step:[1,6,7,8,9,10,11,16,23],stick:1,still:[1,14,15,19],store:[1,14,15,18,21,22],stoudenmir:4,str:[1,15],straightforward:11,strength:11,stride:[6,14,23],string:[1,14,15,21],structur:[1,11,14,18,22,23],sub:[0,19],subclass:[1,11,16,18,21,23],subsequ:[1,15,20],subsetrandomsampl:[17,23],substitut:1,subsystem:14,subtract:15,subtrahend:15,successfulli:10,successor:[0,15],suffix:1,suitabl:11,sum:[1,3,6,7,8,9,10,14,15,17,19,23],sum_:[1,3,14,15],summat:19,support:11,sv:[1,15],svd:[0,1,3,14,17],svd_:[0,1,14,19],svdr:[0,1,14],svdr_:[0,1,14],t:[14,15,17,18,21,22],t_:[12,15],tailor:11,take:[1,14,15,17,20,21,22],taken:[1,15],target:[6,7,8,9,10],task:[1,14,21],temporari:[1,21],tensor:[0,3,4,5,6,7,8,11,12,14,16,21,22],tensor_address:[1,14,20,21],tensorb:15,tensori:5,tensorkrowch:[1,3,4,6,7,8,9,10,12,13,14,15,16,19,21,22,23],tensornetwork:[1,11,15,16,19,21,23],termin:[11,13],test:[6,7,8,9,10,11,13,17,23],test_acc:[6,7,8,9,10],test_dataset:[6,7,8,9,10],test_load:[6,7,8,9,10],test_set:[17,23],texttt:14,th:14,than:[1,3,7,8,14,15,18,22],thank:18,thei:[1,11,13,14,15,18,19,21],them:[1,14,15,17,18,19,20,22,23],themselv:[14,15],therefor:[1,11,13,14,18,22],thi:[1,3,4,6,8,9,10,11,14,15,17,18,19,20,21,22,23],thing:[1,15,17,21,23],think:20,third:14,those:[1,20,21,22],though:[1,14,15,19,22],thought:1,three:[1,15],through:[1,11,14,15,17,22,23],thu:[1,3,8,11,14,15,18,22],time:[1,3,4,8,10,11,14,15,16,18,19,22],titl:11,tk:[1,4,6,7,8,9,10,11,14,15,17,18,19,20,21,22,23],tn1:9,tn:[9,22],tn_fffc:9,tn_linear:9,tn_model:9,tn_nn:9,togeth:[1,11,13,20,21],tool:[11,18,19],top:[11,14,17,19,23],torch:[1,3,4,6,7,8,9,10,11,12,14,15,17,18,19,20,21,22,23],torchvis:[6,7,8,9,10,17,23],total:[1,3,14,15,17,23],total_num:[17,23],totensor:[6,7,8,9,10,17,23],tprod:[0,19],trace:[1,6,7,8,9,10,14,15,17,21,22,23],trace_sit:14,track:[1,14,20],tradit:8,train:[1,5,11,14,18,19,20,22,23],train_acc:[6,7,8,9,10],train_dataset:[6,7,8,9,10],train_load:[6,7,8,9,10],train_set:[17,23],trainabl:[1,14,18,21,22,23],transform:[4,6,7,8,9,10,14,17,19,23],transit:14,translation:[1,14,20,21],transpos:[6,17,22,23],treat:1,tree:[0,22],tri:[1,15],triangular:15,trick:[20,23],tupl:[1,12,14,15],turn:[1,14,20],tutori:[1,2,17,18,19,20,21,22,23],two:[1,14,15,18,19,20],two_0:15,type:[1,3,4,6,7,8,9,10,11,12,14,15,16,18,19,20],typeerror:1,u:[1,12,15],ump:0,umpo:0,umpslay:0,unbind:[0,1,19,20,22],unbind_0:[1,15],unbind_result:19,unbinded_nod:[19,20],unbound:[1,15],under:[11,20],underli:20,underscor:[1,15],understand:[11,18],undesir:[1,3,14,15,21],unfold:14,uniform:[1,12,14,20,21,22],uniform_memori:22,uniform_nod:[20,21],uninform:17,uniqu:[14,20],unit:[0,6,7,10,14],unitari:[14,15],univoc:14,unix:[11,13],unless:1,unmerg:8,unmerge_block:7,unset:1,unset_data_nod:[1,14],unset_tensor:1,unsqueez:23,until:[1,10,14,15,17],up:[1,14,15,18],up_bord:14,updat:[8,14,17,23],update_bond_dim:[7,8,14],upep:0,upon:14,upper:[14,15],ur_1sr_2v:15,us:[1,3,4,7,11,14,15,17,18,19,20,21,22,23],usag:1,useless:17,user:[1,11],usual:[1,14,15,21,22],usv:15,util:[6,7,8,9,10,17,23],utre:0,v:[1,11,13,15],valid:1,valu:[1,3,4,14,15,17,19,22,23],valueerror:[1,7],vanilla:[18,19,20],vanish:14,vari:14,variabl:4,variant:[19,22],variat:1,variou:[11,14],vdot:4,vec:3,vec_to_mp:[0,9],vector:[3,4,14,17,22,23],veri:[1,17,18,20,22],verifi:[1,15],versa:14,versatil:11,version:[1,14,15,19],vertic:[1,14],via:[1,3,7,14,15,18,19,20,21],vice:14,view:[1,6,17,22,23],virtual:[1,15,20,21,22],virtual_nod:[1,21],virtual_result:[1,21],virtual_result_stack:[1,15,21],virtual_uniform:[1,14,21,22],virtual_uniform_mp:1,visit:1,von:14,wa:[1,6,14,22,23],wai:[1,5,8,14,17,18,22],want:[1,14,18,19,20,22],we:[1,6,7,8,10,11,17,18,19,20,21,22,23],weight:9,weight_decai:[6,7,8,9,10,17,23],well:[1,10,14,15],were:[1,14,15,19,21,22],what:[1,18,20,22],when:[1,14,15,17,18,19,20,21,22],whenev:[1,15],where:[1,4,8,14,15,17,18,19,20,22],whether:[1,3,12,14,18,20],which:[1,4,10,12,14,15,17,18,19,20,21,22],whilst:14,who:11,whole:[1,8,14,17,19,20,21,22],whose:[1,6,14,15,20,21],why:20,wide:[11,22],width:[14,17,22],wise:[1,15],wish:22,without:[1,12,14,15,17],won:[14,22],word:[1,18],work:[1,11,15,19,23],workflow:[1,22],worri:[1,20],would:[1,14,15,18,20,23],wouldn:21,wow:23,wrap:[1,18],written:[11,13],x:[4,6,7,8,9,10,11,14,17,22,23],x_1:4,x_j:4,x_n:4,y1:23,y2:23,y3:23,y4:23,y:[6,7,8,9,10,23],yet:[18,22],you:[11,13,17,18,19,20,21,22],your:[11,13,17,18,22],your_nod:20,yourself:11,zero:[0,1,6,7,8,9,10,14,17,22,23],zero_grad:[6,7,8,9,10,17,23],zeros_lik:20},titles:["API Reference","Components","<no title>","Decompositions","Embeddings","Examples","Hybrid Tensorial Neural Network model","DMRG-like training of MPS","Hybrid DMRG-like training of MPS","Tensorizing Neural Networks","Training MPS in different ways","TensorKrowch documentation","Initializers","Installation","Models","Operations","Tutorials","First Steps with TensorKrowch","Creating a Tensor Network in TensorKrowch","Contracting and Differentiating the Tensor Network","How to save Memory and Time with TensorKrowch (ADVANCED)","The different Types of Nodes (ADVANCED)","How to subclass TensorNetwork to build Custom Models","Creating a Hybrid Neural-Tensor Network Model"],titleterms:{"1":[17,18,19,20,21,22],"2":[17,18,19,20,21,22],"3":[17,18,19,20,22],"4":17,"5":17,"6":17,"7":17,"class":15,"function":17,"import":17,The:[21,22],abstractnod:1,add:15,add_on:4,advanc:[20,21],api:0,ar:20,axi:1,basi:4,between:19,build:[18,22],choos:17,cite:11,compon:[1,18,22],connect:15,connect_stack:15,content:2,contract:[15,19],contract_:15,contract_between:15,contract_between_:15,contract_edg:15,convmp:14,convmpslay:14,convpep:14,convtre:14,convump:14,convumpslay:14,convupep:14,convutre:14,copi:12,creat:[18,23],custom:22,data:17,dataset:[6,7,8,9,10],decomposit:3,defin:[6,7,8,9],differ:[10,21],differenti:19,disconnect:15,discret:4,distinguish:19,div:15,dmrg:[7,8],document:11,download:17,edg:[1,15],einsum:15,embed:4,empti:12,everyth:22,exampl:[5,11],faster:20,first:[17,22],how:[18,20,22],hybrid:[6,8,23],hyperparamet:17,initi:12,instal:[11,13],instanti:[10,17],introduct:[17,18,19,20,21,22],layer:9,librari:17,licens:11,like:[7,8,15],loss:17,manag:20,mat_to_mpo:3,matrix:19,memori:20,mode:20,model:[6,7,8,10,14,17,22,23],mp:[7,8,10,14],mpo:14,mpsdata:14,mpslayer:14,mul:15,name:21,network:[1,6,9,17,18,19,23],neural:[6,9,23],node:[1,15,18,19,21],notebook:11,ones:12,oper:[15,19,20],optim:17,our:22,paramnod:[1,19],paramstacknod:1,pep:14,permut:15,permute_:15,poli:4,product:19,prune:[10,17],put:22,qr:15,qr_:15,rand:12,randn:12,refer:0,renorm:15,requir:11,reserv:21,retrain:10,rq:15,rq_:15,run:20,save:20,set:17,setup:[17,18],skip:20,split:15,split_:15,stack:15,stacked_einsum:15,stackedg:1,stacknod:1,start:17,state:19,step:[17,18,19,20,21,22],store:20,sub:15,subclass:22,successor:1,svd:15,svd_:15,svdr:15,svdr_:15,tensor:[1,9,15,17,18,19,20,23],tensori:6,tensorkrowch:[11,17,18,20],tensornetwork:[18,20,22],time:20,togeth:22,tprod:15,train:[6,7,8,9,10,17],tree:14,tutori:[11,16],type:21,ump:14,umpo:14,umpslay:14,unbind:15,unit:4,upep:14,utre:14,vec_to_mp:3,wai:10,zero:12}}) \ No newline at end of file diff --git a/docs/_build/html/tutorials.html b/docs/_build/html/tutorials.html index 46ed10e..07e2f16 100644 --- a/docs/_build/html/tutorials.html +++ b/docs/_build/html/tutorials.html @@ -6,7 +6,7 @@ - Tutorials — TensorKrowch 1.1.2 documentation + Tutorials — TensorKrowch 1.1.3 documentation diff --git a/docs/_build/html/tutorials/0_first_steps.html b/docs/_build/html/tutorials/0_first_steps.html index 1373161..a914266 100644 --- a/docs/_build/html/tutorials/0_first_steps.html +++ b/docs/_build/html/tutorials/0_first_steps.html @@ -6,7 +6,7 @@ - First Steps with TensorKrowch — TensorKrowch 1.1.2 documentation + First Steps with TensorKrowch — TensorKrowch 1.1.3 documentation @@ -619,7 +619,7 @@

    4. Instantiate the Tensor Network Model
    # Check if GPU is available
    -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
     
     # Instantiate model
     mps = tk.models.MPSLayer(n_features=image_size[0] * image_size[1] + 1,
    diff --git a/docs/_build/html/tutorials/1_creating_tensor_network.html b/docs/_build/html/tutorials/1_creating_tensor_network.html
    index 63483dc..cab5a5b 100644
    --- a/docs/_build/html/tutorials/1_creating_tensor_network.html
    +++ b/docs/_build/html/tutorials/1_creating_tensor_network.html
    @@ -6,7 +6,7 @@
         
         
     
    -    Creating a Tensor Network in TensorKrowch — TensorKrowch 1.1.2 documentation
    +    Creating a Tensor Network in TensorKrowch — TensorKrowch 1.1.3 documentation
         
       
       
    diff --git a/docs/_build/html/tutorials/2_contracting_tensor_network.html b/docs/_build/html/tutorials/2_contracting_tensor_network.html
    index e9e6c2d..3dacbac 100644
    --- a/docs/_build/html/tutorials/2_contracting_tensor_network.html
    +++ b/docs/_build/html/tutorials/2_contracting_tensor_network.html
    @@ -6,7 +6,7 @@
         
         
     
    -    Contracting and Differentiating the Tensor Network — TensorKrowch 1.1.2 documentation
    +    Contracting and Differentiating the Tensor Network — TensorKrowch 1.1.3 documentation
         
       
       
    diff --git a/docs/_build/html/tutorials/3_memory_management.html b/docs/_build/html/tutorials/3_memory_management.html
    index 4a9c611..4181205 100644
    --- a/docs/_build/html/tutorials/3_memory_management.html
    +++ b/docs/_build/html/tutorials/3_memory_management.html
    @@ -6,7 +6,7 @@
         
         
     
    -    How to save Memory and Time with TensorKrowch (ADVANCED) — TensorKrowch 1.1.2 documentation
    +    How to save Memory and Time with TensorKrowch (ADVANCED) — TensorKrowch 1.1.3 documentation
         
       
       
    diff --git a/docs/_build/html/tutorials/4_types_of_nodes.html b/docs/_build/html/tutorials/4_types_of_nodes.html
    index e9f766a..a5aa028 100644
    --- a/docs/_build/html/tutorials/4_types_of_nodes.html
    +++ b/docs/_build/html/tutorials/4_types_of_nodes.html
    @@ -6,7 +6,7 @@
         
         
     
    -    The different Types of Nodes (ADVANCED) — TensorKrowch 1.1.2 documentation
    +    The different Types of Nodes (ADVANCED) — TensorKrowch 1.1.3 documentation
         
       
       
    diff --git a/docs/_build/html/tutorials/5_subclass_tensor_network.html b/docs/_build/html/tutorials/5_subclass_tensor_network.html
    index 14bb3c3..607a572 100644
    --- a/docs/_build/html/tutorials/5_subclass_tensor_network.html
    +++ b/docs/_build/html/tutorials/5_subclass_tensor_network.html
    @@ -6,7 +6,7 @@
         
         
     
    -    How to subclass TensorNetwork to build Custom Models — TensorKrowch 1.1.2 documentation
    +    How to subclass TensorNetwork to build Custom Models — TensorKrowch 1.1.3 documentation
         
       
       
    diff --git a/docs/_build/html/tutorials/6_mix_with_pytorch.html b/docs/_build/html/tutorials/6_mix_with_pytorch.html
    index bc91b62..5555ac7 100644
    --- a/docs/_build/html/tutorials/6_mix_with_pytorch.html
    +++ b/docs/_build/html/tutorials/6_mix_with_pytorch.html
    @@ -6,7 +6,7 @@
         
         
     
    -    Creating a Hybrid Neural-Tensor Network Model — TensorKrowch 1.1.2 documentation
    +    Creating a Hybrid Neural-Tensor Network Model — TensorKrowch 1.1.3 documentation
         
       
       
    diff --git a/docs/tutorials/0_first_steps.rst b/docs/tutorials/0_first_steps.rst
    index 05997e1..153719f 100644
    --- a/docs/tutorials/0_first_steps.rst
    +++ b/docs/tutorials/0_first_steps.rst
    @@ -134,7 +134,7 @@ possible classes.
     ::
     
         # Check if GPU is available
    -    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    +    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
     
         # Instantiate model
         mps = tk.models.MPSLayer(n_features=image_size[0] * image_size[1] + 1,
    diff --git a/tensorkrowch/__init__.py b/tensorkrowch/__init__.py
    index 1f07f5b..52e9162 100644
    --- a/tensorkrowch/__init__.py
    +++ b/tensorkrowch/__init__.py
    @@ -6,7 +6,7 @@
     """
     
     # Version
    -__version__ = '1.1.3'
    +__version__ = '1.1.4'
     
     # Network components
     from tensorkrowch.components import Axis
    @@ -34,7 +34,7 @@
     from tensorkrowch.operations import get_shared_edges  # Not in docs
     
     from tensorkrowch.operations import (permute, permute_, tprod, mul, div, add,
    -                                     sub, renormalize)
    +                                     sub, renormalize, conj)
     from tensorkrowch.operations import (split, split_, contract_edges,
                                          contract_between, contract_between_,
                                          stack, unbind, einsum, stacked_einsum)
    diff --git a/tensorkrowch/components.py b/tensorkrowch/components.py
    index c2a1478..78885cd 100644
    --- a/tensorkrowch/components.py
    +++ b/tensorkrowch/components.py
    @@ -450,6 +450,7 @@ def __init__(self,
                      node1_list: Optional[List[bool]] = None,
                      init_method: Optional[Text] = None,
                      device: Optional[torch.device] = None,
    +                 dtype: Optional[torch.dtype] = None,
                      **kwargs: float) -> None:
     
             super().__init__()
    @@ -561,7 +562,10 @@ def __init__(self,
             if shape is not None:
                 if init_method is not None:
                     self._unrestricted_set_tensor(
    -                    init_method=init_method, device=device, **kwargs)
    +                    init_method=init_method,
    +                    device=device,
    +                    dtype=dtype,
    +                    **kwargs)
             else:
                 self._unrestricted_set_tensor(tensor=tensor)
     
    @@ -769,6 +773,36 @@ def is_resultant(self) -> bool:
             their own tensors or use other node's tensor.
             """
             return not (self._leaf or self._data or self._virtual)
    +    
    +    def is_conj(self) -> bool:
    +        """
    +        Equivalent to `torch.is_conj()
    +        `_.
    +        """
    +        tensor = self.tensor
    +        if tensor is None:
    +            return
    +        return tensor.is_conj()
    +    
    +    def is_complex(self) -> bool:
    +        """
    +        Equivalent to `torch.is_complex()
    +        `_.
    +        """
    +        tensor = self.tensor
    +        if tensor is None:
    +            return
    +        return tensor.is_complex()
    +    
    +    def is_floating_point(self) -> bool:
    +        """
    +        Equivalent to `torch.is_floating_point()
    +        `_.
    +        """
    +        tensor = self.tensor
    +        if tensor is None:
    +            return
    +        return tensor.is_floating_point()
     
         def size(self, axis: Optional[Ax] = None) -> Union[Size, int]:
             """
    @@ -1170,9 +1204,10 @@ def disconnect(self, axis: Optional[Ax] = None) -> None:
     
         @staticmethod
         def _make_copy_tensor(shape: Shape,
    -                          device: torch.device = torch.device('cpu')) -> Tensor:
    +                          device: Optional[torch.device] = None,
    +                          dtype: Optional[torch.dtype] = None) -> Tensor:
             """Returns copy tensor (ones in the "diagonal", zeros elsewhere)."""
    -        copy_tensor = torch.zeros(shape, device=device)
    +        copy_tensor = torch.zeros(shape, device=device, dtype=dtype)
             rank = len(shape)
             if rank <= 1:
                 i = 0
    @@ -1185,7 +1220,8 @@ def _make_copy_tensor(shape: Shape,
         def _make_rand_tensor(shape: Shape,
                               low: float = 0.,
                               high: float = 1.,
    -                          device: torch.device = torch.device('cpu')) -> Tensor:
    +                          device: Optional[torch.device] = None,
    +                          dtype: Optional[torch.dtype] = None) -> Tensor:
             """Returns tensor whose entries are drawn from the uniform distribution."""
             if not isinstance(low, float):
                 raise TypeError('`low` should be float type')
    @@ -1193,13 +1229,14 @@ def _make_rand_tensor(shape: Shape,
                 raise TypeError('`high` should be float type')
             if low >= high:
                 raise ValueError('`low` should be strictly smaller than `high`')
    -        return torch.rand(shape, device=device) * (high - low) + low
    +        return torch.rand(shape, device=device, dtype=dtype) * (high - low) + low
     
         @staticmethod
         def _make_randn_tensor(shape: Shape,
                                mean: float = 0.,
                                std: float = 1.,
    -                           device: torch.device = torch.device('cpu')) -> Tensor:
    +                           device: Optional[torch.device] = None,
    +                           dtype: Optional[torch.dtype] = None) -> Tensor:
             """Returns tensor whose entries are drawn from the normal distribution."""
             if not isinstance(mean, float):
                 raise TypeError('`mean` should be float type')
    @@ -1207,12 +1244,13 @@ def _make_randn_tensor(shape: Shape,
                 raise TypeError('`std` should be float type')
             if std <= 0:
                 raise ValueError('`std` should be positive')
    -        return torch.randn(shape, device=device) * std + mean
    +        return torch.randn(shape, device=device, dtype=dtype) * std + mean
     
         def make_tensor(self,
                         shape: Optional[Shape] = None,
                         init_method: Text = 'zeros',
    -                    device: torch.device = torch.device('cpu'),
    +                    device: Optional[torch.device] = None,
    +                    dtype: Optional[torch.dtype] = None,
                         **kwargs: float) -> Tensor:
             """
             Returns a tensor that can be put in the node, and is initialized according
    @@ -1226,6 +1264,8 @@ def make_tensor(self,
                 Initialization method.
             device : torch.device, optional
                 Device where to initialize the tensor.
    +        dtype : torch.dtype, optional
    +            Dtype of the tensor.
             kwargs : float
                 Keyword arguments for the different initialization methods:
     
    @@ -1247,15 +1287,17 @@ def make_tensor(self,
             if shape is None:
                 shape = self._shape
             if init_method == 'zeros':
    -            return torch.zeros(shape, device=device)
    +            return torch.zeros(shape, device=device, dtype=dtype)
             elif init_method == 'ones':
    -            return torch.ones(shape, device=device)
    +            return torch.ones(shape, device=device, dtype=dtype)
             elif init_method == 'copy':
    -            return self._make_copy_tensor(shape, device=device)
    +            return self._make_copy_tensor(shape, device=device, dtype=dtype)
             elif init_method == 'rand':
    -            return self._make_rand_tensor(shape, device=device, **kwargs)
    +            return self._make_rand_tensor(shape, device=device, dtype=dtype,
    +                                          **kwargs)
             elif init_method == 'randn':
    -            return self._make_randn_tensor(shape, device=device, **kwargs)
    +            return self._make_randn_tensor(shape, device=device, dtype=dtype,
    +                                           **kwargs)
             else:
                 raise ValueError('Choose a valid `init_method`: "zeros", '
                                  '"ones", "copy", "rand", "randn"')
    @@ -1328,6 +1370,7 @@ def _unrestricted_set_tensor(self,
                                      tensor: Optional[Tensor] = None,
                                      init_method: Optional[Text] = 'zeros',
                                      device: Optional[torch.device] = None,
    +                                 dtype: Optional[torch.dtype] = None,
                                      **kwargs: float) -> None:
             """
             Sets a new node's tensor or creates one with :meth:`make_tensor` and sets
    @@ -1346,6 +1389,8 @@ def _unrestricted_set_tensor(self,
                 Initialization method.
             device : torch.device, optional
                 Device where to initialize the tensor.
    +        dtype : torch.dtype, optional
    +            Dtype of the tensor.
             kwargs : float
                 Keyword arguments for the different initialization methods. See
                 :meth:`make_tensor`.
    @@ -1353,10 +1398,14 @@ def _unrestricted_set_tensor(self,
             if tensor is not None:
                 if not isinstance(tensor, Tensor):
                     raise TypeError('`tensor` should be torch.Tensor type')
    -            elif device is not None:
    +            if device is not None:
                     warnings.warn('`device` was specified but is being ignored. '
                                   'Provide a tensor that is already in the required'
                                   ' device')
    +            if dtype is not None:
    +                warnings.warn('`dtype` was specified but is being ignored. '
    +                              'Provide a tensor that already has the required'
    +                              ' dtype')
     
                 if not self._compatible_shape(tensor):
                     tensor = self._crop_tensor(tensor)
    @@ -1364,10 +1413,13 @@ def _unrestricted_set_tensor(self,
     
             elif init_method is not None:
                 node_tensor = self.tensor
    -            if (device is None) and (node_tensor is not None):
    -                device = node_tensor.device
    +            if node_tensor is not None:
    +                if device is None:
    +                    device = node_tensor.device
    +                if dtype is None:
    +                    dtype = node_tensor.dtype
                 tensor = self.make_tensor(
    -                init_method=init_method, device=device, **kwargs)
    +                init_method=init_method, device=device, dtype=dtype, **kwargs)
                 correct_format_tensor = self._set_tensor_format(tensor)
     
             else:
    @@ -1380,6 +1432,7 @@ def set_tensor(self,
                        tensor: Optional[Tensor] = None,
                        init_method: Optional[Text] = 'zeros',
                        device: Optional[torch.device] = None,
    +                   dtype: Optional[torch.dtype] = None,
                        **kwargs: float) -> None:
             """
             Sets new node's tensor or creates one with :meth:`make_tensor` and sets
    @@ -1413,6 +1466,8 @@ def set_tensor(self,
                 Initialization method.
             device : torch.device, optional
                 Device where to initialize the tensor.
    +        dtype : torch.dtype, optional
    +            Dtype of the tensor.
             kwargs : float
                 Keyword arguments for the different initialization methods. See
                 :meth:`make_tensor`.
    @@ -1453,6 +1508,7 @@ def set_tensor(self,
                 self._unrestricted_set_tensor(tensor=tensor,
                                               init_method=init_method,
                                               device=device,
    +                                          dtype=dtype,
                                               **kwargs)
             else:
                 raise ValueError('Node\'s tensor can only be changed if it is not'
    @@ -2005,6 +2061,8 @@ class Node(AbstractNode):  # MARK: Node
             Initialization method.
         device : torch.device, optional
             Device where to initialize the tensor if ``init_method`` is provided.
    +    dtype : torch.dtype, optional
    +        Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`AbstractNode.make_tensor`.
    @@ -2329,6 +2387,8 @@ class ParamNode(AbstractNode):  # MARK: ParamNode
             Initialization method.
         device : torch.device, optional
             Device where to initialize the tensor if ``init_method`` is provided.
    +    dtype : torch.dtype, optional
    +        Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`AbstractNode.make_tensor`.
    @@ -2410,6 +2470,7 @@ def __init__(self,
                      node1_list: Optional[List[bool]] = None,
                      init_method: Optional[Text] = None,
                      device: Optional[torch.device] = None,
    +                 dtype: Optional[torch.dtype] = None,
                      **kwargs: float) -> None:
     
             super().__init__(shape=shape,
    @@ -2425,6 +2486,7 @@ def __init__(self,
                              node1_list=node1_list,
                              init_method=init_method,
                              device=device,
    +                         dtype=dtype,
                              **kwargs)
     
         # ----------
    diff --git a/tensorkrowch/decompositions/svd_decompositions.py b/tensorkrowch/decompositions/svd_decompositions.py
    index afb31fe..4d65dc2 100644
    --- a/tensorkrowch/decompositions/svd_decompositions.py
    +++ b/tensorkrowch/decompositions/svd_decompositions.py
    @@ -145,6 +145,9 @@ def vec_to_mps(vec: torch.Tensor,
             
             tensors.append(u)
             prev_bond = aux_rank
    +        
    +        if vh.is_complex():
    +            s = s.to(vh.dtype)
             vec = torch.diag_embed(s) @ vh
             
         tensors.append(vec)
    @@ -281,6 +284,9 @@ def mat_to_mpo(mat: torch.Tensor,
             
             tensors.append(u)
             prev_bond = aux_rank
    +        
    +        if vh.is_complex():
    +            s = s.to(vh.dtype)
             mat = torch.diag_embed(s) @ vh
         
         mat = mat.reshape(aux_rank, in_out_dims[-2], in_out_dims[-1])
    diff --git a/tensorkrowch/models/mpo.py b/tensorkrowch/models/mpo.py
    index 332e907..c817474 100644
    --- a/tensorkrowch/models/mpo.py
    +++ b/tensorkrowch/models/mpo.py
    @@ -81,6 +81,8 @@ class MPO(TensorNetwork):  # MARK: MPO
             explanation of the different initialization methods.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
    +    dtype : torch.dtype, optional
    +        Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -121,6 +123,7 @@ def __init__(self,
                      n_batches: int = 1,
                      init_method: Text = 'randn',
                      device: Optional[torch.device] = None,
    +                 dtype: Optional[torch.dtype] = None,
                      **kwargs) -> None:
     
             super().__init__(name='mpo')
    @@ -267,6 +270,7 @@ def __init__(self,
             self.initialize(tensors=tensors,
                             init_method=init_method,
                             device=device,
    +                        dtype=dtype,
                             **kwargs)
         
         # ----------
    @@ -399,6 +403,7 @@ def initialize(self,
                        tensors: Optional[Sequence[torch.Tensor]] = None,
                        init_method: Optional[Text] = 'randn',
                        device: Optional[torch.device] = None,
    +                   dtype: Optional[torch.dtype] = None,
                        **kwargs: float) -> None:
             """
             Initializes all the nodes of the :class:`MPO`. It can be called when
    @@ -408,7 +413,7 @@ def initialize(self,
             
             * ``{"zeros", "ones", "copy", "rand", "randn"}``: Each node is
               initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
    -          the given method, ``device`` and ``kwargs``.
    +          the given method, ``device``, ``dtype`` and ``kwargs``.
             
             Parameters
             ----------
    @@ -422,6 +427,8 @@ def initialize(self,
                 Initialization method.
             device : torch.device, optional
                 Device where to initialize the tensors if ``init_method`` is provided.
    +        dtype : torch.dtype, optional
    +            Dtype of the tensor if ``init_method`` is provided.
             kwargs : float
                 Keyword arguments for the different initialization methods. See
                 :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -436,6 +443,8 @@ def initialize(self,
                     
                     if device is None:
                         device = tensors[0].device
    +                if dtype is None:
    +                    dtype = tensors[0].dtype
                     
                     if len(tensors) == 1:
                         tensors[0] = tensors[0].reshape(1,
    @@ -446,13 +455,15 @@ def initialize(self,
                     else:
                         # Left node
                         aux_tensor = torch.zeros(*self._mats_env[0].shape,
    -                                             device=device)
    +                                             device=device,
    +                                             dtype=dtype)
                         aux_tensor[0] = tensors[0]
                         tensors[0] = aux_tensor
                         
                         # Right node
                         aux_tensor = torch.zeros(*self._mats_env[-1].shape,
    -                                             device=device)
    +                                             device=device,
    +                                             dtype=dtype)
                         aux_tensor[..., 0, :] = tensors[-1]
                         tensors[-1] = aux_tensor
                     
    @@ -464,10 +475,13 @@ def initialize(self,
                 for i, node in enumerate(self._mats_env):
                     node.set_tensor(init_method=init_method,
                                     device=device,
    +                                dtype=dtype,
                                     **kwargs)
                     
                     if self._boundary == 'obc':
    -                    aux_tensor = torch.zeros(*node.shape, device=device)
    +                    aux_tensor = torch.zeros(*node.shape,
    +                                             device=device,
    +                                             dtype=dtype)
                         if i == 0:
                             # Left node
                             aux_tensor[0] = node.tensor[0]
    @@ -477,8 +491,12 @@ def initialize(self,
                         node.tensor = aux_tensor
             
             if self._boundary == 'obc':
    -            self._left_node.set_tensor(init_method='copy', device=device)
    -            self._right_node.set_tensor(init_method='copy', device=device)
    +            self._left_node.set_tensor(init_method='copy',
    +                                       device=device,
    +                                       dtype=dtype)
    +            self._right_node.set_tensor(init_method='copy',
    +                                        device=device,
    +                                        dtype=dtype)
         
         def set_data_nodes(self) -> None:
             """
    @@ -515,7 +533,8 @@ def copy(self, share_tensors: bool = False) -> 'MPO':
                           tensors=None,
                           n_batches=self._n_batches,
                           init_method=None,
    -                      device=None)
    +                      device=None,
    +                      dtype=None)
             new_mpo.name = self.name + '_copy'
             if share_tensors:
                 for new_node, node in zip(new_mpo._mats_env, self._mats_env):
    @@ -1025,6 +1044,8 @@ class UMPO(MPO):  # MARK: UMPO
             explanation of the different initialization methods.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
    +    dtype : torch.dtype, optional
    +        Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -1053,6 +1074,7 @@ def __init__(self,
                      n_batches: int = 1,
                      init_method: Text = 'randn',
                      device: Optional[torch.device] = None,
    +                 dtype: Optional[torch.dtype] = None,
                      **kwargs) -> None:
     
             tensors = None
    @@ -1097,6 +1119,7 @@ def __init__(self,
                              n_batches=n_batches,
                              init_method=init_method,
                              device=device,
    +                         dtype=dtype,
                              **kwargs)
             self.name = 'umpo'
         
    @@ -1122,6 +1145,7 @@ def initialize(self,
                        tensors: Optional[Sequence[torch.Tensor]] = None,
                        init_method: Optional[Text] = 'randn',
                        device: Optional[torch.device] = None,
    +                   dtype: Optional[torch.dtype] = None,
                        **kwargs: float) -> None:
             """
             Initializes the common tensor of the :class:`UMPO`. It can be called
    @@ -1131,7 +1155,7 @@ def initialize(self,
             
             * ``{"zeros", "ones", "copy", "rand", "randn"}``: The tensor is
               initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
    -          the given method, ``device`` and ``kwargs``.
    +          the given method, ``device``, ``dtype`` and ``kwargs``.
             
             Parameters
             ----------
    @@ -1143,6 +1167,8 @@ def initialize(self,
                 Initialization method.
             device : torch.device, optional
                 Device where to initialize the tensors if ``init_method`` is provided.
    +        dtype : torch.dtype, optional
    +            Dtype of the tensor if ``init_method`` is provided.
             kwargs : float
                 Keyword arguments for the different initialization methods. See
                 :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -1153,6 +1179,7 @@ def initialize(self,
             elif init_method is not None:
                 self.uniform_memory.set_tensor(init_method=init_method,
                                                device=device,
    +                                           dtype=dtype,
                                                **kwargs)
         
         def copy(self, share_tensors: bool = False) -> 'UMPO':
    @@ -1180,7 +1207,8 @@ def copy(self, share_tensors: bool = False) -> 'UMPO':
                            tensor=None,
                            n_batches=self._n_batches,
                            init_method=None,
    -                       device=None)
    +                       device=None,
    +                       dtype=None)
             new_mpo.name = self.name + '_copy'
             if share_tensors:
                 new_mpo.uniform_memory.tensor = self.uniform_memory.tensor
    diff --git a/tensorkrowch/models/mps.py b/tensorkrowch/models/mps.py
    index feea9b2..6f91a76 100644
    --- a/tensorkrowch/models/mps.py
    +++ b/tensorkrowch/models/mps.py
    @@ -111,6 +111,8 @@ class MPS(TensorNetwork):  # MARK: MPS
             explanation of the different initialization methods.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
    +    dtype : torch.dtype, optional
    +        Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -170,6 +172,7 @@ def __init__(self,
                      n_batches: int = 1,
                      init_method: Text = 'randn',
                      device: Optional[torch.device] = None,
    +                 dtype: Optional[torch.dtype] = None,
                      **kwargs) -> None:
     
             super().__init__(name='mps')
    @@ -363,6 +366,7 @@ def __init__(self,
             self.initialize(tensors=tensors,
                             init_method=init_method,
                             device=device,
    +                        dtype=dtype,
                             **kwargs)
         
         # ----------
    @@ -579,7 +583,9 @@ def _make_nodes(self) -> None:
                     if i == self._n_features - 1:
                         self._mats_env[-1]['right'] ^ self._right_node['left']
         
    -    def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
    +    def _make_canonical(self,
    +                        device: Optional[torch.device] = None,
    +                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
             """
             Creates random unitaries to initialize the MPS in canonical form with
             orthogonality center at the rightmost node. Unitaries in nodes are
    @@ -607,22 +613,22 @@ def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.T
                     phys_dim = node_shape[1]
                 size = max(aux_shape[0], aux_shape[1])
                 
    -            tensor = random_unitary(size, device=device)
    +            tensor = random_unitary(size, device=device, dtype=dtype)
                 tensor = tensor[:min(aux_shape[0], size), :min(aux_shape[1], size)]
                 tensor = tensor.reshape(*node_shape)
                 
                 if i == (self._n_features - 1):
    -                if self._boundary == 'obc':
    -                    tensor = tensor.t() / tensor.norm() * sqrt(phys_dim)
    -                else:
    -                    tensor = tensor / tensor.norm() * sqrt(phys_dim)
    -            else:
    -                tensor = tensor * sqrt(phys_dim)
    +                if (self._boundary == 'obc') and (i == 0):
    +                    tensor = tensor[:, 0]
    +                tensor = tensor / tensor.norm()
    +            tensor = tensor * sqrt(phys_dim)
                 
                 tensors.append(tensor)
             return tensors
         
    -    def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
    +    def _make_unitaries(self,
    +                        device: Optional[torch.device] = None,
    +                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
             """
             Creates random unitaries to initialize the MPS nodes as stacks of
             unitaries.
    @@ -646,7 +652,7 @@ def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.T
                     size_2 = min(node.shape[2], size)
                 
                 for _ in range(node.shape[1]):
    -                tensor = random_unitary(size, device=device)
    +                tensor = random_unitary(size, device=device, dtype=dtype)
                     tensor = tensor[:size_1, :size_2]
                     units.append(tensor)
                 
    @@ -654,19 +660,18 @@ def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.T
                 
                 if self._boundary == 'obc':
                     if i == 0:
    -                    tensors.append(units.squeeze(0))
    +                    units = units.squeeze(0)
                     elif i == (self._n_features - 1):
    -                    tensors.append(units.squeeze(-1))
    -                else:
    -                    tensors.append(units)
    -            else:    
    -                tensors.append(units)
    +                    units = units.squeeze(-1)
    +            tensors.append(units)
    +        
             return tensors
     
         def initialize(self,
                        tensors: Optional[Sequence[torch.Tensor]] = None,
                        init_method: Optional[Text] = 'randn',
                        device: Optional[torch.device] = None,
    +                   dtype: Optional[torch.dtype] = None,
                        **kwargs: float) -> None:
             """
             Initializes all the nodes of the :class:`MPS`. It can be called when
    @@ -676,7 +681,7 @@ def initialize(self,
             
             * ``{"zeros", "ones", "copy", "rand", "randn"}``: Each node is
               initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
    -          the given method, ``device`` and ``kwargs``.
    +          the given method, ``device``, ``dtype`` and ``kwargs``.
             
             * ``"randn_eye"``: Nodes are initialized as in this
               `paper `_, adding identities at the
    @@ -706,14 +711,16 @@ def initialize(self,
                 Initialization method.
             device : torch.device, optional
                 Device where to initialize the tensors if ``init_method`` is provided.
    +        dtype : torch.dtype, optional
    +            Dtype of the tensor if ``init_method`` is provided.
             kwargs : float
                 Keyword arguments for the different initialization methods. See
                 :meth:`~tensorkrowch.AbstractNode.make_tensor`.
             """
             if init_method == 'unit':
    -            tensors = self._make_unitaries(device=device)
    +            tensors = self._make_unitaries(device=device, dtype=dtype)
             elif init_method == 'canonical':
    -            tensors = self._make_canonical(device=device)
    +            tensors = self._make_canonical(device=device, dtype=dtype)
     
             if tensors is not None:
                 if len(tensors) != self._n_features:
    @@ -725,19 +732,23 @@ def initialize(self,
                     
                     if device is None:
                         device = tensors[0].device
    +                if dtype is None:
    +                    dtype = tensors[0].dtype
                     
                     if len(tensors) == 1:
                         tensors[0] = tensors[0].reshape(1, -1, 1)
                     else:
                         # Left node
                         aux_tensor = torch.zeros(*self._mats_env[0].shape,
    -                                             device=device)
    +                                             device=device,
    +                                             dtype=dtype)
                         aux_tensor[0] = tensors[0]
                         tensors[0] = aux_tensor
                         
                         # Right node
                         aux_tensor = torch.zeros(*self._mats_env[-1].shape,
    -                                             device=device)
    +                                             device=device,
    +                                             dtype=dtype)
                         aux_tensor[..., 0] = tensors[-1]
                         tensors[-1] = aux_tensor
                     
    @@ -753,16 +764,20 @@ def initialize(self,
                 for i, node in enumerate(self._mats_env):
                     node.set_tensor(init_method=init_method,
                                     device=device,
    +                                dtype=dtype,
                                     **kwargs)
                     if add_eye:
                         aux_tensor = node.tensor.detach()
                         aux_tensor[:, 0, :] += torch.eye(node.shape[0],
                                                          node.shape[2],
    -                                                     device=device)
    +                                                     device=device,
    +                                                     dtype=dtype)
                         node.tensor = aux_tensor
                     
                     if self._boundary == 'obc':
    -                    aux_tensor = torch.zeros(*node.shape, device=device)
    +                    aux_tensor = torch.zeros(*node.shape,
    +                                             device=device,
    +                                             dtype=dtype)
                         if i == 0:
                             # Left node
                             aux_tensor[0] = node.tensor[0]
    @@ -773,8 +788,12 @@ def initialize(self,
                             node.tensor = aux_tensor
             
             if self._boundary == 'obc':
    -            self._left_node.set_tensor(init_method='copy', device=device)
    -            self._right_node.set_tensor(init_method='copy', device=device)
    +            self._left_node.set_tensor(init_method='copy',
    +                                       device=device,
    +                                       dtype=dtype)
    +            self._right_node.set_tensor(init_method='copy',
    +                                        device=device,
    +                                        dtype=dtype)
     
         def set_data_nodes(self) -> None:
             """
    @@ -812,7 +831,8 @@ def copy(self, share_tensors: bool = False) -> 'MPS':
                           out_features=self._out_features,
                           n_batches=self._n_batches,
                           init_method=None,
    -                      device=None)
    +                      device=None,
    +                      dtype=None)
             new_mps.name = self.name + '_copy'
             if share_tensors:
                 for new_node, node in zip(new_mps._mats_env, self._mats_env):
    @@ -1225,10 +1245,10 @@ def contract(self,
                     copied_nodes = []
                     for node in nodes_out_env:
                         copied_node = node.__class__(shape=node._shape,
    -                                              axes_names=node.axes_names,
    -                                              name='virtual_result_copy',
    -                                              network=self,
    -                                              virtual=True)
    +                                                 axes_names=node.axes_names,
    +                                                 name='virtual_result_copy',
    +                                                 network=self,
    +                                                 virtual=True)
                         copied_node.set_tensor_from(node)
                         copied_nodes.append(copied_node)
                         
    @@ -1316,6 +1336,12 @@ def contract(self,
                             # Connect copies directly to output nodes
                             copied_node['input'] ^ node['input']
                     
    +                # If MPS nodes are complex, copied nodes are their conjugates
    +                is_complex = copied_nodes[0].is_complex()
    +                if is_complex:
    +                    for i, node in enumerate(copied_nodes):
    +                        copied_nodes[i] = node.conj()
    +                
                     # Contract output nodes with copies
                     mats_out_env = self._input_contraction(
                         nodes_env=nodes_out_env,
    @@ -1437,6 +1463,12 @@ def norm(self,
                 copied_nodes = []
                 for node in all_nodes:
                     copied_nodes.append(node.neighbours('input'))
    +        
    +        # If MPS nodes are complex, copied nodes are their conjugates
    +        is_complex = copied_nodes[0].is_complex()
    +        if is_complex:
    +            for i, node in enumerate(copied_nodes):
    +                copied_nodes[i] = node.conj()
                 
             # Contract output nodes with copies
             mats_out_env = self._input_contraction(
    @@ -1545,10 +1577,14 @@ def partial_density(self,
             if n_dims >= 1:
                 if n_dims == 1:
                     data = torch.cat(data, dim=-1)
    -                data = basis(data, dim=dims[0]).float().to(self.in_env[0].device)
    +                data = basis(data, dim=dims[0])\
    +                    .to(self.in_env[0].dtype)\
    +                    .to(self.in_env[0].device)
                 elif n_dims > 1:
                     data = [
    -                    basis(dat, dim=dim).squeeze(-2).float().to(self.in_env[0].device)
    +                    basis(dat, dim=dim).squeeze(-2)\
    +                        .to(self.in_env[0].dtype)\
    +                        .to(self.in_env[0].device)
                         for dat, dim in zip(data, dims)
                         ]
                 
    @@ -1675,7 +1711,7 @@ def entropy(self,
                                       middle_tensor.shape[-1]),         # right
                 full_matrices=False)
             
    -        s = s[s > 0]
    +        s = s[s.pow(2) > 0]
             entropy = -(s.pow(2) * s.pow(2).log()).sum()
             
             # Rescale
    @@ -1868,6 +1904,7 @@ def _project_to_bond_dim(self,
                                  side: Text = 'right'):
             """Projects all nodes into a space of dimension ``bond_dim``."""
             device = nodes[0].tensor.device
    +        dtype = nodes[0].tensor.dtype
     
             if side == 'left':
                 nodes.reverse()
    @@ -1892,7 +1929,7 @@ def _project_to_bond_dim(self,
     
                     proj_mat_node.tensor = torch.eye(
                         torch.tensor(phys_dim_lst).prod().int().item(),
    -                    bond_dim).view(*phys_dim_lst, -1).to(device)
    +                    bond_dim).view(*phys_dim_lst, -1).to(dtype).to(device)
                     for k in range(j + 1):
                         nodes[k]['input'] ^ proj_mat_node[k]
     
    @@ -1912,7 +1949,7 @@ def _project_to_bond_dim(self,
     
                 proj_mat_node.tensor = torch.eye(
                     torch.tensor(phys_dim_lst).prod().int().item(),
    -                bond_dim).view(*phys_dim_lst, -1).to(device)
    +                bond_dim).view(*phys_dim_lst, -1).to(dtype).to(device)
                 for k in range(j + 1):
                     nodes[k]['input'] ^ proj_mat_node[k]
     
    @@ -1929,7 +1966,8 @@ def _project_to_bond_dim(self,
                                      name=f'proj_vec_node_{side}_({k})',
                                      network=self)
     
    -            proj_vec_node.tensor = torch.eye(phys_dim, 1).squeeze().to(device)
    +            proj_vec_node.tensor = torch.eye(phys_dim, 1).squeeze()\
    +                .to(dtype).to(device)
                 nodes[k]['input'] ^ proj_vec_node['input']
                 line_mat_nodes.append(proj_vec_node @ nodes[k])
     
    @@ -2097,6 +2135,8 @@ class UMPS(MPS):  # MARK: UMPS
             explanation of the different initialization methods.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
    +    dtype : torch.dtype, optional
    +        Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -2125,6 +2165,7 @@ def __init__(self,
                      n_batches: int = 1,
                      init_method: Text = 'randn',
                      device: Optional[torch.device] = None,
    +                 dtype: Optional[torch.dtype] = None,
                      **kwargs) -> None:
             
             tensors = None
    @@ -2166,6 +2207,7 @@ def __init__(self,
                              n_batches=n_batches,
                              init_method=init_method,
                              device=device,
    +                         dtype=dtype,
                              **kwargs)
             self.name = 'umps'
     
    @@ -2186,7 +2228,9 @@ def _make_nodes(self) -> None:
             for node in self._mats_env:
                 node.set_tensor_from(uniform_memory)
         
    -    def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
    +    def _make_canonical(self,
    +                        device: Optional[torch.device] = None,
    +                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
             """
             Creates random unitaries to initialize the MPS in canonical form with
             orthogonality center at the rightmost node. Unitaries in nodes are
    @@ -2200,13 +2244,15 @@ def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.T
             size = max(aux_shape[0], aux_shape[1])
             phys_dim = node_shape[1]
             
    -        tensor = random_unitary(size, device=device)
    +        tensor = random_unitary(size, device=device, dtype=dtype)
             tensor = tensor[:min(aux_shape[0], size), :min(aux_shape[1], size)]
             tensor = tensor.reshape(*node_shape)
             tensor = tensor * sqrt(phys_dim)
             return tensor
         
    -    def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
    +    def _make_unitaries(self,
    +                        device: Optional[torch.device] = None,
    +                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
             """
             Creates random unitaries to initialize the MPS nodes as stacks of
             unitaries.
    @@ -2216,7 +2262,7 @@ def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.T
             
             units = []
             for _ in range(node_shape[1]):
    -            tensor = random_unitary(node_shape[0], device=device)
    +            tensor = random_unitary(node_shape[0], device=device, dtype=dtype)
                 units.append(tensor)
             tensor = torch.stack(units, dim=1)
             return tensor
    @@ -2225,6 +2271,7 @@ def initialize(self,
                        tensors: Optional[Sequence[torch.Tensor]] = None,
                        init_method: Optional[Text] = 'randn',
                        device: Optional[torch.device] = None,
    +                   dtype: Optional[torch.dtype] = None,
                        **kwargs: float) -> None:
             """
             Initializes the common tensor of the :class:`UMPS`. It can be called
    @@ -2234,7 +2281,7 @@ def initialize(self,
             
             * ``{"zeros", "ones", "copy", "rand", "randn"}``: The tensor is
               initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
    -          the given method, ``device`` and ``kwargs``.
    +          the given method, ``device``, ``dtype`` and ``kwargs``.
             
             * ``"randn_eye"``: Tensor is initialized as in this
               `paper `_, adding identities at the
    @@ -2261,6 +2308,8 @@ def initialize(self,
                 Initialization method.
             device : torch.device, optional
                 Device where to initialize the tensors if ``init_method`` is provided.
    +        dtype : torch.dtype, optional
    +            Dtype of the tensor if ``init_method`` is provided.
             kwargs : float
                 Keyword arguments for the different initialization methods. See
                 :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -2268,9 +2317,9 @@ def initialize(self,
             node = self.uniform_memory
             
             if init_method == 'unit':
    -            tensors = [self._make_unitaries(device=device)]
    +            tensors = [self._make_unitaries(device=device, dtype=dtype)]
             elif init_method == 'canonical':
    -            tensors = [self._make_canonical(device=device)]
    +            tensors = [self._make_canonical(device=device, dtype=dtype)]
             
             if tensors is not None:
                 node.tensor = tensors[0]
    @@ -2283,12 +2332,14 @@ def initialize(self,
                 
                 node.set_tensor(init_method=init_method,
                                 device=device,
    +                            dtype=dtype,
                                 **kwargs)
                 if add_eye:
                     aux_tensor = node.tensor.detach()
                     aux_tensor[:, 0, :] += torch.eye(node.shape[0],
                                                      node.shape[2],
    -                                                 device=device)
    +                                                 device=device,
    +                                                 dtype=dtype)
                     node.tensor = aux_tensor
         
         def copy(self, share_tensors: bool = False) -> 'UMPS':
    @@ -2317,7 +2368,8 @@ def copy(self, share_tensors: bool = False) -> 'UMPS':
                            out_features=self._out_features,
                            n_batches=self._n_batches,
                            init_method=None,
    -                       device=None)
    +                       device=None,
    +                       dtype=None)
             new_mps.name = self.name + '_copy'
             if share_tensors:
                 new_mps.uniform_memory.tensor = self.uniform_memory.tensor
    @@ -2461,6 +2513,8 @@ class MPSLayer(MPS):  # MARK: MPSLayer
             explanation of the different initialization methods.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
    +    dtype : torch.dtype, optional
    +        Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -2502,6 +2556,7 @@ def __init__(self,
                      n_batches: int = 1,
                      init_method: Text = 'randn',
                      device: Optional[torch.device] = None,
    +                 dtype: Optional[torch.dtype] = None,
                      **kwargs) -> None:
             
             phys_dim = None
    @@ -2563,6 +2618,7 @@ def __init__(self,
                              n_batches=n_batches,
                              init_method=init_method,
                              device=device,
    +                         dtype=dtype,
                              **kwargs)
             self.name = 'mpslayer'
             self._in_dim = self._phys_dim[:out_position] + \
    @@ -2592,7 +2648,9 @@ def out_node(self) -> ParamNode:
             """Returns the output node."""
             return self._mats_env[self._out_position]
         
    -    def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
    +    def _make_canonical(self,
    +                        device: Optional[torch.device] = None,
    +                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
             """
             Creates random unitaries to initialize the MPS in canonical form with
             orthogonality center at the rightmost node. Unitaries in nodes are
    @@ -2617,14 +2675,16 @@ def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.T
                     phys_dim = node_shape[1]
                 size = max(aux_shape[0], aux_shape[1])
                 
    -            tensor = random_unitary(size, device=device)
    +            tensor = random_unitary(size, device=device, dtype=dtype)
                 tensor = tensor[:min(aux_shape[0], size), :min(aux_shape[1], size)]
                 tensor = tensor.reshape(*node_shape)
                 
                 left_tensors.append(tensor * sqrt(phys_dim))
             
             # Output node
    -        out_tensor = torch.randn(self.out_node.shape, device=device)
    +        out_tensor = torch.randn(self.out_node.shape,
    +                                 device=device,
    +                                 dtype=dtype)
             phys_dim = out_tensor.shape[1]
             if self._boundary == 'obc':
                 if self._out_position == 0:
    @@ -2651,7 +2711,7 @@ def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.T
                     phys_dim = node_shape[1]
                 size = max(aux_shape[0], aux_shape[1])
                 
    -            tensor = random_unitary(size, device=device)
    +            tensor = random_unitary(size, device=device, dtype=dtype)
                 tensor = tensor[:min(aux_shape[0], size), :min(aux_shape[1], size)]
                 tensor = tensor.reshape(*node_shape)
                 
    @@ -2662,7 +2722,9 @@ def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.T
             tensors = left_tensors + [out_tensor] + right_tensors
             return tensors
         
    -    def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
    +    def _make_unitaries(self,
    +                        device: Optional[torch.device] = None,
    +                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
             """
             Creates random unitaries to initialize the MPS nodes as stacks of
             unitaries.
    @@ -2684,7 +2746,7 @@ def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.T
                     size_2 = min(node.shape[2], size)
                 
                 for _ in range(node.shape[1]):
    -                tensor = random_unitary(size, device=device)
    +                tensor = random_unitary(size, device=device, dtype=dtype)
                     tensor = tensor[:size_1, :size_2]
                     units.append(tensor)
                 
    @@ -2699,7 +2761,9 @@ def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.T
                     left_tensors.append(units)
             
             # Output node
    -        out_tensor = torch.randn(self.out_node.shape, device=device)
    +        out_tensor = torch.randn(self.out_node.shape,
    +                                 device=device,
    +                                 dtype=dtype)
             if self._boundary == 'obc':
                 if self._out_position == 0:
                     out_tensor = out_tensor[0]
    @@ -2723,7 +2787,7 @@ def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.T
                     size_2 = min(node.shape[2], size)
                 
                 for _ in range(node.shape[1]):
    -                tensor = random_unitary(size, device=device).t()
    +                tensor = random_unitary(size, device=device, dtype=dtype).H
                     tensor = tensor[:size_1, :size_2]
                     units.append(tensor)
                 
    @@ -2741,63 +2805,12 @@ def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.T
             # All tensors
             tensors = left_tensors + [out_tensor] + right_tensors
             return tensors
    -
    -    def initialize(self,
    -                   tensors: Optional[Sequence[torch.Tensor]] = None,
    -                   init_method: Optional[Text] = 'randn',
    -                   device: Optional[torch.device] = None,
    -                   **kwargs: float) -> None:
    -        """
    -        Initializes all the nodes of the :class:`MPS`. It can be called when
    -        instantiating the model, or to override the existing nodes' tensors.
    -        
    -        There are different methods to initialize the nodes:
    -        
    -        * ``{"zeros", "ones", "copy", "rand", "randn"}``: Each node is
    -          initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
    -          the given method, ``device`` and ``kwargs``.
    -        
    -        * ``"randn_eye"``: Nodes are initialized as in this
    -          `paper `_, adding identities at the
    -          top of random gaussian tensors. In this case, ``std`` should be
    -          specified with a low value, e.g., ``std = 1e-9``.
    -        
    -        * ``"unit"``: Nodes are initialized as stacks of random unitaries. This,
    -          combined (at least) with an embedding of the inputs as elements of
    -          the computational basis (:func:`~tensorkrowch.embeddings.discretize`
    -          combined with :func:`~tensorkrowch.embeddings.basis`)
    -        
    -        * ``"canonical"```: MPS is initialized in canonical form with a squared
    -          norm `close` to the product of all the physical dimensions (if bond
    -          dimensions are bigger than the powers of the physical dimensions,
    -          the norm could vary). Th orthogonality center is at the rightmost
    -          node.
    -        
    -        Parameters
    -        ----------
    -        tensors : list[torch.Tensor] or tuple[torch.Tensor], optional
    -            Sequence of tensors to set in each of the MPS nodes. If ``boundary``
    -            is ``"obc"``, all tensors should be rank-3, except the first and
    -            last ones, which can be rank-2, or rank-1 (if the first and last are
    -            the same). If ``boundary`` is ``"pbc"``, all tensors should be
    -            rank-3.
    -        init_method : {"zeros", "ones", "copy", "rand", "randn", "randn_eye", "unit", "canonical"}, optional
    -            Initialization method.
    -        device : torch.device, optional
    -            Device where to initialize the tensors if ``init_method`` is provided.
    -        kwargs : float
    -            Keyword arguments for the different initialization methods. See
    -            :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    -        """
    -        if init_method == 'unit':
    -            tensors = self._make_unitaries(device=device)
    -        elif init_method == 'canonical':
    -            tensors = self._make_canonical(device=device)
         
         def initialize(self,
                        tensors: Optional[Sequence[torch.Tensor]] = None,
                        init_method: Optional[Text] = 'randn',
                        device: Optional[torch.device] = None,
    +                   dtype: Optional[torch.dtype] = None,
                        **kwargs: float) -> None:
             """
             Initializes all the nodes of the :class:`MPSLayer`. It can be called when
    @@ -2807,7 +2820,7 @@ def initialize(self,
             
             * ``{"zeros", "ones", "copy", "rand", "randn"}``: Each node is
               initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
    -          the given method, ``device`` and ``kwargs``.
    +          the given method, ``device``, ``dtype`` and ``kwargs``.
             
             * ``"randn_eye"``: Nodes are initialized as in this
               `paper `_, adding identities at the
    @@ -2836,14 +2849,16 @@ def initialize(self,
                 Initialization method.
             device : torch.device, optional
                 Device where to initialize the tensors if ``init_method`` is provided.
    +        dtype : torch.dtype, optional
    +            Dtype of the tensor if ``init_method`` is provided.
             kwargs : float
                 Keyword arguments for the different initialization methods. See
                 :meth:`~tensorkrowch.AbstractNode.make_tensor`.
             """
             if init_method == 'unit':
    -            tensors = self._make_unitaries(device=device)
    +            tensors = self._make_unitaries(device=device, dtype=dtype)
             elif init_method == 'canonical':
    -            tensors = self._make_canonical(device=device)
    +            tensors = self._make_canonical(device=device, dtype=dtype)
     
             if tensors is not None:
                 if len(tensors) != self._n_features:
    @@ -2855,19 +2870,23 @@ def initialize(self,
                     
                     if device is None:
                         device = tensors[0].device
    +                if dtype is None:
    +                    dtype = tensors[0].dtype
                     
                     if len(tensors) == 1:
                         tensors[0] = tensors[0].reshape(1, -1, 1)
                     else:
                         # Left node
                         aux_tensor = torch.zeros(*self._mats_env[0].shape,
    -                                             device=device)
    +                                             device=device,
    +                                             dtype=dtype)
                         aux_tensor[0] = tensors[0]
                         tensors[0] = aux_tensor
                         
                         # Right node
                         aux_tensor = torch.zeros(*self._mats_env[-1].shape,
    -                                             device=device)
    +                                             device=device,
    +                                             dtype=dtype)
                         aux_tensor[..., 0] = tensors[-1]
                         tensors[-1] = aux_tensor
                     
    @@ -2883,12 +2902,14 @@ def initialize(self,
                 for i, node in enumerate(self._mats_env):
                     node.set_tensor(init_method=init_method,
                                     device=device,
    +                                dtype=dtype,
                                     **kwargs)
                     if add_eye:
                         aux_tensor = node.tensor.detach()
                         eye_tensor = torch.eye(node.shape[0],
                                                node.shape[2],
    -                                           device=device)
    +                                           device=device,
    +                                           dtype=dtype)
                         if i == self._out_position:
                             eye_tensor = eye_tensor.unsqueeze(1)
                             eye_tensor = eye_tensor.expand(node.shape)
    @@ -2898,7 +2919,9 @@ def initialize(self,
                         node.tensor = aux_tensor
                     
                     if self._boundary == 'obc':
    -                    aux_tensor = torch.zeros(*node.shape, device=device)
    +                    aux_tensor = torch.zeros(*node.shape,
    +                                             device=device,
    +                                             dtype=dtype)
                         if i == 0:
                             # Left node
                             aux_tensor[0] = node.tensor[0]
    @@ -2909,8 +2932,12 @@ def initialize(self,
                             node.tensor = aux_tensor
             
             if self._boundary == 'obc':
    -            self._left_node.set_tensor(init_method='copy', device=device)
    -            self._right_node.set_tensor(init_method='copy', device=device)
    +            self._left_node.set_tensor(init_method='copy',
    +                                       device=device,
    +                                       dtype=dtype)
    +            self._right_node.set_tensor(init_method='copy',
    +                                        device=device,
    +                                        dtype=dtype)
     
         def copy(self, share_tensors: bool = False) -> 'MPSLayer':
             """
    @@ -2939,7 +2966,8 @@ def copy(self, share_tensors: bool = False) -> 'MPSLayer':
                                tensors=None,
                                n_batches=self._n_batches,
                                init_method=None,
    -                           device=None)
    +                           device=None,
    +                           dtype=None)
             new_mps.name = self.name + '_copy'
             if share_tensors:
                 for new_node, node in zip(new_mps._mats_env, self._mats_env):
    @@ -3001,6 +3029,8 @@ class UMPSLayer(MPS):  # MARK: UMPSLayer
             explanation of the different initialization methods.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
    +    dtype : torch.dtype, optional
    +        Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -3031,6 +3061,7 @@ def __init__(self,
                      n_batches: int = 1,
                      init_method: Text = 'randn',
                      device: Optional[torch.device] = None,
    +                 dtype: Optional[torch.dtype] = None,
                      **kwargs) -> None:
             
             phys_dim = None
    @@ -3106,6 +3137,7 @@ def __init__(self,
                              n_batches=n_batches,
                              init_method=init_method,
                              device=device,
    +                         dtype=dtype,
                              **kwargs)
             self.name = 'umpslayer'
             self._in_dim = self._phys_dim[:out_position] + \
    @@ -3154,7 +3186,9 @@ def _make_nodes(self) -> None:
             for node in in_nodes:
                 node.set_tensor_from(uniform_memory)
         
    -    def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
    +    def _make_canonical(self,
    +                        device: Optional[torch.device] = None,
    +                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
             """
             Creates random unitaries to initialize the MPS in canonical form with
             orthogonality center at the rightmost node. Unitaries in nodes are
    @@ -3169,18 +3203,22 @@ def _make_canonical(self, device: Optional[torch.device] = None) -> List[torch.T
             size = max(aux_shape[0], aux_shape[1])
             phys_dim = node_shape[1]
             
    -        uni_tensor = random_unitary(size, device=device)
    +        uni_tensor = random_unitary(size, device=device, dtype=dtype)
             uni_tensor = uni_tensor[:min(aux_shape[0], size), :min(aux_shape[1], size)]
             uni_tensor = uni_tensor.reshape(*node_shape)
             uni_tensor = uni_tensor * sqrt(phys_dim)
             
             # Output node
    -        out_tensor = torch.randn(self.out_node.shape, device=device)
    +        out_tensor = torch.randn(self.out_node.shape,
    +                                 device=device,
    +                                 dtype=dtype)
             out_tensor = out_tensor / out_tensor.norm() * sqrt(out_tensor.shape[1])
             
             return [uni_tensor, out_tensor]
         
    -    def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.Tensor]:
    +    def _make_unitaries(self,
    +                        device: Optional[torch.device] = None,
    +                        dtype: Optional[torch.dtype] = None) -> List[torch.Tensor]:
             """
             Creates random unitaries to initialize the MPS nodes as stacks of
             unitaries.
    @@ -3191,7 +3229,9 @@ def _make_unitaries(self, device: Optional[torch.device] = None) -> List[torch.T
                 
                 units = []
                 for _ in range(node_shape[1]):
    -                tensor = random_unitary(node_shape[0], device=device)
    +                tensor = random_unitary(node_shape[0],
    +                                        device=device,
    +                                        dtype=dtype)
                     units.append(tensor)
                 
                 tensors.append(torch.stack(units, dim=1))
    @@ -3202,6 +3242,7 @@ def initialize(self,
                        tensors: Optional[Sequence[torch.Tensor]] = None,
                        init_method: Optional[Text] = 'randn',
                        device: Optional[torch.device] = None,
    +                   dtype: Optional[torch.dtype] = None,
                        **kwargs: float) -> None:
             """
             Initializes the common tensor of the :class:`UMPSLayer`. It can be called
    @@ -3211,7 +3252,7 @@ def initialize(self,
             
             * ``{"zeros", "ones", "copy", "rand", "randn"}``: The tensor is
               initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
    -          the given method, ``device`` and ``kwargs``.
    +          the given method, ``device``, ``dtype`` and ``kwargs``.
             
             * ``"randn_eye"``: Tensor is initialized as in this
               `paper `_, adding identities at the
    @@ -3239,14 +3280,16 @@ def initialize(self,
                 Initialization method.
             device : torch.device, optional
                 Device where to initialize the tensors if ``init_method`` is provided.
    +        dtype : torch.dtype, optional
    +            Dtype of the tensor if ``init_method`` is provided.
             kwargs : float
                 Keyword arguments for the different initialization methods. See
                 :meth:`~tensorkrowch.AbstractNode.make_tensor`.
             """
             if init_method == 'unit':
    -            tensors = self._make_unitaries(device=device)
    +            tensors = self._make_unitaries(device=device, dtype=dtype)
             elif init_method == 'canonical':
    -            tensors = self._make_canonical(device=device)
    +            tensors = self._make_canonical(device=device, dtype=dtype)
             
             if tensors is not None:
                 self.uniform_memory.tensor = tensors[0]
    @@ -3261,12 +3304,14 @@ def initialize(self,
                     
                     node.set_tensor(init_method=init_method,
                                     device=device,
    +                                dtype=dtype,
                                     **kwargs)
                     if add_eye:
                         aux_tensor = node.tensor.detach()
                         eye_tensor = torch.eye(node.shape[0],
                                                node.shape[2],
    -                                           device=device)
    +                                           device=device,
    +                                           dtype=dtype)
                         if i == 0:
                             aux_tensor[:, 0, :] += eye_tensor
                         else:
    @@ -3301,7 +3346,8 @@ def copy(self, share_tensors: bool = False) -> 'UMPSLayer':
                                 tensor=None,
                                 n_batches=self._n_batches,
                                 init_method=None,
    -                            device=None)
    +                            device=None,
    +                            dtype=None)
             new_mps.name = self.name + '_copy'
             if share_tensors:
                 new_mps.uniform_memory.tensor = self.uniform_memory.tensor
    @@ -3542,6 +3588,8 @@ class ConvMPS(AbstractConvClass, MPS):  # MARK: ConvMPS
             explanation of the different initialization methods.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
    +    dtype : torch.dtype, optional
    +        Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -3568,6 +3616,7 @@ def __init__(self,
                      tensors: Optional[Sequence[torch.Tensor]] = None,
                      init_method: Text = 'randn',
                      device: Optional[torch.device] = None,
    +                 dtype: Optional[torch.dtype] = None,
                      **kwargs):
             
             unfold = self._set_attributes(in_channels=in_channels,
    @@ -3585,6 +3634,7 @@ def __init__(self,
                          n_batches=2,
                          init_method=init_method,
                          device=device,
    +                     dtype=dtype,
                          **kwargs)
             
             self.unfold = unfold
    @@ -3653,7 +3703,8 @@ def copy(self, share_tensors: bool = False) -> 'ConvMPS':
                               boundary=self._boundary,
                               tensors=None,
                               init_method=None,
    -                          device=None)
    +                          device=None,
    +                          dtype=None)
             new_mps.name = self.name + '_copy'
             if share_tensors:
                 for new_node, node in zip(new_mps._mats_env, self._mats_env):
    @@ -3705,6 +3756,8 @@ class ConvUMPS(AbstractConvClass, UMPS):  # MARK: ConvUMPS
             explanation of the different initialization methods.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
    +    dtype : torch.dtype, optional
    +        Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -3734,6 +3787,7 @@ def __init__(self,
                      tensor: Optional[torch.Tensor] = None,
                      init_method: Text = 'randn',
                      device: Optional[torch.device] = None,
    +                 dtype: Optional[torch.dtype] = None,
                      **kwargs):
     
             unfold = self._set_attributes(in_channels=in_channels,
    @@ -3750,6 +3804,7 @@ def __init__(self,
                           n_batches=2,
                           init_method=init_method,
                           device=device,
    +                      dtype=dtype,
                           **kwargs)
             
             self.unfold = unfold
    @@ -3817,7 +3872,8 @@ def copy(self, share_tensors: bool = False) -> 'ConvUMPS':
                                dilation=self.dilation,
                                tensor=None,
                                init_method=None,
    -                           device=None)
    +                           device=None,
    +                           dtype=None)
             new_mps.name = self.name + '_copy'
             if share_tensors:
                 new_mps.uniform_memory.tensor = self.uniform_memory.tensor
    @@ -3885,6 +3941,8 @@ class ConvMPSLayer(AbstractConvClass, MPSLayer):  # MARK: ConvMPSLayer
             explanation of the different initialization methods.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
    +    dtype : torch.dtype, optional
    +        Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -3914,6 +3972,7 @@ def __init__(self,
                      tensors: Optional[Sequence[torch.Tensor]] = None,
                      init_method: Text = 'randn',
                      device: Optional[torch.device] = None,
    +                 dtype: Optional[torch.dtype] = None,
                      **kwargs):
     
             unfold = self._set_attributes(in_channels=in_channels,
    @@ -3934,6 +3993,7 @@ def __init__(self,
                               n_batches=2,
                               init_method=init_method,
                               device=device,
    +                          dtype=dtype,
                               **kwargs)
             
             self._out_channels = out_channels
    @@ -4009,7 +4069,8 @@ def copy(self, share_tensors: bool = False) -> 'ConvMPSLayer':
                                    boundary=self._boundary,
                                    tensors=None,
                                    init_method=None,
    -                               device=None)
    +                               device=None,
    +                               dtype=None)
             new_mps.name = self.name + '_copy'
             if share_tensors:
                 for new_node, node in zip(new_mps._mats_env, self._mats_env):
    @@ -4070,6 +4131,8 @@ class ConvUMPSLayer(AbstractConvClass, UMPSLayer):  # MARK: ConvUMPSLayer
             detailed explanation of the different initialization methods.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
    +    dtype : torch.dtype, optional
    +        Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -4102,6 +4165,7 @@ def __init__(self,
                      tensors: Optional[Sequence[torch.Tensor]] = None,
                      init_method: Text = 'randn',
                      device: Optional[torch.device] = None,
    +                 dtype: Optional[torch.dtype] = None,
                      **kwargs):
     
             unfold = self._set_attributes(in_channels=in_channels,
    @@ -4121,6 +4185,7 @@ def __init__(self,
                                n_batches=2,
                                init_method=init_method,
                                device=device,
    +                           dtype=dtype,
                                **kwargs)
             
             self._out_channels = out_channels
    @@ -4200,7 +4265,8 @@ def copy(self, share_tensors: bool = False) -> 'ConvUMPSLayer':
                                     dilation=self.dilation,
                                     tensor=None,
                                     init_method=None,
    -                                device=None)
    +                                device=None,
    +                                dtype=None)
             new_mps.name = self.name + '_copy'
             if share_tensors:
                 new_mps.uniform_memory.tensor = self.uniform_memory.tensor
    diff --git a/tensorkrowch/models/mps_data.py b/tensorkrowch/models/mps_data.py
    index c15dd9d..8bfa339 100644
    --- a/tensorkrowch/models/mps_data.py
    +++ b/tensorkrowch/models/mps_data.py
    @@ -86,6 +86,8 @@ class MPSData(TensorNetwork):  # MARK: MPSData
             :meth:`add_data` to see how to initialize MPS nodes with data tensors.
         device : torch.device, optional
             Device where to initialize the tensors if ``init_method`` is provided.
    +    dtype : torch.dtype, optional
    +            Dtype of the tensor if ``init_method`` is provided.
         kwargs : float
             Keyword arguments for the different initialization methods. See
             :meth:`~tensorkrowch.AbstractNode.make_tensor`.
    @@ -112,6 +114,7 @@ def __init__(self,
                      tensors: Optional[Sequence[torch.Tensor]] = None,
                      init_method: Optional[Text] = None,
                      device: Optional[torch.device] = None,
    +                 dtype: Optional[torch.dtype] = None,
                      **kwargs) -> None:
     
             super().__init__(name='mps_data')
    @@ -241,6 +244,7 @@ def __init__(self,
             self._make_nodes()
             self.initialize(init_method=init_method,
                                 device=device,
    +                            dtype=dtype,
                                 **kwargs)
             
             if tensors is not None:
    @@ -289,6 +293,15 @@ def mats_env(self) -> List[AbstractNode]:
             """Returns the list of nodes in ``mats_env``."""
             return self._mats_env
         
    +    @property
    +    def tensors(self) -> List[torch.Tensor]:
    +        """Returns the list of MPS tensors."""
    +        mps_tensors = [node.tensor for node in self._mats_env]
    +        if self._boundary == 'obc':
    +            mps_tensors[0] = mps_tensors[0][0, :, :]
    +            mps_tensors[-1] = mps_tensors[-1][:, :, 0]
    +        return mps_tensors
    +    
         # -------
         # Methods
         # -------
    @@ -346,6 +359,7 @@ def _make_nodes(self) -> None:
         def initialize(self,
                        init_method: Optional[Text] = 'randn',
                        device: Optional[torch.device] = None,
    +                   dtype: Optional[torch.dtype] = None,
                        **kwargs: float) -> None:
             """
             Initializes all the nodes of the :class:`MPSData`. It can be called when
    @@ -355,7 +369,7 @@ def initialize(self,
             
             * ``{"zeros", "ones", "copy", "rand", "randn"}``: Each node is
               initialized calling :meth:`~tensorkrowch.AbstractNode.set_tensor` with
    -          the given method, ``device`` and ``kwargs``.
    +          the given method, ``device``, ``dtype`` and ``kwargs``.
             
             Parameters
             ----------
    @@ -364,22 +378,31 @@ def initialize(self,
                 initialize MPS nodes with data tensors.
             device : torch.device, optional
                 Device where to initialize the tensors if ``init_method`` is provided.
    +        dtype : torch.dtype, optional
    +            Dtype of the tensor if ``init_method`` is provided.
             kwargs : float
                 Keyword arguments for the different initialization methods. See
                 :meth:`~tensorkrowch.AbstractNode.make_tensor`.
             """
             if self._boundary == 'obc':
    -            self._left_node.set_tensor(init_method='copy', device=device)
    -            self._right_node.set_tensor(init_method='copy', device=device)
    +            self._left_node.set_tensor(init_method='copy',
    +                                       device=device,
    +                                       dtype=dtype)
    +            self._right_node.set_tensor(init_method='copy',
    +                                        device=device,
    +                                        dtype=dtype)
                     
             if init_method is not None:
                 for i, node in enumerate(self._mats_env):
                     node.set_tensor(init_method=init_method,
                                     device=device,
    +                                dtype=dtype,
                                     **kwargs)
                     
                     if self._boundary == 'obc':
    -                    aux_tensor = torch.zeros(*node.shape, device=device)
    +                    aux_tensor = torch.zeros(*node.shape,
    +                                             device=device,
    +                                             dtype=dtype)
                         if i == 0:
                             # Left node
                             aux_tensor[..., 0, :, :] = node.tensor[..., 0, :, :]
    @@ -435,37 +458,45 @@ def add_data(self, data: Sequence[torch.Tensor]) -> None:
             
             data = data[:]
             device = data[0].device
    +        dtype = data[0].dtype
             for i, node in enumerate(self._mats_env):
                 if self._boundary == 'obc':
    -                aux_tensor = torch.zeros(*node.shape, device=device)
                     if (i == 0) and (i == (self._n_features - 1)):
                         aux_tensor = torch.zeros(*data[i].shape[:-1],
                                                  *node.shape[-3:],
    -                                             device=device)
    +                                             device=device,
    +                                             dtype=dtype)
                         aux_tensor[..., 0, :, 0] = data[i]
                         data[i] = aux_tensor
                     elif i == 0:
                         aux_tensor = torch.zeros(*data[i].shape[:-2],
                                                  *node.shape[-3:-1],
                                                  data[i].shape[-1],
    -                                             device=device)
    +                                             device=device,
    +                                             dtype=dtype)
                         aux_tensor[..., 0, :, :] = data[i]
                         data[i] = aux_tensor
                     elif i == (self._n_features - 1):
                         aux_tensor = torch.zeros(*data[i].shape[:-1],
                                                  *node.shape[-2:],
    -                                             device=device)
    +                                             device=device,
    +                                             dtype=dtype)
                         aux_tensor[..., 0] = data[i]
                         data[i] = aux_tensor
                         
                 node._direct_set_tensor(data[i])
             
    -        # Send left and right nodes to correct device
    +        # Send left and right nodes to correct device and dtype
             if self._boundary == 'obc':
                 if self._left_node.device != device:
                     self._left_node.tensor = self._left_node.tensor.to(device)
    +            if self._left_node.dtype != dtype:
    +                self._left_node.tensor = self._left_node.tensor.to(dtype)
    +            
                 if self._right_node.device != device:
                     self._right_node.tensor = self._right_node.tensor.to(device)
    +            if self._right_node.dtype != dtype:
    +                self._right_node.tensor = self._right_node.tensor.to(dtype)
             
             # Update bond dim
             if self._boundary == 'obc':
    diff --git a/tensorkrowch/operations.py b/tensorkrowch/operations.py
    index 42408c2..db6bcc0 100644
    --- a/tensorkrowch/operations.py
    +++ b/tensorkrowch/operations.py
    @@ -9,8 +9,11 @@
             * permute_           (in-place)
             * tprod
             * mul
    +        * div
             * add
             * sub
    +        * renormalize
    +        * conj
             
         Node-like operations:
             * split
    @@ -1466,6 +1469,120 @@ def renormalize(
     AbstractNode.renormalize = renormalize_node
     
     
    +##################################   conj    ##################################
    +# MARK: conj
    +def _check_first_conj(node: AbstractNode) -> Optional[Successor]:
    +    args = (node,)
    +    successors = node._successors.get('conj')
    +    if not successors:
    +        return None
    +    return successors.get(args)
    +
    +
    +def _conj_first(node: AbstractNode) -> Node:
    +    new_node = Node._create_resultant(axes_names=node.axes_names,
    +                                      name='conj',
    +                                      network=node._network,
    +                                      tensor=node.tensor.conj(),
    +                                      edges=node._edges,
    +                                      node1_list=node.is_node1())
    +    
    +    # Create successor
    +    net = node._network
    +    args = (node,)
    +    successor = Successor(node_ref=node.node_ref(),
    +                          index=node._tensor_info['index'],
    +                          child=new_node)
    +
    +    # Add successor to parent
    +    if 'conj' in node._successors:
    +        node._successors['conj'].update({args: successor})
    +    else:
    +        node._successors['conj'] = {args: successor}
    +
    +    # Add operation to list of performed operations of TN
    +    net._seq_ops.append(('conj', args))
    +
    +    # Record in inverse_memory while tracing
    +    if net._tracing:
    +        node._record_in_inverse_memory()
    +
    +    return new_node
    +
    +
    +def _conj_next(successor: Successor, node: AbstractNode) -> Node:
    +    tensor = node._direct_get_tensor(successor.node_ref,
    +                                     successor.index)
    +    child = successor.child
    +    child._direct_set_tensor(tensor.conj())
    +
    +    # Record in inverse_memory while contracting, if network is traced
    +    # (to delete memory if possible)
    +    if node._network._traced:
    +        node._check_inverse_memory(successor.node_ref)
    +    
    +    return child
    +
    +
    +conj_op = Operation('conj',
    +                    _check_first_conj,
    +                    _conj_first,
    +                    _conj_next)
    +
    +
    +def conj(node: AbstractNode) -> Node:
    +    """
    +    Returns a view of the node's tensor with a flipped conjugate bit. If the
    +    node has a non-complex dtype, this function returns a new node with the
    +    same tensor.
    +
    +    See `conj `_
    +    in the **PyTorch** documentation.
    +
    +    Parameters
    +    ----------
    +    node : Node or ParamNode
    +        Node that is to be conjugated.
    +
    +    Returns
    +    -------
    +    Node
    +    
    +    Examples
    +    --------
    +    >>> nodeA = tk.randn((3, 3), dtype=torch.complex64)
    +    >>> conjA = tk.conj(nodeA)
    +    >>> conjA.is_conj()
    +    True
    +    """
    +    return conj_op(node)
    +
    +
    +conj_node = copy_func(conj)
    +conj_node.__doc__ = \
    +    """
    +    Returns a view of the node's tensor with a flipped conjugate bit. If the
    +    node has a non-complex dtype, this function returns a new node with the
    +    same tensor.
    +
    +    See `conj `_
    +    in the **PyTorch** documentation.
    +
    +    Returns
    +    -------
    +    Node
    +    
    +    Examples
    +    --------
    +    >>> nodeA = tk.randn((3, 3), dtype=torch.complex64)
    +    >>> conjA = nodeA.conj()
    +    >>> conjA.is_conj()
    +    True
    +    """
    +
    +AbstractNode.conj = conj_node
    +
    +
     ###############################################################################
     #                            NODE-LIKE OPERATIONS                             #
     ###############################################################################
    @@ -1599,9 +1716,12 @@ def _split_first(node: AbstractNode,
             u = u[..., :rank]
             s = s[..., :rank]
             vh = vh[..., :rank, :]
    +        
    +        if u.is_complex():
    +            s = s.to(u.dtype)
     
             if mode == 'svdr':
    -            phase = torch.sign(torch.randn(s.shape))
    +            phase = torch.sgn(torch.randn_like(s))
                 phase = torch.diag_embed(phase)
                 u = u @ phase
                 vh = phase @ vh
    @@ -1798,9 +1918,12 @@ def _split_next(successor: Successor,
             u = u[..., :rank]
             s = s[..., :rank]
             vh = vh[..., :rank, :]
    +        
    +        if u.is_complex():
    +            s = s.to(u.dtype)
     
             if mode == 'svdr':
    -            phase = torch.sign(torch.randn(s.shape))
    +            phase = torch.sgn(torch.randn_like(s))
                 phase = torch.diag_embed(phase)
                 u = u @ phase
                 vh = phase @ vh
    diff --git a/tensorkrowch/utils.py b/tensorkrowch/utils.py
    index 15f3dee..db2554b 100644
    --- a/tensorkrowch/utils.py
    +++ b/tensorkrowch/utils.py
    @@ -307,14 +307,16 @@ def split_sequence_into_regions(lst: Sequence[int]) -> List[List[int]]:
         regions.append(current_region)
         return regions
     
    -def random_unitary(n, device: Optional[torch.device] = None):
    +def random_unitary(n,
    +                   device: Optional[torch.device] = None,
    +                   dtype: Optional[torch.dtype] = None):
         """
         Returns random unitary matrix from the Haar measure of size n x n.
         
         Unitary matrix is created as described in this `paper
         `_.
         """
    -    mat = torch.randn(n, n, device=device)
    +    mat = torch.randn(n, n, device=device, dtype=dtype)
         q, r = torch.linalg.qr(mat)
         d = torch.diagonal(r)
         ph = d / d.abs()
    diff --git a/tests/decompositions/test_svd_decompositions.py b/tests/decompositions/test_svd_decompositions.py
    index c454602..7adb682 100644
    --- a/tests/decompositions/test_svd_decompositions.py
    +++ b/tests/decompositions/test_svd_decompositions.py
    @@ -14,72 +14,76 @@ class TestSVDDecompositions:
         
         def test_vec_to_mps(self):
             for renormalize in [True, False]:
    -            for i in range(1, 6):
    -                dims = torch.randint(low=2, high=10, size=(i,))
    -                vec = torch.randn(*dims) * 1e-5
    -                for j in range(i + 1):
    -                    tensors = tk.decompositions.vec_to_mps(vec=vec,
    -                                                           n_batches=j,
    -                                                           rank=5,
    -                                                           renormalize=renormalize)
    -                    
    -                    for k, tensor in enumerate(tensors):
    -                        assert tensor.shape[:j] == vec.shape[:j]
    -                        if k == 0:
    -                            if j < i:
    -                                assert tensor.shape[j] == vec.shape[j]
    -                            if j < i - 1:
    -                                assert tensor.shape[j + 1] <= 5
    -                        elif k < (i - j - 1):
    -                            assert tensor.shape[j + 1] == vec.shape[j + k]
    -                            assert tensor.shape[j + 2] <= 5
    -                        else:
    -                            assert tensor.shape[j + 1] == vec.shape[j + k]
    -                            
    -                    bond_dims = [tensor.shape[-1] for tensor in tensors[:-1]]
    -                    
    -                    if i - j > 0:
    -                        mps = tk.models.MPSData(tensors=tensors,
    -                                                n_batches=j)
    -                        assert mps.phys_dim == dims[j:].tolist()
    -                        assert mps.bond_dim == bond_dims
    -                    
    -                    if j == 0:
    -                        mps = tk.models.MPS(tensors=tensors)
    -                        assert mps.phys_dim == dims.tolist()
    -                        assert mps.bond_dim == bond_dims
    +            for dtype in [torch.float32, torch.complex64]:
    +                for i in range(1, 6):
    +                    dims = torch.randint(low=2, high=10, size=(i,))
    +                    vec = torch.randn(*dims, dtype=dtype) * 1e-5
    +                    for j in range(i + 1):
    +                        tensors = tk.decompositions.vec_to_mps(vec=vec,
    +                                                               n_batches=j,
    +                                                               rank=5,
    +                                                               renormalize=renormalize)
    +                        
    +                        for k, tensor in enumerate(tensors):
    +                            assert tensor.shape[:j] == vec.shape[:j]
    +                            assert tensor.dtype == dtype
    +                            if k == 0:
    +                                if j < i:
    +                                    assert tensor.shape[j] == vec.shape[j]
    +                                if j < i - 1:
    +                                    assert tensor.shape[j + 1] <= 5
    +                            elif k < (i - j - 1):
    +                                assert tensor.shape[j + 1] == vec.shape[j + k]
    +                                assert tensor.shape[j + 2] <= 5
    +                            else:
    +                                assert tensor.shape[j + 1] == vec.shape[j + k]
    +                                
    +                        bond_dims = [tensor.shape[-1] for tensor in tensors[:-1]]
    +                        
    +                        if i - j > 0:
    +                            mps = tk.models.MPSData(tensors=tensors,
    +                                                    n_batches=j)
    +                            assert mps.phys_dim == dims[j:].tolist()
    +                            assert mps.bond_dim == bond_dims
    +                        
    +                        if j == 0:
    +                            mps = tk.models.MPS(tensors=tensors)
    +                            assert mps.phys_dim == dims.tolist()
    +                            assert mps.bond_dim == bond_dims
         
         def test_vec_to_mps_accuracy(self):
             for renormalize in [True, False]:
    -            for i in range(1, 6):
    -                dims = torch.randint(low=2, high=10, size=(i,))
    -                vec = torch.randn(*dims) * 1e-1
    -                for j in range(i + 1):
    -                    tensors = tk.decompositions.vec_to_mps(vec=vec,
    -                                                           n_batches=j,
    -                                                           cum_percentage=0.9999,
    -                                                           renormalize=renormalize)
    -                            
    -                    bond_dims = [tensor.shape[-1] for tensor in tensors[:-1]]
    -                    
    -                    if i - j > 0:
    -                        mps = tk.models.MPSData(tensors=tensors,
    -                                                n_batches=j)
    -                        assert mps.phys_dim == dims[j:].tolist()
    -                        assert mps.bond_dim == bond_dims
    -                    
    -                    if j == 0:
    -                        mps = tk.models.MPS(tensors=tensors)
    -                        assert mps.phys_dim == dims.tolist()
    -                        assert mps.bond_dim == bond_dims
    -                    
    -                    approx_vec = mps.left_node
    -                    for node in mps.mats_env + [mps.right_node]:
    -                        approx_vec @= node
    -                    
    -                    approx_vec = approx_vec.tensor
    -                    diff = vec - approx_vec
    -                    assert diff.norm() < 1e-1
    +            for dtype in [torch.float32, torch.complex64]:
    +                for i in range(1, 6):
    +                    dims = torch.randint(low=2, high=10, size=(i,))
    +                    vec = torch.randn(*dims, dtype=dtype) * 1e-1
    +                    for j in range(i + 1):
    +                        tensors = tk.decompositions.vec_to_mps(
    +                            vec=vec,
    +                            n_batches=j,
    +                            cum_percentage=0.9999,
    +                            renormalize=renormalize)
    +                                
    +                        bond_dims = [tensor.shape[-1] for tensor in tensors[:-1]]
    +                        
    +                        if i - j > 0:
    +                            mps = tk.models.MPSData(tensors=tensors,
    +                                                    n_batches=j)
    +                            assert mps.phys_dim == dims[j:].tolist()
    +                            assert mps.bond_dim == bond_dims
    +                        
    +                        if j == 0:
    +                            mps = tk.models.MPS(tensors=tensors)
    +                            assert mps.phys_dim == dims.tolist()
    +                            assert mps.bond_dim == bond_dims
    +                        
    +                        approx_vec = mps.left_node
    +                        for node in mps.mats_env + [mps.right_node]:
    +                            approx_vec @= node
    +                        
    +                        approx_vec = approx_vec.tensor
    +                        diff = vec - approx_vec
    +                        assert diff.norm() < 1e-1
         
         def test_mat_to_mpo(self):
             for renormalize in [True, False]:
    @@ -164,7 +168,7 @@ def test_mat_to_mpo_permuted_accuracy(self):
                 diff = mat - approx_mat
                 assert diff.norm() < 1e-1
     
    -    def test_mat_to_mpo_permuted_diff_dimsaccuracy(self):
    +    def test_mat_to_mpo_permuted_diff_dims_accuracy(self):
             for renormalize in [True, False]:
                 # inputs (2, 3, 4, 2) x outputs (3, 5, 7, 2)
                 dims = [2, 3, 4, 2, 3, 5, 7, 2]
    diff --git a/tests/models/test_mps.py b/tests/models/test_mps.py
    index a3c0140..5fab322 100644
    --- a/tests/models/test_mps.py
    +++ b/tests/models/test_mps.py
    @@ -39,7 +39,6 @@ def test_initialize_with_tensors(self):
         
         def test_initialize_with_tensors_cuda(self):
             device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    -        
             for n in [1, 2, 5]:
                 # PBC
                 tensors = [torch.randn(10, 2, 10, device=device) for _ in range(n)]
    @@ -63,6 +62,34 @@ def test_initialize_with_tensors_cuda(self):
                 for node in mps.mats_env:
                     assert node.device == device
         
    +    def test_initialize_with_tensors_complex(self):
    +        for n in [1, 2, 5]:
    +            # PBC
    +            tensors = [torch.randn(10, 2, 10,
    +                                   dtype=torch.complex64) for _ in range(n)]
    +            mps = tk.models.MPS(tensors=tensors)
    +            assert mps.n_features == n
    +            assert mps.boundary == 'pbc'
    +            assert mps.phys_dim == [2] * n
    +            assert mps.bond_dim == [10] * n
    +            for node in mps.mats_env:
    +                assert node.dtype == torch.complex64
    +                assert node.is_complex()
    +            
    +            # OBC
    +            tensors = [torch.randn(10, 2, 10,
    +                                   dtype=torch.complex64) for _ in range(n)]
    +            tensors[0] = tensors[0][0]
    +            tensors[-1] = tensors[-1][..., 0]
    +            mps = tk.models.MPS(tensors=tensors)
    +            assert mps.n_features == n
    +            assert mps.boundary == 'obc'
    +            assert mps.phys_dim == [2] * n
    +            assert mps.bond_dim == [10] * (n - 1)
    +            for node in mps.mats_env:
    +                assert node.dtype == torch.complex64
    +                assert node.is_complex()
    +    
         def test_initialize_with_tensors_ignore_rest(self):
             tensors = [torch.randn(10, 2, 10) for _ in range(10)]
             mps = tk.models.MPS(tensors=tensors,
    @@ -104,23 +131,23 @@ def test_initialize_init_method(self):
                     mps = tk.models.MPS(boundary='pbc',
                                         n_features=n,
                                         phys_dim=2,
    -                                    bond_dim=2,
    +                                    bond_dim=5,
                                         init_method=init_method)
                     assert mps.n_features == n
                     assert mps.boundary == 'pbc'
                     assert mps.phys_dim == [2] * n
    -                assert mps.bond_dim == [2] * n
    +                assert mps.bond_dim == [5] * n
                     
                     # OBC
                     mps = tk.models.MPS(boundary='obc',
                                         n_features=n,
                                         phys_dim=2,
    -                                    bond_dim=2,
    +                                    bond_dim=5,
                                         init_method=init_method)
                     assert mps.n_features == n
                     assert mps.boundary == 'obc'
                     assert mps.phys_dim == [2] * n
    -                assert mps.bond_dim == [2] * (n - 1)
    +                assert mps.bond_dim == [5] * (n - 1)
                     
                     assert torch.equal(mps.left_node.tensor[0],
                                        torch.ones_like(mps.left_node.tensor)[0])
    @@ -142,13 +169,13 @@ def test_initialize_init_method_cuda(self):
                     mps = tk.models.MPS(boundary='pbc',
                                         n_features=n,
                                         phys_dim=2,
    -                                    bond_dim=2,
    +                                    bond_dim=5,
                                         init_method=init_method,
                                         device=device)
                     assert mps.n_features == n
                     assert mps.boundary == 'pbc'
                     assert mps.phys_dim == [2] * n
    -                assert mps.bond_dim == [2] * n
    +                assert mps.bond_dim == [5] * n
                     for node in mps.mats_env:
                         assert node.device == device
                     
    @@ -156,13 +183,13 @@ def test_initialize_init_method_cuda(self):
                     mps = tk.models.MPS(boundary='obc',
                                         n_features=n,
                                         phys_dim=2,
    -                                    bond_dim=2,
    +                                    bond_dim=5,
                                         init_method=init_method,
                                         device=device)
                     assert mps.n_features == n
                     assert mps.boundary == 'obc'
                     assert mps.phys_dim == [2] * n
    -                assert mps.bond_dim == [2] * (n - 1)
    +                assert mps.bond_dim == [5] * (n - 1)
                     for node in mps.mats_env:
                         assert node.device == device
                     
    @@ -176,6 +203,51 @@ def test_initialize_init_method_cuda(self):
                     assert torch.equal(mps.right_node.tensor[1:],
                                        torch.zeros_like(mps.right_node.tensor)[1:])
         
    +    def test_initialize_init_method_complex(self):
    +        methods = ['zeros', 'ones', 'copy', 'rand', 'randn',
    +                   'randn_eye', 'unit', 'canonical']
    +        for n in [1, 2, 5]:
    +            for init_method in methods:
    +                # PBC
    +                mps = tk.models.MPS(boundary='pbc',
    +                                    n_features=n,
    +                                    phys_dim=2,
    +                                    bond_dim=5,
    +                                    init_method=init_method,
    +                                    dtype=torch.complex64)
    +                assert mps.n_features == n
    +                assert mps.boundary == 'pbc'
    +                assert mps.phys_dim == [2] * n
    +                assert mps.bond_dim == [5] * n
    +                for node in mps.mats_env:
    +                    assert node.dtype == torch.complex64
    +                    assert node.is_complex()
    +                
    +                # OBC
    +                mps = tk.models.MPS(boundary='obc',
    +                                    n_features=n,
    +                                    phys_dim=2,
    +                                    bond_dim=5,
    +                                    init_method=init_method,
    +                                    dtype=torch.complex64)
    +                assert mps.n_features == n
    +                assert mps.boundary == 'obc'
    +                assert mps.phys_dim == [2] * n
    +                assert mps.bond_dim == [5] * (n - 1)
    +                for node in mps.mats_env:
    +                    assert node.dtype == torch.complex64
    +                    assert node.is_complex()
    +                
    +                assert torch.equal(mps.left_node.tensor[0],
    +                                   torch.ones_like(mps.left_node.tensor)[0])
    +                assert torch.equal(mps.left_node.tensor[1:],
    +                                   torch.zeros_like(mps.left_node.tensor)[1:])
    +                
    +                assert torch.equal(mps.right_node.tensor[0],
    +                                   torch.ones_like(mps.right_node.tensor)[0])
    +                assert torch.equal(mps.right_node.tensor[1:],
    +                                   torch.zeros_like(mps.right_node.tensor)[1:])
    +    
         def test_initialize_canonical(self):
             for n in [1, 2, 5]:
                 # PBC
    @@ -203,6 +275,7 @@ def test_initialize_canonical(self):
                 assert mps.bond_dim == [2] * (n - 1)
                 
                 # Check it has norm == 2**n
    +            norm = mps.norm()
                 assert mps.norm().isclose(torch.tensor(2. ** n).sqrt())
                 # Norm is close to 2**n if bond dimension is <= than
                 # physical dimension, otherwise, it will not be exactly 2**n
    @@ -241,11 +314,50 @@ def test_initialize_canonical_cuda(self):
                     assert node.device == device
                 
                 # Check it has norm == 2**n
    -            norm = mps.norm()
                 assert mps.norm().isclose(torch.tensor(2. ** n).sqrt())
                 # Norm is close to 2**n if bond dimension is <= than
                 # physical dimension, otherwise, it will not be exactly 2**n
         
    +    def test_initialize_canonical_complex(self):
    +        for n in [1, 2, 5]:
    +            # PBC
    +            mps = tk.models.MPS(boundary='pbc',
    +                                n_features=n,
    +                                phys_dim=2,
    +                                bond_dim=2,
    +                                init_method='canonical',
    +                                dtype=torch.complex64)
    +            assert mps.n_features == n
    +            assert mps.boundary == 'pbc'
    +            assert mps.phys_dim == [2] * n
    +            assert mps.bond_dim == [2] * n
    +            for node in mps.mats_env:
    +                assert node.dtype == torch.complex64
    +                assert node.is_complex()
    +            
    +            # For PBC norm does not have to be 2**n always
    +            
    +            # OBC
    +            mps = tk.models.MPS(boundary='obc',
    +                                n_features=n,
    +                                phys_dim=2,
    +                                bond_dim=2,
    +                                init_method='canonical',
    +                                dtype=torch.complex64)
    +            assert mps.n_features == n
    +            assert mps.boundary == 'obc'
    +            assert mps.phys_dim == [2] * n
    +            assert mps.bond_dim == [2] * (n - 1)
    +            for node in mps.mats_env:
    +                assert node.dtype == torch.complex64
    +                assert node.is_complex()
    +            
    +            # Check it has norm == 2**n
    +            assert mps.norm().isclose(
    +                torch.tensor(2. ** n, dtype=torch.complex64).sqrt())
    +            # Norm is close to 2**n if bond dimension is <= than
    +            # physical dimension, otherwise, it will not be exactly 2**n
    +    
         def test_in_and_out_features(self):
             tensors = [torch.randn(10, 2, 10) for _ in range(10)]
             mps = tk.models.MPS(tensors=tensors,
    @@ -495,6 +607,52 @@ def test_all_algorithms_cuda(self):
                                         result.sum().backward()
                                         for node in mps.mats_env:
                                             assert node.grad is not None
    +    
    +    def test_all_algorithms_complex(self):
    +        for n_features in [1, 2, 3, 4, 10]:
    +            for boundary in ['obc', 'pbc']:
    +                # batch x n_features x feature_dim
    +                example = torch.randn(1, n_features, 5, dtype=torch.complex64)
    +                data = torch.randn(100, n_features, 5, dtype=torch.complex64)
    +                
    +                mps = tk.models.MPS(n_features=n_features,
    +                                    phys_dim=5,
    +                                    bond_dim=2,
    +                                    boundary=boundary,
    +                                    dtype=torch.complex64)
    +
    +                for auto_stack in [True, False]:
    +                    for auto_unbind in [True, False]:
    +                        for inline_input in [True, False]:
    +                            for inline_mats in [True, False]:
    +                                for renormalize in [True, False]:
    +                                    mps.auto_stack = auto_stack
    +                                    mps.auto_unbind = auto_unbind
    +
    +                                    mps.trace(example,
    +                                              inline_input=inline_input,
    +                                              inline_mats=inline_mats,
    +                                              renormalize=renormalize)
    +                                    result = mps(data,
    +                                                 inline_input=inline_input,
    +                                                 inline_mats=inline_mats,
    +                                                 renormalize=renormalize)
    +
    +                                    assert result.shape == (100,)
    +                                    assert len(mps.edges) == 0
    +                                    if boundary == 'obc':
    +                                        assert len(mps.leaf_nodes) == n_features + 2
    +                                    else:
    +                                        assert len(mps.leaf_nodes) == n_features
    +                                    assert len(mps.data_nodes) == n_features
    +                                    if not inline_input and auto_stack:
    +                                        assert len(mps.virtual_nodes) == 2
    +                                    else:
    +                                        assert len(mps.virtual_nodes) == 1
    +                                    
    +                                    result.sum().abs().backward()
    +                                    for node in mps.mats_env:
    +                                        assert node.grad is not None
     
         def test_all_algorithms_diff_in_dim(self):
             for n_features in [1, 2, 3, 4, 6]:
    @@ -802,7 +960,8 @@ def test_all_algorithms_marginalize_with_list_matrices(self):
                                         for node in mps.mats_env:
                                             assert node.grad is not None
         
    -    def test_all_algorithms_marginalize_with_matrix(self):
    +    def test_all_algorithms_marginalize_with_list_matrices_cuda(self):
    +        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
             for n_features in [1, 2, 3, 4, 10]:
                 for boundary in ['obc', 'pbc']:
                     in_features = torch.randint(low=0,
    @@ -811,8 +970,8 @@ def test_all_algorithms_marginalize_with_matrix(self):
                     in_features = list(set(in_features))
                     
                     # batch x n_features x feature_dim
    -                example = torch.randn(1, len(in_features), 5)
    -                data = torch.randn(100, len(in_features), 5)
    +                example = torch.randn(1, len(in_features), 5, device=device)
    +                data = torch.randn(100, len(in_features), 5, device=device)
                     
                     if example.numel() == 0:
                         example = None
    @@ -822,9 +981,11 @@ def test_all_algorithms_marginalize_with_matrix(self):
                                         phys_dim=5,
                                         bond_dim=2,
                                         boundary=boundary,
    -                                    in_features=in_features)
    +                                    in_features=in_features,
    +                                    device=device)
                     
    -                embedding_matrix = torch.randn(5, 5)
    +                embedding_matrices = [torch.randn(5, 5, device=device)
    +                                      for _ in range(len(mps.out_features))]
     
                     for auto_stack in [True, False]:
                         for auto_unbind in [True, False]:
    @@ -839,13 +1000,13 @@ def test_all_algorithms_marginalize_with_matrix(self):
                                                   inline_mats=inline_mats,
                                                   renormalize=renormalize,
                                                   marginalize_output=True,
    -                                              embedding_matrices=embedding_matrix)
    +                                              embedding_matrices=embedding_matrices)
                                         result = mps(data,
                                                      inline_input=inline_input,
                                                      inline_mats=inline_mats,
                                                      renormalize=renormalize,
                                                      marginalize_output=True,
    -                                                 embedding_matrices=embedding_matrix)
    +                                                 embedding_matrices=embedding_matrices)
     
                                         if in_features:
                                             assert result.shape == (100, 100)
    @@ -873,8 +1034,7 @@ def test_all_algorithms_marginalize_with_matrix(self):
                                         for node in mps.mats_env:
                                             assert node.grad is not None
         
    -    def test_all_algorithms_marginalize_with_matrix_cuda(self):
    -        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    +    def test_all_algorithms_marginalize_with_list_matrices_complex(self):
             for n_features in [1, 2, 3, 4, 10]:
                 for boundary in ['obc', 'pbc']:
                     in_features = torch.randint(low=0,
    @@ -883,8 +1043,8 @@ def test_all_algorithms_marginalize_with_matrix_cuda(self):
                     in_features = list(set(in_features))
                     
                     # batch x n_features x feature_dim
    -                example = torch.randn(1, len(in_features), 5, device=device)
    -                data = torch.randn(100, len(in_features), 5, device=device)
    +                example = torch.randn(1, len(in_features), 5, dtype=torch.complex64)
    +                data = torch.randn(100, len(in_features), 5, dtype=torch.complex64)
                     
                     if example.numel() == 0:
                         example = None
    @@ -895,9 +1055,10 @@ def test_all_algorithms_marginalize_with_matrix_cuda(self):
                                         bond_dim=2,
                                         boundary=boundary,
                                         in_features=in_features,
    -                                    device=device)
    +                                    dtype=torch.complex64)
                     
    -                embedding_matrix = torch.randn(5, 5, device=device)
    +                embedding_matrices = [torch.randn(5, 5, dtype=torch.complex64)
    +                                      for _ in range(len(mps.out_features))]
     
                     for auto_stack in [True, False]:
                         for auto_unbind in [True, False]:
    @@ -912,13 +1073,13 @@ def test_all_algorithms_marginalize_with_matrix_cuda(self):
                                                   inline_mats=inline_mats,
                                                   renormalize=renormalize,
                                                   marginalize_output=True,
    -                                              embedding_matrices=embedding_matrix)
    +                                              embedding_matrices=embedding_matrices)
                                         result = mps(data,
                                                      inline_input=inline_input,
                                                      inline_mats=inline_mats,
                                                      renormalize=renormalize,
                                                      marginalize_output=True,
    -                                                 embedding_matrices=embedding_matrix)
    +                                                 embedding_matrices=embedding_matrices)
     
                                         if in_features:
                                             assert result.shape == (100, 100)
    @@ -942,106 +1103,322 @@ def test_all_algorithms_marginalize_with_matrix_cuda(self):
                                             
                                         assert len(mps.data_nodes) == len(in_features)
                                         
    -                                    result.sum().backward()
    +                                    result.sum().abs().backward()
                                         for node in mps.mats_env:
                                             assert node.grad is not None
         
    -    def test_all_algorithms_marginalize_with_mpo(self):
    -        for n_features in [1, 2, 3, 4, 10]:
    -            for mps_boundary in ['obc', 'pbc']:
    -                for mpo_boundary in ['obc', 'pbc']:
    -                    in_features = torch.randint(low=0,
    -                                                high=n_features,
    -                                                size=(n_features // 2,)).tolist()
    -                    in_features = list(set(in_features))
    -                    
    -                    # batch x n_features x feature_dim
    -                    example = torch.randn(1, len(in_features), 5)
    -                    data = torch.randn(100, len(in_features), 5)
    -                    
    -                    if example.numel() == 0:
    -                        example = None
    -                        data = None
    -                    
    -                    mps = tk.models.MPS(n_features=n_features,
    -                                        phys_dim=5,
    -                                        bond_dim=2,
    -                                        boundary=mps_boundary,
    -                                        in_features=in_features)
    -                    
    -                    mpo = tk.models.MPO(n_features=n_features - len(in_features),
    -                                        in_dim=5,
    -                                        out_dim=5,
    -                                        bond_dim=2,
    -                                        boundary=mpo_boundary)
    -                    
    -                    # De-parameterize MPO nodes to only train MPS nodes
    -                    mpo = mpo.parameterize(set_param=False, override=True)
    -
    -                    for auto_stack in [True, False]:
    -                        for auto_unbind in [True, False]:
    -                            for inline_input in [True, False]:
    -                                for inline_mats in [True, False]:
    -                                    for renormalize in [True, False]:
    -                                        mps.auto_stack = auto_stack
    -                                        mps.auto_unbind = auto_unbind
    -
    -                                        mps.trace(example,
    -                                                  inline_input=inline_input,
    -                                                  inline_mats=inline_mats,
    -                                                  renormalize=renormalize,
    -                                                  marginalize_output=True,
    -                                                  mpo=mpo)
    -                                        result = mps(data,
    -                                                     inline_input=inline_input,
    -                                                     inline_mats=inline_mats,
    -                                                     renormalize=renormalize,
    -                                                     marginalize_output=True,
    -                                                     mpo=mpo)
    -                                        
    -                                        if in_features:
    -                                            assert result.shape == (100, 100)
    -                                        else:
    -                                            assert result.shape == tuple()
    -                                        
    -                                        if mps_boundary == 'obc':
    -                                            if mpo_boundary == 'obc':
    -                                                leaf = (n_features + 2) + \
    -                                                    (n_features - len(in_features) + 2)
    -                                                assert len(mps.leaf_nodes) == leaf
    -                                            else:
    -                                                leaf = (n_features + 2) + \
    -                                                    (n_features - len(in_features))
    -                                                assert len(mps.leaf_nodes) == leaf
    -                                        else:
    -                                            if mpo_boundary == 'obc':
    -                                                leaf = n_features + \
    -                                                    (n_features - len(in_features) + 2)
    -                                                assert len(mps.leaf_nodes) == leaf
    -                                            else:
    -                                                leaf = n_features + \
    -                                                    (n_features - len(in_features))
    -                                                assert len(mps.leaf_nodes) == leaf
    -                                        
    -                                        result.sum().backward()
    -                                        for node in mps.mats_env:
    -                                            assert node.grad is not None
    -                                        for node in mpo.mats_env:
    -                                            assert node.tensor.grad is None
    -    
    -    def test_all_algorithms_marginalize_with_mpo_cuda(self):
    -        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    +    def test_all_algorithms_marginalize_with_matrix(self):
             for n_features in [1, 2, 3, 4, 10]:
    -            for mps_boundary in ['obc', 'pbc']:
    -                for mpo_boundary in ['obc', 'pbc']:
    -                    in_features = torch.randint(low=0,
    +            for boundary in ['obc', 'pbc']:
    +                in_features = torch.randint(low=0,
                                                 high=n_features,
                                                 size=(n_features // 2,)).tolist()
                     in_features = list(set(in_features))
                     
                     # batch x n_features x feature_dim
    -                example = torch.randn(1, len(in_features), 5, device=device)
    -                data = torch.randn(100, len(in_features), 5, device=device)
    +                example = torch.randn(1, len(in_features), 5)
    +                data = torch.randn(100, len(in_features), 5)
    +                
    +                if example.numel() == 0:
    +                    example = None
    +                    data = None
    +                
    +                mps = tk.models.MPS(n_features=n_features,
    +                                    phys_dim=5,
    +                                    bond_dim=2,
    +                                    boundary=boundary,
    +                                    in_features=in_features)
    +                
    +                embedding_matrix = torch.randn(5, 5)
    +
    +                for auto_stack in [True, False]:
    +                    for auto_unbind in [True, False]:
    +                        for inline_input in [True, False]:
    +                            for inline_mats in [True, False]:
    +                                for renormalize in [True, False]:
    +                                    mps.auto_stack = auto_stack
    +                                    mps.auto_unbind = auto_unbind
    +
    +                                    mps.trace(example,
    +                                              inline_input=inline_input,
    +                                              inline_mats=inline_mats,
    +                                              renormalize=renormalize,
    +                                              marginalize_output=True,
    +                                              embedding_matrices=embedding_matrix)
    +                                    result = mps(data,
    +                                                 inline_input=inline_input,
    +                                                 inline_mats=inline_mats,
    +                                                 renormalize=renormalize,
    +                                                 marginalize_output=True,
    +                                                 embedding_matrices=embedding_matrix)
    +
    +                                    if in_features:
    +                                        assert result.shape == (100, 100)
    +                                        
    +                                        if not inline_input and auto_stack:
    +                                            assert len(mps.virtual_nodes) == \
    +                                                (2 + 2 * len(mps.out_features))
    +                                        else:
    +                                            assert len(mps.virtual_nodes) == \
    +                                                (1 + 2 * len(mps.out_features))
    +                                            
    +                                    else:
    +                                        assert result.shape == tuple()
    +                                        assert len(mps.virtual_nodes) == \
    +                                            2 * len(mps.out_features)
    +                                    
    +                                    if boundary == 'obc':
    +                                        assert len(mps.leaf_nodes) == n_features + 2
    +                                    else:
    +                                        assert len(mps.leaf_nodes) == n_features
    +                                        
    +                                    assert len(mps.data_nodes) == len(in_features)
    +                                    
    +                                    result.sum().backward()
    +                                    for node in mps.mats_env:
    +                                        assert node.grad is not None
    +    
    +    def test_all_algorithms_marginalize_with_matrix_cuda(self):
    +        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    +        for n_features in [1, 2, 3, 4, 10]:
    +            for boundary in ['obc', 'pbc']:
    +                in_features = torch.randint(low=0,
    +                                            high=n_features,
    +                                            size=(n_features // 2,)).tolist()
    +                in_features = list(set(in_features))
    +                
    +                # batch x n_features x feature_dim
    +                example = torch.randn(1, len(in_features), 5, device=device)
    +                data = torch.randn(100, len(in_features), 5, device=device)
    +                
    +                if example.numel() == 0:
    +                    example = None
    +                    data = None
    +                
    +                mps = tk.models.MPS(n_features=n_features,
    +                                    phys_dim=5,
    +                                    bond_dim=2,
    +                                    boundary=boundary,
    +                                    in_features=in_features,
    +                                    device=device)
    +                
    +                embedding_matrix = torch.randn(5, 5, device=device)
    +
    +                for auto_stack in [True, False]:
    +                    for auto_unbind in [True, False]:
    +                        for inline_input in [True, False]:
    +                            for inline_mats in [True, False]:
    +                                for renormalize in [True, False]:
    +                                    mps.auto_stack = auto_stack
    +                                    mps.auto_unbind = auto_unbind
    +
    +                                    mps.trace(example,
    +                                              inline_input=inline_input,
    +                                              inline_mats=inline_mats,
    +                                              renormalize=renormalize,
    +                                              marginalize_output=True,
    +                                              embedding_matrices=embedding_matrix)
    +                                    result = mps(data,
    +                                                 inline_input=inline_input,
    +                                                 inline_mats=inline_mats,
    +                                                 renormalize=renormalize,
    +                                                 marginalize_output=True,
    +                                                 embedding_matrices=embedding_matrix)
    +
    +                                    if in_features:
    +                                        assert result.shape == (100, 100)
    +                                        
    +                                        if not inline_input and auto_stack:
    +                                            assert len(mps.virtual_nodes) == \
    +                                                (2 + 2 * len(mps.out_features))
    +                                        else:
    +                                            assert len(mps.virtual_nodes) == \
    +                                                (1 + 2 * len(mps.out_features))
    +                                            
    +                                    else:
    +                                        assert result.shape == tuple()
    +                                        assert len(mps.virtual_nodes) == \
    +                                            2 * len(mps.out_features)
    +                                    
    +                                    if boundary == 'obc':
    +                                        assert len(mps.leaf_nodes) == n_features + 2
    +                                    else:
    +                                        assert len(mps.leaf_nodes) == n_features
    +                                        
    +                                    assert len(mps.data_nodes) == len(in_features)
    +                                    
    +                                    result.sum().backward()
    +                                    for node in mps.mats_env:
    +                                        assert node.grad is not None
    +    
    +    def test_all_algorithms_marginalize_with_matrix_complex(self):
    +        for n_features in [1, 2, 3, 4, 10]:
    +            for boundary in ['obc', 'pbc']:
    +                in_features = torch.randint(low=0,
    +                                            high=n_features,
    +                                            size=(n_features // 2,)).tolist()
    +                in_features = list(set(in_features))
    +                
    +                # batch x n_features x feature_dim
    +                example = torch.randn(1, len(in_features), 5, dtype=torch.complex64)
    +                data = torch.randn(100, len(in_features), 5, dtype=torch.complex64)
    +                
    +                if example.numel() == 0:
    +                    example = None
    +                    data = None
    +                
    +                mps = tk.models.MPS(n_features=n_features,
    +                                    phys_dim=5,
    +                                    bond_dim=2,
    +                                    boundary=boundary,
    +                                    in_features=in_features,
    +                                    dtype=torch.complex64)
    +                
    +                embedding_matrix = torch.randn(5, 5, dtype=torch.complex64)
    +
    +                for auto_stack in [True, False]:
    +                    for auto_unbind in [True, False]:
    +                        for inline_input in [True, False]:
    +                            for inline_mats in [True, False]:
    +                                for renormalize in [True, False]:
    +                                    mps.auto_stack = auto_stack
    +                                    mps.auto_unbind = auto_unbind
    +
    +                                    mps.trace(example,
    +                                              inline_input=inline_input,
    +                                              inline_mats=inline_mats,
    +                                              renormalize=renormalize,
    +                                              marginalize_output=True,
    +                                              embedding_matrices=embedding_matrix)
    +                                    result = mps(data,
    +                                                 inline_input=inline_input,
    +                                                 inline_mats=inline_mats,
    +                                                 renormalize=renormalize,
    +                                                 marginalize_output=True,
    +                                                 embedding_matrices=embedding_matrix)
    +
    +                                    if in_features:
    +                                        assert result.shape == (100, 100)
    +                                        
    +                                        if not inline_input and auto_stack:
    +                                            assert len(mps.virtual_nodes) == \
    +                                                (2 + 2 * len(mps.out_features))
    +                                        else:
    +                                            assert len(mps.virtual_nodes) == \
    +                                                (1 + 2 * len(mps.out_features))
    +                                            
    +                                    else:
    +                                        assert result.shape == tuple()
    +                                        assert len(mps.virtual_nodes) == \
    +                                            2 * len(mps.out_features)
    +                                    
    +                                    if boundary == 'obc':
    +                                        assert len(mps.leaf_nodes) == n_features + 2
    +                                    else:
    +                                        assert len(mps.leaf_nodes) == n_features
    +                                        
    +                                    assert len(mps.data_nodes) == len(in_features)
    +                                    
    +                                    result.sum().abs().backward()
    +                                    for node in mps.mats_env:
    +                                        assert node.grad is not None
    +    
    +    def test_all_algorithms_marginalize_with_mpo(self):
    +        for n_features in [1, 2, 3, 4, 10]:
    +            for mps_boundary in ['obc', 'pbc']:
    +                for mpo_boundary in ['obc', 'pbc']:
    +                    in_features = torch.randint(low=0,
    +                                                high=n_features,
    +                                                size=(n_features // 2,)).tolist()
    +                    in_features = list(set(in_features))
    +                    
    +                    # batch x n_features x feature_dim
    +                    example = torch.randn(1, len(in_features), 5)
    +                    data = torch.randn(100, len(in_features), 5)
    +                    
    +                    if example.numel() == 0:
    +                        example = None
    +                        data = None
    +                    
    +                    mps = tk.models.MPS(n_features=n_features,
    +                                        phys_dim=5,
    +                                        bond_dim=2,
    +                                        boundary=mps_boundary,
    +                                        in_features=in_features)
    +                    
    +                    mpo = tk.models.MPO(n_features=n_features - len(in_features),
    +                                        in_dim=5,
    +                                        out_dim=5,
    +                                        bond_dim=2,
    +                                        boundary=mpo_boundary)
    +                    
    +                    # De-parameterize MPO nodes to only train MPS nodes
    +                    mpo = mpo.parameterize(set_param=False, override=True)
    +
    +                    for auto_stack in [True, False]:
    +                        for auto_unbind in [True, False]:
    +                            for inline_input in [True, False]:
    +                                for inline_mats in [True, False]:
    +                                    for renormalize in [True, False]:
    +                                        mps.auto_stack = auto_stack
    +                                        mps.auto_unbind = auto_unbind
    +
    +                                        mps.trace(example,
    +                                                  inline_input=inline_input,
    +                                                  inline_mats=inline_mats,
    +                                                  renormalize=renormalize,
    +                                                  marginalize_output=True,
    +                                                  mpo=mpo)
    +                                        result = mps(data,
    +                                                     inline_input=inline_input,
    +                                                     inline_mats=inline_mats,
    +                                                     renormalize=renormalize,
    +                                                     marginalize_output=True,
    +                                                     mpo=mpo)
    +                                        
    +                                        if in_features:
    +                                            assert result.shape == (100, 100)
    +                                        else:
    +                                            assert result.shape == tuple()
    +                                        
    +                                        if mps_boundary == 'obc':
    +                                            if mpo_boundary == 'obc':
    +                                                leaf = (n_features + 2) + \
    +                                                    (n_features - len(in_features) + 2)
    +                                                assert len(mps.leaf_nodes) == leaf
    +                                            else:
    +                                                leaf = (n_features + 2) + \
    +                                                    (n_features - len(in_features))
    +                                                assert len(mps.leaf_nodes) == leaf
    +                                        else:
    +                                            if mpo_boundary == 'obc':
    +                                                leaf = n_features + \
    +                                                    (n_features - len(in_features) + 2)
    +                                                assert len(mps.leaf_nodes) == leaf
    +                                            else:
    +                                                leaf = n_features + \
    +                                                    (n_features - len(in_features))
    +                                                assert len(mps.leaf_nodes) == leaf
    +                                        
    +                                        result.sum().backward()
    +                                        for node in mps.mats_env:
    +                                            assert node.grad is not None
    +                                        for node in mpo.mats_env:
    +                                            assert node.tensor.grad is None
    +    
    +    def test_all_algorithms_marginalize_with_mpo_cuda(self):
    +        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    +        for n_features in [1, 2, 3, 4, 10]:
    +            for mps_boundary in ['obc', 'pbc']:
    +                for mpo_boundary in ['obc', 'pbc']:
    +                    in_features = torch.randint(low=0,
    +                                            high=n_features,
    +                                            size=(n_features // 2,)).tolist()
    +                in_features = list(set(in_features))
    +                
    +                # batch x n_features x feature_dim
    +                example = torch.randn(1, len(in_features), 5, device=device)
    +                data = torch.randn(100, len(in_features), 5, device=device)
                     
                     if example.numel() == 0:
                         example = None
    @@ -1119,6 +1496,96 @@ def test_all_algorithms_marginalize_with_mpo_cuda(self):
                                                 assert node.grad is not None
                                             for node in mpo.mats_env:
                                                 assert node.tensor.grad is None
    +    
    +    def test_all_algorithms_marginalize_with_mpo_complex(self):
    +        for n_features in [1, 2, 3, 4, 10]:
    +            for mps_boundary in ['obc', 'pbc']:
    +                for mpo_boundary in ['obc', 'pbc']:
    +                    in_features = torch.randint(low=0,
    +                                            high=n_features,
    +                                            size=(n_features // 2,)).tolist()
    +                in_features = list(set(in_features))
    +                
    +                # batch x n_features x feature_dim
    +                example = torch.randn(1, len(in_features), 5, dtype=torch.complex64)
    +                data = torch.randn(100, len(in_features), 5, dtype=torch.complex64)
    +                
    +                if example.numel() == 0:
    +                    example = None
    +                    data = None
    +                    
    +                    mps = tk.models.MPS(n_features=n_features,
    +                                        phys_dim=5,
    +                                        bond_dim=2,
    +                                        boundary=mps_boundary,
    +                                        in_features=in_features,
    +                                        dtype=torch.complex64)
    +                    
    +                    mpo = tk.models.MPO(n_features=n_features - len(in_features),
    +                                        in_dim=5,
    +                                        out_dim=5,
    +                                        bond_dim=2,
    +                                        boundary=mpo_boundary,
    +                                        dtype=torch.complex64)
    +                    
    +                    # Send mpo to cuda before deparameterizing, so that all
    +                    # nodes are still in the state_dict of the model and are
    +                    # automatically sent to cuda
    +                    # mpo = mpo.to(device)
    +                    
    +                    # De-parameterize MPO nodes to only train MPS nodes
    +                    mpo = mpo.parameterize(set_param=False, override=True)
    +
    +                    for auto_stack in [True, False]:
    +                        for auto_unbind in [True, False]:
    +                            for inline_input in [True, False]:
    +                                for inline_mats in [True, False]:
    +                                    for renormalize in [True, False]:
    +                                        mps.auto_stack = auto_stack
    +                                        mps.auto_unbind = auto_unbind
    +
    +                                        mps.trace(example,
    +                                                  inline_input=inline_input,
    +                                                  inline_mats=inline_mats,
    +                                                  renormalize=renormalize,
    +                                                  marginalize_output=True,
    +                                                  mpo=mpo)
    +                                        result = mps(data,
    +                                                     inline_input=inline_input,
    +                                                     inline_mats=inline_mats,
    +                                                     renormalize=renormalize,
    +                                                     marginalize_output=True,
    +                                                     mpo=mpo)
    +                                        
    +                                        if in_features:
    +                                            assert result.shape == (100, 100)
    +                                        else:
    +                                            assert result.shape == tuple()
    +                                        
    +                                        if mps_boundary == 'obc':
    +                                            if mpo_boundary == 'obc':
    +                                                leaf = (n_features + 2) + \
    +                                                    (n_features - len(in_features) + 2)
    +                                                assert len(mps.leaf_nodes) == leaf
    +                                            else:
    +                                                leaf = (n_features + 2) + \
    +                                                    (n_features - len(in_features))
    +                                                assert len(mps.leaf_nodes) == leaf
    +                                        else:
    +                                            if mpo_boundary == 'obc':
    +                                                leaf = n_features + \
    +                                                    (n_features - len(in_features) + 2)
    +                                                assert len(mps.leaf_nodes) == leaf
    +                                            else:
    +                                                leaf = n_features + \
    +                                                    (n_features - len(in_features))
    +                                                assert len(mps.leaf_nodes) == leaf
    +                                        
    +                                        result.sum().abs().backward()
    +                                        for node in mps.mats_env:
    +                                            assert node.grad is not None
    +                                        for node in mpo.mats_env:
    +                                            assert node.tensor.grad is None
                                         
         def test_all_algorithms_no_marginalize(self):
             for n_features in [1, 2, 3, 4, 10]:
    @@ -1140,53 +1607,143 @@ def test_all_algorithms_no_marginalize(self):
                                         phys_dim=5,
                                         bond_dim=2,
                                         boundary=boundary,
    -                                    in_features=in_features)
    -
    -                for auto_stack in [True, False]:
    -                    for auto_unbind in [True, False]:
    -                        for inline_input in [True, False]:
    -                            for inline_mats in [True, False]:
    -                                for renormalize in [True, False]:
    -                                    mps.auto_stack = auto_stack
    -                                    mps.auto_unbind = auto_unbind
    -
    -                                    mps.trace(example,
    -                                              inline_input=inline_input,
    -                                              inline_mats=inline_mats,
    -                                              renormalize=renormalize)
    -                                    result = mps(data,
    -                                                 inline_input=inline_input,
    -                                                 inline_mats=inline_mats,
    -                                                 renormalize=renormalize)
    -
    -                                    aux_shape = [5] * len(mps.out_features)
    -                                    if in_features:
    -                                        aux_shape = [100] + aux_shape
    -                                        assert result.shape == tuple(aux_shape)
    -                                        
    -                                        if not inline_input and auto_stack:
    -                                            assert len(mps.virtual_nodes) == 2
    -                                        else:
    -                                            assert len(mps.virtual_nodes) == 1
    -                                            
    -                                    else:
    -                                        assert result.shape == tuple(aux_shape)
    -                                        assert len(mps.virtual_nodes) == 0
    -                                        
    -                                    assert len(mps.edges) == len(mps.out_features)
    -                                    
    -                                    if boundary == 'obc':
    -                                        assert len(mps.leaf_nodes) == n_features + 2
    -                                    else:
    -                                        assert len(mps.leaf_nodes) == n_features
    -                                        
    -                                    assert len(mps.data_nodes) == len(in_features)
    -                                    
    -                                    result.sum().backward()
    -                                    for node in mps.mats_env:
    -                                        assert node.grad is not None
    +                                    in_features=in_features)
    +
    +                for auto_stack in [True, False]:
    +                    for auto_unbind in [True, False]:
    +                        for inline_input in [True, False]:
    +                            for inline_mats in [True, False]:
    +                                for renormalize in [True, False]:
    +                                    mps.auto_stack = auto_stack
    +                                    mps.auto_unbind = auto_unbind
    +
    +                                    mps.trace(example,
    +                                              inline_input=inline_input,
    +                                              inline_mats=inline_mats,
    +                                              renormalize=renormalize)
    +                                    result = mps(data,
    +                                                 inline_input=inline_input,
    +                                                 inline_mats=inline_mats,
    +                                                 renormalize=renormalize)
    +
    +                                    aux_shape = [5] * len(mps.out_features)
    +                                    if in_features:
    +                                        aux_shape = [100] + aux_shape
    +                                        assert result.shape == tuple(aux_shape)
    +                                        
    +                                        if not inline_input and auto_stack:
    +                                            assert len(mps.virtual_nodes) == 2
    +                                        else:
    +                                            assert len(mps.virtual_nodes) == 1
    +                                            
    +                                    else:
    +                                        assert result.shape == tuple(aux_shape)
    +                                        assert len(mps.virtual_nodes) == 0
    +                                        
    +                                    assert len(mps.edges) == len(mps.out_features)
    +                                    
    +                                    if boundary == 'obc':
    +                                        assert len(mps.leaf_nodes) == n_features + 2
    +                                    else:
    +                                        assert len(mps.leaf_nodes) == n_features
    +                                        
    +                                    assert len(mps.data_nodes) == len(in_features)
    +                                    
    +                                    result.sum().backward()
    +                                    for node in mps.mats_env:
    +                                        assert node.grad is not None
    +    
    +    def test_norm(self):
    +        for n_features in [1, 2, 3, 4, 10]:
    +            for boundary in ['obc', 'pbc']: 
    +                in_features = torch.randint(low=0,
    +                                            high=n_features,
    +                                            size=(n_features // 2,)).tolist()
    +                in_features = list(set(in_features))
    +                in_features.sort()
    +                
    +                # batch x n_features x feature_dim
    +                example = torch.randn(1, len(in_features), 5)
    +                
    +                if example.numel() == 0:
    +                    example = None
    +                
    +                mps = tk.models.MPS(n_features=n_features,
    +                                    phys_dim=5,
    +                                    bond_dim=2,
    +                                    boundary=boundary,
    +                                    in_features=in_features)
    +                mps.trace(example)
    +                
    +                assert mps.resultant_nodes
    +                if in_features:
    +                    assert mps.data_nodes
    +                assert mps.in_features == in_features
    +                
    +                for log_scale in [True, False]:
    +                    # MPS has to be reset, otherwise norm automatically calls
    +                    # the forward method that was traced when contracting the MPS
    +                    # with example
    +                    mps.reset()
    +                    norm = mps.norm(log_scale=log_scale)
    +                    assert mps.resultant_nodes
    +                    assert not mps.data_nodes
    +                    assert mps.in_features == in_features
    +                    
    +                    norm.sum().backward()
    +                    for node in mps.mats_env:
    +                        assert node.grad is not None
    +                    
    +                    # Repeat norm
    +                    norm = mps.norm(log_scale=log_scale)
    +    
    +    def test_norm_cuda(self):
    +        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    +        for n_features in [1, 2, 3, 4, 10]:
    +            for boundary in ['obc', 'pbc']: 
    +                in_features = torch.randint(low=0,
    +                                            high=n_features,
    +                                            size=(n_features // 2,)).tolist()
    +                in_features = list(set(in_features))
    +                in_features.sort()
    +                
    +                # batch x n_features x feature_dim
    +                example = torch.randn(1, len(in_features), 5, device=device)
    +                
    +                if example.numel() == 0:
    +                    example = None
    +                
    +                mps = tk.models.MPS(n_features=n_features,
    +                                    phys_dim=5,
    +                                    bond_dim=2,
    +                                    boundary=boundary,
    +                                    in_features=in_features,
    +                                    device=device)
    +                mps.trace(example)
    +                
    +                assert mps.resultant_nodes
    +                if in_features:
    +                    assert mps.data_nodes
    +                assert mps.in_features == in_features
    +                
    +                for log_scale in [True, False]:
    +                    # MPS has to be reset, otherwise norm automatically calls
    +                    # the forward method that was traced when contracting the MPS
    +                    # with example
    +                    mps.reset()
    +                    norm = mps.norm(log_scale=log_scale)
    +                    assert mps.resultant_nodes
    +                    assert not mps.data_nodes
    +                    assert mps.in_features == in_features
    +                    
    +                    norm.sum().backward()
    +                    for node in mps.mats_env:
    +                        assert node.grad is not None
    +                    
    +                    # Repeat norm
    +                    norm = mps.norm(log_scale=log_scale)
         
    -    def test_norm(self):
    +    def test_norm_complex(self):
             for n_features in [1, 2, 3, 4, 10]:
                 for boundary in ['obc', 'pbc']: 
                     in_features = torch.randint(low=0,
    @@ -1196,7 +1753,7 @@ def test_norm(self):
                     in_features.sort()
                     
                     # batch x n_features x feature_dim
    -                example = torch.randn(1, len(in_features), 5)
    +                example = torch.randn(1, len(in_features), 5, dtype=torch.complex64)
                     
                     if example.numel() == 0:
                         example = None
    @@ -1205,7 +1762,8 @@ def test_norm(self):
                                         phys_dim=5,
                                         bond_dim=2,
                                         boundary=boundary,
    -                                    in_features=in_features)
    +                                    in_features=in_features,
    +                                    dtype=torch.complex64)
                     mps.trace(example)
                     
                     assert mps.resultant_nodes
    @@ -1223,7 +1781,7 @@ def test_norm(self):
                         assert not mps.data_nodes
                         assert mps.in_features == in_features
                         
    -                    norm.sum().backward()
    +                    norm.sum().abs().backward()
                         for node in mps.mats_env:
                             assert node.grad is not None
                         
    @@ -1336,6 +1894,59 @@ def test_partial_density_cuda(self):
                     # Repeat density
                     density = mps.partial_density(trace_sites)
         
    +    def test_partial_density_complex(self):
    +        for n_features in [1, 2, 3, 4, 5]:
    +            for boundary in ['obc', 'pbc']:
    +                phys_dim = torch.randint(low=2, high=12,
    +                                         size=(n_features,)).tolist()
    +                bond_dim = torch.randint(low=2, high=10, size=(n_features,)).tolist()
    +                bond_dim = bond_dim[:-1] if boundary == 'obc' else bond_dim
    +                
    +                trace_sites = torch.randint(low=0,
    +                                            high=n_features,
    +                                            size=(n_features // 2,)).tolist()
    +                
    +                mps = tk.models.MPS(n_features=n_features,
    +                                    phys_dim=phys_dim,
    +                                    bond_dim=bond_dim,
    +                                    boundary=boundary,
    +                                    in_features=trace_sites,
    +                                    dtype=torch.complex64)
    +                
    +                in_dims = [phys_dim[i] for i in mps.in_features]
    +                example = [torch.randn(1, d, dtype=torch.complex64) for d in in_dims]
    +                if example == []:
    +                    example = None
    +                
    +                mps.trace(example)
    +                
    +                assert mps.resultant_nodes
    +                if trace_sites:
    +                    assert mps.data_nodes
    +                assert set(mps.in_features) == set(trace_sites)
    +                
    +                # MPS has to be reset, otherwise partial_density automatically
    +                # calls the forward method that was traced when contracting the
    +                # MPS with example
    +                mps.reset()
    +                
    +                # Here, trace_sites are now the out_features,
    +                # not the in_features
    +                density = mps.partial_density(trace_sites)
    +                assert mps.resultant_nodes
    +                assert mps.data_nodes
    +                assert set(mps.out_features) == set(trace_sites)
    +                
    +                assert density.shape == \
    +                    tuple([phys_dim[i] for i in mps.in_features] * 2)
    +                
    +                density.sum().abs().backward()
    +                for node in mps.mats_env:
    +                    assert node.grad is not None
    +                
    +                # Repeat density
    +                density = mps.partial_density(trace_sites)
    +    
         def test_entropy(self):
             for n_features in [1, 2, 3, 4, 10]:
                 for boundary in ['obc', 'pbc']:
    @@ -1343,12 +1954,13 @@ def test_entropy(self):
                         bond_dim = torch.randint(low=2, high=10,
                                                  size=(n_features,)).tolist()
                         bond_dim = bond_dim[:-1] if boundary == 'obc' else bond_dim
    -                                
    +                    
                         mps = tk.models.MPS(n_features=n_features,
                                             phys_dim=2,
                                             bond_dim=bond_dim,
                                             boundary=boundary,
    -                                        in_features=[])
    +                                        in_features=[],
    +                                        init_method='canonical')
                         
                         mps_tensor = mps()
                         assert mps_tensor.shape == (2,) * n_features
    @@ -1364,17 +1976,128 @@ def test_entropy(self):
                         assert len(mps.data_nodes) == n_features
                         
                         # Mutual Information
    -                    scaled_mi, log_norm = mps.entropy(middle_site=middle_site,
    -                                                      renormalize=True)
    -                    mi = mps.entropy(middle_site=middle_site,
    -                                     renormalize=False)
    +                    scaled_entropy, log_norm = mps.entropy(middle_site=middle_site,
    +                                                           renormalize=True)
    +                    entropy = mps.entropy(middle_site=middle_site,
    +                                          renormalize=False)
    +                    
    +                    assert all([mps.bond_dim[i] <= bond_dim[i]
    +                                for i in range(len(bond_dim))])
    +                    
    +                    sq_norm = log_norm.exp().pow(2)
    +                    approx_entropy = sq_norm * scaled_entropy - sq_norm * 2 * log_norm
    +                    assert torch.isclose(entropy, approx_entropy,
    +                                         rtol=1e-03, atol=1e-05)
    +                    
    +                    if mps.boundary == 'obc':
    +                        assert len(mps.leaf_nodes) == n_features + 2
    +                    else:
    +                        assert len(mps.leaf_nodes) == n_features
    +                    assert len(mps.data_nodes) == n_features
    +                    
    +                    mps.unset_data_nodes()
    +                    mps.in_features = []
    +                    approx_mps_tensor = mps()
    +                    assert approx_mps_tensor.shape == (2,) * n_features
    +    
    +    def test_entropy_cuda(self):
    +        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    +        for n_features in [1, 2, 3, 4, 10]:
    +            for boundary in ['obc', 'pbc']:
    +                for middle_site in range(n_features - 1):
    +                    bond_dim = torch.randint(low=2, high=10,
    +                                             size=(n_features,)).tolist()
    +                    bond_dim = bond_dim[:-1] if boundary == 'obc' else bond_dim
    +                                
    +                    mps = tk.models.MPS(n_features=n_features,
    +                                        phys_dim=2,
    +                                        bond_dim=bond_dim,
    +                                        boundary=boundary,
    +                                        in_features=[],
    +                                        device=device,
    +                                        init_method='canonical')
    +                    
    +                    mps_tensor = mps()
    +                    assert mps_tensor.shape == (2,) * n_features
    +                    
    +                    mps.out_features = []
    +                    example = torch.randn(1, n_features, 2, device=device)
    +                    mps.trace(example)
    +                    
    +                    if mps.boundary == 'obc':
    +                        assert len(mps.leaf_nodes) == n_features + 2
    +                    else:
    +                        assert len(mps.leaf_nodes) == n_features
    +                    assert len(mps.data_nodes) == n_features
    +                    
    +                    # Mutual Information
    +                    scaled_entropy, log_norm = mps.entropy(middle_site=middle_site,
    +                                                           renormalize=True)
    +                    entropy = mps.entropy(middle_site=middle_site,
    +                                          renormalize=False)
    +                    
    +                    assert all([mps.bond_dim[i] <= bond_dim[i]
    +                                for i in range(len(bond_dim))])
    +                    
    +                    sq_norm = log_norm.exp().pow(2)
    +                    approx_entropy = sq_norm * scaled_entropy - sq_norm * 2 * log_norm
    +                    assert torch.isclose(entropy, approx_entropy,
    +                                         rtol=1e-03, atol=1e-05)
    +                    
    +                    if mps.boundary == 'obc':
    +                        assert len(mps.leaf_nodes) == n_features + 2
    +                    else:
    +                        assert len(mps.leaf_nodes) == n_features
    +                    assert len(mps.data_nodes) == n_features
    +                    
    +                    mps.unset_data_nodes()
    +                    mps.in_features = []
    +                    approx_mps_tensor = mps()
    +                    assert approx_mps_tensor.shape == (2,) * n_features
    +    
    +    def test_entropy_complex(self):
    +        for n_features in [1, 2, 3, 4, 10]:
    +            for boundary in ['obc', 'pbc']:
    +                for middle_site in range(n_features - 1):
    +                    bond_dim = torch.randint(low=2, high=10,
    +                                             size=(n_features,)).tolist()
    +                    bond_dim = bond_dim[:-1] if boundary == 'obc' else bond_dim
    +                                
    +                    mps = tk.models.MPS(n_features=n_features,
    +                                        phys_dim=2,
    +                                        bond_dim=bond_dim,
    +                                        boundary=boundary,
    +                                        in_features=[],
    +                                        dtype=torch.complex64,
    +                                        init_method='canonical')
    +                    
    +                    mps_tensor = mps()
    +                    assert mps_tensor.shape == (2,) * n_features
    +                    
    +                    mps.out_features = []
    +                    example = torch.randn(1, n_features, 2,
    +                                          dtype=torch.complex64)
    +                    mps.trace(example)
    +                    
    +                    if mps.boundary == 'obc':
    +                        assert len(mps.leaf_nodes) == n_features + 2
    +                    else:
    +                        assert len(mps.leaf_nodes) == n_features
    +                    assert len(mps.data_nodes) == n_features
    +                    
    +                    # Mutual Information
    +                    scaled_entropy, log_norm = mps.entropy(middle_site=middle_site,
    +                                                           renormalize=True)
    +                    entropy = mps.entropy(middle_site=middle_site,
    +                                          renormalize=False)
                         
                         assert all([mps.bond_dim[i] <= bond_dim[i]
                                     for i in range(len(bond_dim))])
                         
                         sq_norm = log_norm.exp().pow(2)
    -                    approx_mi = sq_norm * scaled_mi - sq_norm * 2 * log_norm
    -                    assert torch.isclose(mi, approx_mi)
    +                    approx_entropy = sq_norm * scaled_entropy - sq_norm * 2 * log_norm
    +                    assert torch.isclose(entropy, approx_entropy,
    +                                         rtol=1e-03, atol=1e-05)
                         
                         if mps.boundary == 'obc':
                             assert len(mps.leaf_nodes) == n_features + 2
    @@ -1387,7 +2110,59 @@ def test_entropy(self):
                         approx_mps_tensor = mps()
                         assert approx_mps_tensor.shape == (2,) * n_features
         
    -    def test_canonicalize(self):
    +    def test_canonicalize(self):
    +        for n_features in [1, 2, 3, 4, 10]:
    +            for boundary in ['obc', 'pbc']:
    +                for oc in range(n_features):
    +                    for mode in ['svd', 'svdr', 'qr']:
    +                        for renormalize in [True, False]:
    +                            mps = tk.models.MPS(n_features=n_features,
    +                                                phys_dim=2,
    +                                                bond_dim=10,
    +                                                boundary=boundary,
    +                                                in_features=[])
    +                            
    +                            mps_tensor = mps()
    +                            assert mps_tensor.shape == (2,) * n_features
    +                            
    +                            mps.out_features = []
    +                            example = torch.randn(1, n_features, 2)
    +                            mps.trace(example)
    +                            
    +                            if mps.boundary == 'obc':
    +                                assert len(mps.leaf_nodes) == n_features + 2
    +                            else:
    +                                assert len(mps.leaf_nodes) == n_features
    +                            assert len(mps.data_nodes) == n_features
    +                            
    +                            # Canonicalize
    +                            rank = torch.randint(5, 11, (1,)).item()
    +                            mps.canonicalize(oc=oc,
    +                                             mode=mode,
    +                                             rank=rank,
    +                                             cum_percentage=0.98,
    +                                             cutoff=1e-5,
    +                                             renormalize=renormalize)
    +                            
    +                            if mps.bond_dim and mode != 'qr':
    +                                if mps.boundary == 'obc':
    +                                    assert (torch.tensor(mps.bond_dim) <= rank).all()
    +                                else:
    +                                    assert (torch.tensor(mps.bond_dim[:-1]) <= rank).all()
    +                            
    +                            if mps.boundary == 'obc':
    +                                assert len(mps.leaf_nodes) == n_features + 2
    +                            else:
    +                                assert len(mps.leaf_nodes) == n_features
    +                            assert len(mps.data_nodes) == n_features
    +                            
    +                            mps.unset_data_nodes()
    +                            mps.in_features = []
    +                            approx_mps_tensor = mps()
    +                            assert approx_mps_tensor.shape == (2,) * n_features
    +    
    +    def test_canonicalize_cuda(self):
    +        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
             for n_features in [1, 2, 3, 4, 10]:
                 for boundary in ['obc', 'pbc']:
                     for oc in range(n_features):
    @@ -1397,13 +2172,68 @@ def test_canonicalize(self):
                                                     phys_dim=2,
                                                     bond_dim=10,
                                                     boundary=boundary,
    -                                                in_features=[])
    +                                                in_features=[],
    +                                                device=device)
                                 
                                 mps_tensor = mps()
                                 assert mps_tensor.shape == (2,) * n_features
                                 
                                 mps.out_features = []
    -                            example = torch.randn(1, n_features, 2)
    +                            example = torch.randn(1, n_features, 2,
    +                                                  device=device)
    +                            mps.trace(example)
    +                            
    +                            if mps.boundary == 'obc':
    +                                assert len(mps.leaf_nodes) == n_features + 2
    +                            else:
    +                                assert len(mps.leaf_nodes) == n_features
    +                            assert len(mps.data_nodes) == n_features
    +                            
    +                            # Canonicalize
    +                            rank = torch.randint(5, 11, (1,)).item()
    +                            mps.canonicalize(oc=oc,
    +                                             mode=mode,
    +                                             rank=rank,
    +                                             cum_percentage=0.98,
    +                                             cutoff=1e-5,
    +                                             renormalize=renormalize)
    +                            
    +                            if mps.bond_dim and mode != 'qr':
    +                                if mps.boundary == 'obc':
    +                                    assert (torch.tensor(mps.bond_dim) <= rank).all()
    +                                else:
    +                                    assert (torch.tensor(mps.bond_dim[:-1]) <= rank).all()
    +                            
    +                            if mps.boundary == 'obc':
    +                                assert len(mps.leaf_nodes) == n_features + 2
    +                            else:
    +                                assert len(mps.leaf_nodes) == n_features
    +                            assert len(mps.data_nodes) == n_features
    +                            
    +                            mps.unset_data_nodes()
    +                            mps.in_features = []
    +                            approx_mps_tensor = mps()
    +                            assert approx_mps_tensor.shape == (2,) * n_features
    +    
    +    def test_canonicalize_complex(self):
    +        for n_features in [1, 2, 3, 4, 10]:
    +            for boundary in ['obc', 'pbc']:
    +                for oc in range(n_features):
    +                    for mode in ['svd', 'svdr', 'qr']:
    +                        for renormalize in [True, False]:
    +                            mps = tk.models.MPS(n_features=n_features,
    +                                                phys_dim=2,
    +                                                bond_dim=10,
    +                                                boundary=boundary,
    +                                                in_features=[],
    +                                                dtype=torch.complex64)
    +                            
    +                            mps_tensor = mps()
    +                            assert mps_tensor.shape == (2,) * n_features
    +                            
    +                            mps.out_features = []
    +                            example = torch.randn(1, n_features, 2,
    +                                                  dtype=torch.complex64)
                                 mps.trace(example)
                                 
                                 if mps.boundary == 'obc':
    @@ -1517,6 +2347,75 @@ def test_canonicalize_univocal(self):
                 approx_mps_tensor = mps()
                 assert approx_mps_tensor.shape == (2,) * n_features
                 
    +            assert torch.allclose(mps_tensor, approx_mps_tensor,
    +                                  rtol=1e-2, atol=1e-3)
    +    
    +    def test_canonicalize_univocal_cuda(self):
    +        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    +        for n_features in [1, 2, 3, 4, 10]:
    +            mps = tk.models.MPS(n_features=n_features,
    +                                phys_dim=2,
    +                                bond_dim=10,
    +                                boundary='obc',
    +                                in_features=[],
    +                                init_method='unit',
    +                                device=device)
    +            
    +            mps_tensor = mps()
    +            assert mps_tensor.shape == (2,) * n_features
    +            
    +            mps.out_features = []
    +            example = torch.randn(1, n_features, 2, device=device)
    +            mps.trace(example)
    +            
    +            assert len(mps.leaf_nodes) == n_features + 2
    +            assert len(mps.data_nodes) == n_features
    +            
    +            # Canonicalize
    +            mps.canonicalize_univocal()
    +            
    +            assert len(mps.leaf_nodes) == n_features + 2
    +            assert len(mps.data_nodes) == n_features
    +            
    +            mps.unset_data_nodes()
    +            mps.in_features = []
    +            approx_mps_tensor = mps()
    +            assert approx_mps_tensor.shape == (2,) * n_features
    +            
    +            assert torch.allclose(mps_tensor, approx_mps_tensor,
    +                                  rtol=1e-2, atol=1e-4)
    +    
    +    def test_canonicalize_univocal_complex(self):
    +        for n_features in [1, 2, 3, 4, 10]:
    +            mps = tk.models.MPS(n_features=n_features,
    +                                phys_dim=2,
    +                                bond_dim=10,
    +                                boundary='obc',
    +                                in_features=[],
    +                                init_method='unit',
    +                                dtype=torch.complex64)
    +            
    +            mps_tensor = mps()
    +            assert mps_tensor.shape == (2,) * n_features
    +            
    +            mps.out_features = []
    +            example = torch.randn(1, n_features, 2, dtype=torch.complex64)
    +            mps.trace(example)
    +            
    +            assert len(mps.leaf_nodes) == n_features + 2
    +            assert len(mps.data_nodes) == n_features
    +            
    +            # Canonicalize
    +            mps.canonicalize_univocal()
    +            
    +            assert len(mps.leaf_nodes) == n_features + 2
    +            assert len(mps.data_nodes) == n_features
    +            
    +            mps.unset_data_nodes()
    +            mps.in_features = []
    +            approx_mps_tensor = mps()
    +            assert approx_mps_tensor.shape == (2,) * n_features
    +            
                 assert torch.allclose(mps_tensor, approx_mps_tensor,
                                       rtol=1e-2, atol=1e-4)
         
    @@ -1681,12 +2580,12 @@ def test_initialize_init_method(self):
                 for init_method in methods:
                     mps = tk.models.UMPS(n_features=n,
                                          phys_dim=2,
    -                                     bond_dim=2,
    +                                     bond_dim=5,
                                          init_method=init_method)
                     assert mps.n_features == n
                     assert mps.boundary == 'pbc'
                     assert mps.phys_dim == [2] * n
    -                assert mps.bond_dim == [2] * n
    +                assert mps.bond_dim == [5] * n
         
         def test_initialize_init_method_cuda(self):
             device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    @@ -1696,16 +2595,34 @@ def test_initialize_init_method_cuda(self):
                 for init_method in methods:
                     mps = tk.models.UMPS(n_features=n,
                                          phys_dim=2,
    -                                     bond_dim=2,
    +                                     bond_dim=5,
                                          init_method=init_method,
                                          device=device)
                     assert mps.n_features == n
                     assert mps.boundary == 'pbc'
                     assert mps.phys_dim == [2] * n
    -                assert mps.bond_dim == [2] * n
    +                assert mps.bond_dim == [5] * n
                     for node in mps.mats_env:
                         assert node.device == device
         
    +    def test_initialize_init_method_complex(self):
    +        methods = ['zeros', 'ones', 'copy', 'rand', 'randn',
    +                   'randn_eye', 'unit', 'canonical']
    +        for n in [1, 2, 5]:
    +            for init_method in methods:
    +                mps = tk.models.UMPS(n_features=n,
    +                                     phys_dim=2,
    +                                     bond_dim=5,
    +                                     init_method=init_method,
    +                                     dtype=torch.complex64)
    +                assert mps.n_features == n
    +                assert mps.boundary == 'pbc'
    +                assert mps.phys_dim == [2] * n
    +                assert mps.bond_dim == [5] * n
    +                for node in mps.mats_env:
    +                    assert node.dtype == torch.complex64
    +                    assert node.is_complex()
    +    
         def test_initialize_with_unitaries(self):
             for n in [1, 2, 5]:
                 mps = tk.models.UMPS(n_features=n,
    @@ -2105,6 +3022,55 @@ def test_partial_density_cuda(self):
                 # Repeat density
                 density = mps.partial_density(trace_sites)
         
    +    def test_partial_density_complex(self):
    +        for n_features in [1, 2, 3, 4, 5]:
    +            phys_dim = torch.randint(low=2, high=12, size=(1,)).item()
    +            bond_dim = torch.randint(low=2, high=10, size=(1,)).item()
    +            
    +            trace_sites = torch.randint(low=0,
    +                                        high=n_features,
    +                                        size=(n_features // 2,)).tolist()
    +            
    +            mps = tk.models.UMPS(n_features=n_features,
    +                                 phys_dim=phys_dim,
    +                                 bond_dim=bond_dim,
    +                                 in_features=trace_sites,
    +                                 dtype=torch.complex64)
    +            
    +            # batch x n_features x feature_dim
    +            example = torch.randn(1, n_features // 2, phys_dim,
    +                                  dtype=torch.complex64)
    +            if example.numel() == 0:
    +                example = None
    +            
    +            mps.trace(example)
    +            
    +            assert mps.resultant_nodes
    +            if trace_sites:
    +                assert mps.data_nodes
    +            assert set(mps.in_features) == set(trace_sites)
    +            
    +            # MPS has to be reset, otherwise partial_density automatically
    +            # calls the forward method that was traced when contracting the
    +            # MPS with example
    +            mps.reset()
    +            
    +            # Here, trace_sites are now the out_features,
    +            # not the in_features
    +            density = mps.partial_density(trace_sites)
    +            assert mps.resultant_nodes
    +            assert mps.data_nodes
    +            assert set(mps.out_features) == set(trace_sites)
    +            
    +            assert density.shape == (phys_dim,) * 2 * len(mps.in_features)
    +            
    +            density.sum().abs().backward()
    +            for node in mps.mats_env:
    +                assert node.grad is not None
    +            
    +            # Repeat density
    +            density = mps.partial_density(trace_sites)
    +    
         def test_canonicalize_error(self):
             mps = tk.models.UMPS(n_features=10,
                                  phys_dim=2,
    @@ -2175,26 +3141,26 @@ def test_initialize_init_method(self):
                                              n_features=n,
                                              in_dim=2,
                                              out_dim=10,
    -                                         bond_dim=2,
    +                                         bond_dim=5,
                                              init_method=init_method)
                     assert mps.n_features == n
                     assert mps.boundary == 'pbc'
                     assert mps.in_dim == [2] * (n - 1)
                     assert mps.out_dim == 10
    -                assert mps.bond_dim == [2] * n
    +                assert mps.bond_dim == [5] * n
                     
                     # OBC
                     mps = tk.models.MPSLayer(boundary='obc',
                                              n_features=n,
                                              in_dim=2,
                                              out_dim=10,
    -                                         bond_dim=2,
    +                                         bond_dim=5,
                                              init_method=init_method)
                     assert mps.n_features == n
                     assert mps.boundary == 'obc'
                     assert mps.in_dim == [2] * (n - 1)
                     assert mps.out_dim == 10
    -                assert mps.bond_dim == [2] * (n - 1)
    +                assert mps.bond_dim == [5] * (n - 1)
                     
                     assert torch.equal(mps.left_node.tensor[0],
                                        torch.ones_like(mps.left_node.tensor)[0])
    @@ -2217,14 +3183,14 @@ def test_initialize_init_method_cuda(self):
                                              n_features=n,
                                              in_dim=2,
                                              out_dim=10,
    -                                         bond_dim=2,
    +                                         bond_dim=5,
                                              init_method=init_method,
                                              device=device)
                     assert mps.n_features == n
                     assert mps.boundary == 'pbc'
                     assert mps.in_dim == [2] * (n - 1)
                     assert mps.out_dim == 10
    -                assert mps.bond_dim == [2] * n
    +                assert mps.bond_dim == [5] * n
                     for node in mps.mats_env:
                         assert node.device == device
                     
    @@ -2233,14 +3199,14 @@ def test_initialize_init_method_cuda(self):
                                              n_features=n,
                                              in_dim=2,
                                              out_dim=10,
    -                                         bond_dim=2,
    +                                         bond_dim=5,
                                              init_method=init_method,
                                              device=device)
                     assert mps.n_features == n
                     assert mps.boundary == 'obc'
                     assert mps.in_dim == [2] * (n - 1)
                     assert mps.out_dim == 10
    -                assert mps.bond_dim == [2] * (n - 1)
    +                assert mps.bond_dim == [5] * (n - 1)
                     for node in mps.mats_env:
                         assert node.device == device
                     
    @@ -2254,6 +3220,55 @@ def test_initialize_init_method_cuda(self):
                     assert torch.equal(mps.right_node.tensor[1:],
                                        torch.zeros_like(mps.right_node.tensor)[1:])
         
    +    def test_initialize_init_method_complex(self):
    +        methods = ['zeros', 'ones', 'copy', 'rand', 'randn',
    +                   'randn_eye', 'unit', 'canonical']
    +        for n in [1, 2, 5]:
    +            for init_method in methods:
    +                # PBC
    +                mps = tk.models.MPSLayer(boundary='pbc',
    +                                         n_features=n,
    +                                         in_dim=2,
    +                                         out_dim=10,
    +                                         bond_dim=5,
    +                                         init_method=init_method,
    +                                         dtype=torch.complex64)
    +                assert mps.n_features == n
    +                assert mps.boundary == 'pbc'
    +                assert mps.in_dim == [2] * (n - 1)
    +                assert mps.out_dim == 10
    +                assert mps.bond_dim == [5] * n
    +                for node in mps.mats_env:
    +                    assert node.dtype == torch.complex64
    +                    assert node.is_complex()
    +                
    +                # OBC
    +                mps = tk.models.MPSLayer(boundary='obc',
    +                                         n_features=n,
    +                                         in_dim=2,
    +                                         out_dim=10,
    +                                         bond_dim=5,
    +                                         init_method=init_method,
    +                                         dtype=torch.complex64)
    +                assert mps.n_features == n
    +                assert mps.boundary == 'obc'
    +                assert mps.in_dim == [2] * (n - 1)
    +                assert mps.out_dim == 10
    +                assert mps.bond_dim == [5] * (n - 1)
    +                for node in mps.mats_env:
    +                    assert node.dtype == torch.complex64
    +                    assert node.is_complex()
    +                
    +                assert torch.equal(mps.left_node.tensor[0],
    +                                   torch.ones_like(mps.left_node.tensor)[0])
    +                assert torch.equal(mps.left_node.tensor[1:],
    +                                   torch.zeros_like(mps.left_node.tensor)[1:])
    +                
    +                assert torch.equal(mps.right_node.tensor[0],
    +                                   torch.ones_like(mps.right_node.tensor)[0])
    +                assert torch.equal(mps.right_node.tensor[1:],
    +                                   torch.zeros_like(mps.right_node.tensor)[1:])
    +    
         def test_initialize_canonical(self):
             for n in [1, 2, 5]:
                 # PBC
    @@ -2283,7 +3298,6 @@ def test_initialize_canonical(self):
                 assert mps.bond_dim == [10] * (n - 1)
                 
                 # Check it has norm == 10**n
    -            norm = mps.norm()
                 assert mps.norm().isclose(torch.tensor(10. ** n).sqrt())
                 # Norm is close to 10**n if bond dimension is <= than
                 # physical dimension, otherwise, it will not be exactly 10**n
    @@ -2324,11 +3338,52 @@ def test_initialize_canonical_cuda(self):
                     assert node.device == device
                 
                 # Check it has norm == 10**n
    -            norm = mps.norm()
                 assert mps.norm().isclose(torch.tensor(10. ** n).sqrt())
                 # Norm is close to 10**n if bond dimension is <= than
                 # physical dimension, otherwise, it will not be exactly 10**n
         
    +    def test_initialize_canonical_complex(self):
    +        for n in [1, 2, 5]:
    +            # PBC
    +            mps = tk.models.MPSLayer(boundary='pbc',
    +                                     n_features=n,
    +                                     in_dim=10,
    +                                     out_dim=10,
    +                                     bond_dim=10,
    +                                     init_method='canonical',
    +                                     dtype=torch.complex64)
    +            assert mps.n_features == n
    +            assert mps.boundary == 'pbc'
    +            assert mps.phys_dim == [10] * n
    +            assert mps.bond_dim == [10] * n
    +            for node in mps.mats_env:
    +                assert node.dtype == torch.complex64
    +                assert node.is_complex()
    +            
    +            # For PBC norm does not have to be 10**n always
    +            
    +            # OBC
    +            mps = tk.models.MPSLayer(boundary='obc',
    +                                     n_features=n,
    +                                     in_dim=10,
    +                                     out_dim=10,
    +                                     bond_dim=10,
    +                                     init_method='canonical',
    +                                     dtype=torch.complex64)
    +            assert mps.n_features == n
    +            assert mps.boundary == 'obc'
    +            assert mps.phys_dim == [10] * n
    +            assert mps.bond_dim == [10] * (n - 1)
    +            for node in mps.mats_env:
    +                assert node.dtype == torch.complex64
    +                assert node.is_complex()
    +            
    +            # Check it has norm == 10**n
    +            assert mps.norm().isclose(torch.tensor(10. ** n,
    +                                                   dtype=torch.complex64).sqrt())
    +            # Norm is close to 10**n if bond dimension is <= than
    +            # physical dimension, otherwise, it will not be exactly 10**n
    +    
         def test_in_and_out_features(self):
             tensors = [torch.randn(10, 2, 10) for _ in range(10)]
             mps = tk.models.MPSLayer(tensors=tensors,
    @@ -2700,7 +3755,7 @@ def test_initialize_init_method(self):
                             mps = tk.models.MPSData(boundary=boundary,
                                                     n_features=n_features,
                                                     phys_dim=2,
    -                                                bond_dim=2,
    +                                                bond_dim=5,
                                                     n_batches=n_batches,
                                                     init_method=init_method)
                             
    @@ -2709,16 +3764,16 @@ def test_initialize_init_method(self):
                             assert mps.phys_dim == [2] * n_features
                             
                             if boundary == 'obc':
    -                            assert mps.bond_dim == [2] * (n_features - 1)
    +                            assert mps.bond_dim == [5] * (n_features - 1)
                             else:
    -                            assert mps.bond_dim == [2] * n_features
    +                            assert mps.bond_dim == [5] * n_features
                             
                             if (n_features == 1) and (boundary == 'obc'):
                                 node = mps.mats_env[0]
                                 assert node.shape == tuple([1] * n_batches + [1, 2, 1])
                             else:  
                                 for node in mps.mats_env:
    -                                assert node.shape == tuple([1] * n_batches + [2] * 3)
    +                                assert node.shape == tuple([1] * n_batches + [5, 2, 5])
         
         def test_initialize_add_data(self):
             for n_features in [1, 2, 5]:
    diff --git a/tests/test_components.py b/tests/test_components.py
    index 050815f..bbf55e4 100644
    --- a/tests/test_components.py
    +++ b/tests/test_components.py
    @@ -599,16 +599,54 @@ def test_set_init_method(self, setup):
             assert node1.tensor is None
     
             # Initialize tensor of node1
    -        node1.set_tensor(init_method='randn', mean=1., std=2.)
    -        assert node1.tensor is not None
    +        for init_method in ["zeros", "ones", "copy", "rand", "randn"]:
    +            node1.set_tensor(init_method=init_method)
    +            assert node1.tensor is not None
     
    -        # Set node1's tensor as node2's tensor
    -        node2.tensor = node1.tensor
    -        assert torch.equal(node1.tensor, node2.tensor)
    +            # Set node1's tensor as node2's tensor
    +            node2.tensor = node1.tensor
    +            assert torch.equal(node1.tensor, node2.tensor)
     
    -        # Changing node1's tensor changes node2's tensor
    -        node1.tensor[0, 0, 0] = 1000
    -        assert node2.tensor[0, 0, 0] == 1000
    +            # Changing node1's tensor changes node2's tensor
    +            node1.tensor[0, 0, 0] = 1000
    +            assert node2.tensor[0, 0, 0] == 1000
    +    
    +    def test_set_init_method_cuda(self, setup):
    +        node1, node2, tensor = setup
    +        assert node1.tensor is None
    +        
    +        # Send to cuda if possible
    +        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    +
    +        # Initialize tensor of node1
    +        for init_method in ["zeros", "ones", "copy", "rand", "randn"]:
    +            node1.set_tensor(init_method=init_method, device=device)
    +            assert node1.tensor is not None
    +
    +            # Set node1's tensor as node2's tensor
    +            node2.tensor = node1.tensor
    +            assert torch.equal(node1.tensor, node2.tensor)
    +
    +            # Changing node1's tensor changes node2's tensor
    +            node1.tensor[0, 0, 0] = 1000
    +            assert node2.tensor[0, 0, 0] == 1000
    +    
    +    def test_set_init_method_complex(self, setup):
    +        node1, node2, tensor = setup
    +        assert node1.tensor is None
    +
    +        # Initialize tensor of node1
    +        for init_method in ["zeros", "ones", "copy", "rand", "randn"]:
    +            node1.set_tensor(init_method=init_method, dtype=torch.complex64)
    +            assert node1.tensor is not None
    +
    +            # Set node1's tensor as node2's tensor
    +            node2.tensor = node1.tensor
    +            assert torch.equal(node1.tensor, node2.tensor)
    +
    +            # Changing node1's tensor changes node2's tensor
    +            node1.tensor[0, 0, 0] = 1000
    +            assert node2.tensor[0, 0, 0] == 1000
     
         def test_set_tensor_from(self, setup):
             node1, node2, tensor = setup
    @@ -901,16 +939,54 @@ def test_set_init_method(self, setup):
             assert node1.tensor is None
     
             # Initialize tensor of node1
    -        node1.set_tensor(init_method='randn', mean=1., std=2.)
    -        assert node1.tensor is not None
    +        for init_method in ["zeros", "ones", "copy", "rand", "randn"]:
    +            node1.set_tensor(init_method=init_method)
    +            assert node1.tensor is not None
     
    -        # Set node1's tensor as node2's tensor
    -        node2.tensor = node1.tensor
    -        assert torch.equal(node1.tensor, node2.tensor)
    +            # Set node1's tensor as node2's tensor
    +            node2.tensor = node1.tensor
    +            assert torch.equal(node1.tensor, node2.tensor)
     
    -        # Cannot change element of Parameter
    -        with pytest.raises(RuntimeError):
    -            node1.tensor[0, 0, 0] = 1000
    +            # Cannot change element of Parameter
    +            with pytest.raises(RuntimeError):
    +                node1.tensor[0, 0, 0] = 1000
    +    
    +    def test_set_init_method_cuda(self, setup):
    +        node1, node2, tensor = setup
    +        assert node1.tensor is None
    +        
    +        # Send to cuda if possible
    +        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    +
    +        # Initialize tensor of node1
    +        for init_method in ["zeros", "ones", "copy", "rand", "randn"]:
    +            node1.set_tensor(init_method=init_method, device=device)
    +            assert node1.tensor is not None
    +
    +            # Set node1's tensor as node2's tensor
    +            node2.tensor = node1.tensor
    +            assert torch.equal(node1.tensor, node2.tensor)
    +
    +            # Cannot change element of Parameter
    +            with pytest.raises(RuntimeError):
    +                node1.tensor[0, 0, 0] = 1000
    +    
    +    def test_set_init_method_complex(self, setup):
    +        node1, node2, tensor = setup
    +        assert node1.tensor is None
    +
    +        # Initialize tensor of node1
    +        for init_method in ["zeros", "ones", "copy", "rand", "randn"]:
    +            node1.set_tensor(init_method=init_method, dtype=torch.complex64)
    +            assert node1.tensor is not None
    +
    +            # Set node1's tensor as node2's tensor
    +            node2.tensor = node1.tensor
    +            assert torch.equal(node1.tensor, node2.tensor)
    +
    +            # Cannot change element of Parameter
    +            with pytest.raises(RuntimeError):
    +                node1.tensor[0, 0, 0] = 1000
     
         def test_set_parametric(self, setup):
             node1, node2, tensor = setup
    diff --git a/tests/test_operations.py b/tests/test_operations.py
    index 9d51a49..0c180e1 100644
    --- a/tests/test_operations.py
    +++ b/tests/test_operations.py
    @@ -2,7 +2,7 @@
     This script contains tests for operations:
     
         * TestPermute
    -    * TestBasicOps
    +    * TestTensorOps
         * TestSplitSVD
         * TestSplitSVDR
         * TestSplitQR
    @@ -159,7 +159,7 @@ def test_permute_paramnode_in_place(self):
             assert torch.equal(permuted_node.tensor, node.tensor.permute(0, 2, 1))
     
     
    -class TestBasicOps:
    +class TestTensorOps:
     
         @pytest.fixture
         def setup(self):
    @@ -521,6 +521,22 @@ def test_renormalize(self):
             
             # Repeat
             node4 = node1.renormalize(axis=['left', 'right'])
    +    
    +    def test_conj(self):
    +        node1 = tk.randn(shape=(3, 3, 3),
    +                         axes_names=('left', 'input', 'right'),
    +                         dtype=torch.complex64)
    +        assert node1.successors == dict()
    +        
    +        node2 = tk.conj(node1)
    +        assert node2 is not node1
    +        assert node2.shape == (3, 3, 3)
    +        assert (node2.tensor == node1.tensor.conj()).all()
    +        assert node2.is_conj()
    +        assert node1.successors != dict()
    +        
    +        # Repeat
    +        node2 = tk.conj(node1)
     
         @pytest.fixture
         def setup_param(self):
    @@ -915,6 +931,23 @@ def test_renormalize_param(self):
             
             # Repeat
             node4 = node1.renormalize(axis=['left', 'right'])
    +    
    +    def test_conj_param(self):
    +        node1 = tk.randn(shape=(3, 3, 3),
    +                         axes_names=('left', 'input', 'right'),
    +                         param_node=True,
    +                         dtype=torch.complex64)
    +        assert node1.successors == dict()
    +        
    +        node2 = tk.conj(node1)
    +        assert node2 is not node1
    +        assert node2.shape == (3, 3, 3)
    +        assert (node2.tensor == node1.tensor.conj()).all()
    +        assert node2.is_conj()
    +        assert node1.successors != dict()
    +        
    +        # Repeat
    +        node2 = tk.conj(node1)
     
     
     class TestSplitSVD:
    @@ -1106,7 +1139,7 @@ def test_split_contracted_node_cum_percentage(self):
             assert new_node1.shape == (10, 2, 5, 1)
             assert new_node2.shape == (10, 1, 5, 3)
     
    -    def test_split_contracted_node_paramnodes_cum_percentage(self):
    +    def test_split_contracted_node_paramnode_cum_percentage(self):
             net = tk.TensorNetwork()
             node1 = tk.ParamNode(shape=(10, 2, 5, 4),
                                  axes_names=('batch', 'left', 'input', 'right'),
    @@ -1264,6 +1297,59 @@ def test_split_contracted_node_rank_cum_percentage_cutoff(self):
             
             assert new_node1.shape == (10, 2, 5, 8)
             assert new_node2.shape == (10, 8, 5, 3)
    +    
    +    def test_split_contracted_complex_node(self):
    +        net = tk.TensorNetwork()
    +        node1 = tk.Node(shape=(10, 2, 5, 4),
    +                        axes_names=('batch', 'left', 'input', 'right'),
    +                        name='node1',
    +                        init_method='randn',
    +                        dtype=torch.complex64,
    +                        network=net)
    +        node2 = tk.Node(shape=(10, 4, 5, 3),
    +                        axes_names=('batch', 'left', 'input', 'right'),
    +                        name='node2',
    +                        init_method='randn',
    +                        dtype=torch.complex64,
    +                        network=net)
    +        edge = node1[3] ^ node2[1]
    +        result = node1 @ node2
    +
    +        assert len(net.nodes) == 3
    +        assert len(net.leaf_nodes) == 2
    +        assert len(net.resultant_nodes) == 1
    +        assert node1.successors['contract_edges'][(None, node1, node2)].child == result
    +
    +        # Split result
    +        new_node1, new_node2 = result.split(node1_axes=['left', 'input_0'],
    +                                            node2_axes=['input_1', 'right'])
    +
    +        assert new_node1.shape == (10, 2, 5, 10)
    +        assert new_node1['batch'].size() == 10
    +        assert new_node1['left'].size() == 2
    +        assert new_node1['input'].size() == 5
    +        assert new_node1['split'].size() == 10
    +
    +        assert new_node2.shape == (10, 10, 5, 3)
    +        assert new_node2['batch'].size() == 10
    +        assert new_node2['split'].size() == 10
    +        assert new_node2['input'].size() == 5
    +        assert new_node2['right'].size() == 3
    +
    +        assert len(net.nodes) == 5
    +        assert len(net.leaf_nodes) == 2
    +        assert len(net.resultant_nodes) == 3
    +        assert result.successors['split'][(result,
    +                                           tuple(['left', 'input_0']),
    +                                           tuple(['input_1', 'right']),
    +                                           'svd',
    +                                           'left',
    +                                           None,
    +                                           None,
    +                                           None)].child == [new_node1, new_node2]
    +
    +        assert net.edges == [node1['left'], node1['input'],
    +                             node2['input'], node2['right']]
     
         def test_split_in_place(self):
             net = tk.TensorNetwork()
    @@ -1659,7 +1745,7 @@ def test_split_contracted_node_cum_percentage(self):
             assert new_node1.shape == (10, 2, 5, 1)
             assert new_node2.shape == (10, 1, 5, 3)
     
    -    def test_split_contracted_node_paramnodes_cum_percentage(self):
    +    def test_split_contracted_node_paramnode_cum_percentage(self):
             net = tk.TensorNetwork()
             node1 = tk.ParamNode(shape=(10, 2, 5, 4),
                                  axes_names=('batch', 'left', 'input', 'right'),
    @@ -1824,6 +1910,60 @@ def test_split_contracted_node_rank_cum_percentage_cutoff(self):
             
             assert new_node1.shape == (10, 2, 5, 8)
             assert new_node2.shape == (10, 8, 5, 3)
    +    
    +    def test_split_contracted_complex_node(self):
    +        net = tk.TensorNetwork()
    +        node1 = tk.Node(shape=(10, 2, 5, 4),
    +                        axes_names=('batch', 'left', 'input', 'right'),
    +                        name='node1',
    +                        init_method='randn',
    +                        dtype=torch.complex64,
    +                        network=net)
    +        node2 = tk.Node(shape=(10, 4, 5, 3),
    +                        axes_names=('batch', 'left', 'input', 'right'),
    +                        name='node2',
    +                        init_method='randn',
    +                        dtype=torch.complex64,
    +                        network=net)
    +        edge = node1[3] ^ node2[1]
    +        result = node1 @ node2
    +
    +        assert len(net.nodes) == 3
    +        assert len(net.leaf_nodes) == 2
    +        assert len(net.resultant_nodes) == 1
    +        assert node1.successors['contract_edges'][(None, node1, node2)].child == result
    +
    +        # Split result
    +        new_node1, new_node2 = result.split(node1_axes=['left', 'input_0'],
    +                                            node2_axes=['input_1', 'right'],
    +                                            mode='svdr')
    +
    +        assert new_node1.shape == (10, 2, 5, 10)
    +        assert new_node1['batch'].size() == 10
    +        assert new_node1['left'].size() == 2
    +        assert new_node1['input'].size() == 5
    +        assert new_node1['split'].size() == 10
    +
    +        assert new_node2.shape == (10, 10, 5, 3)
    +        assert new_node2['batch'].size() == 10
    +        assert new_node2['split'].size() == 10
    +        assert new_node2['input'].size() == 5
    +        assert new_node2['right'].size() == 3
    +
    +        assert len(net.nodes) == 5
    +        assert len(net.leaf_nodes) == 2
    +        assert len(net.resultant_nodes) == 3
    +        assert result.successors['split'][(result,
    +                                           tuple(['left', 'input_0']),
    +                                           tuple(['input_1', 'right']),
    +                                           'svdr',
    +                                           'left',
    +                                           None,
    +                                           None,
    +                                           None)].child == [new_node1, new_node2]
    +
    +        assert net.edges == [node1['left'], node1['input'],
    +                             node2['input'], node2['right']]
     
         def test_split_in_place(self):
             net = tk.TensorNetwork()
    @@ -6676,7 +6816,7 @@ def test_mps(self, setup_mps):
             image_size = (10, 10)
             mps = MPS(image_size=image_size)
     
    -        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    +        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
             mps = mps.to(device)
     
             # batch_size x height x width
    @@ -6700,7 +6840,7 @@ def test_uniform_mps(self, setup_mps):
             image_size = (10, 10)
             mps = MPS(image_size=image_size, uniform=True)
     
    -        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    +        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
             mps = mps.to(device)
     
             # batch_size x height x width
    @@ -6918,7 +7058,7 @@ def test_peps(self, setup_peps):
             image_size = (3, 3)
             peps = PEPS(image_size=image_size)
     
    -        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    +        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
             peps = peps.to(device)
     
             # batch_size x height x width
    @@ -6942,7 +7082,7 @@ def test_uniform_peps(self, setup_peps):
             image_size = (3, 3)
             peps = PEPS(image_size=image_size, uniform=True)
     
    -        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    +        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
             peps = peps.to(device)
     
             # batch_size x height x width