diff --git a/tensorkrowch/models/mps.py b/tensorkrowch/models/mps.py index 9de5a77..17a575b 100644 --- a/tensorkrowch/models/mps.py +++ b/tensorkrowch/models/mps.py @@ -1202,6 +1202,10 @@ def partial_density(self, trace_sites: Sequence[int] = []) -> torch.Tensor: it should be ``reset`` before calling ``partial_density`` to avoid undesired behaviour. + Since the density matrix is computed by contracting the MPS, it means + one can take gradients of it with respect to the MPS tensors, if it + is needed. + This method may also alter the attribute :attr:`n_batches` of the :class:`MPS`. @@ -1269,6 +1273,131 @@ def partial_density(self, trace_sites: Sequence[int] = []) -> torch.Tensor: result = self.forward(marginalize_output=True) return result + + @torch.no_grad() + def mi(self, + middle_site: int, + renormalize: bool = False) -> Union[float, Tuple[float]]: + r""" + Computes the Mutual Information between subsystems :math:`A` and + :math:`B`, :math:`\textrm{MI}(A:B)`, where :math:`A` goes from site + 0 to ``middle_site``, and :math:`B` goes from ``middle_site + 1`` to + ``n_features - 1``. + + To compute the mutual information, the MPS is put into canonical form + with orthogonality center at ``middle_site``. Bond dimensions are not + changed if possible. Only when the bond dimension is bigger than the + physical dimension multiplied by the other bond dimension of the node, + it will be cropped to that size. + + If the MPS is not normalized, it may happen that the computation of the + mutual information fails due to errors in the Singular Value + Decompositions. To avoid this, it is recommended to set + ``renormalize = True``. In this case, the norm of each node after the + SVD is extracted in logarithmic form, and accumulated. As a result, + the function will return the tuple ``(mi, log_norm)``, which is a sort + of `scaled` mutual information. The actual mutual information could be + obtained as ``exp(log_norm) * mi``. + + Parameters + ---------- + middle_site : int + Position that separates regios :math:`A` and :math:`B`. It should + be between 0 and ``n_features - 2``. + renormalize : bool + Indicates whether nodes should be renormalized after SVD/QR + decompositions. If not, it may happen that the norm explodes as it + is being accumulated from all nodes. Renormalization aims to avoid + this undesired behavior by extracting the norm of each node on a + logarithmic scale after SVD/QR decompositions are computed. Finally, + the normalization factor is evenly distributed among all nodes of + the MPS. + + Returns + ------- + float or tuple[float, float] + """ + self.reset() + + prev_auto_stack = self._auto_stack + self.auto_stack = False + + if (middle_site < 0) or (middle_site > (self._n_features - 2)): + raise ValueError( + '`middle_site` should be between 0 and `n_features` - 2') + + log_norm = 0 + + nodes = self._mats_env[:] + if self._boundary == 'obc': + nodes[0].tensor[1:] = torch.zeros_like( + nodes[0].tensor[1:]) + nodes[-1].tensor[..., 1:] = torch.zeros_like( + nodes[-1].tensor[..., 1:]) + + for i in range(middle_site): + result1, result2 = nodes[i]['right'].svd_( + side='right', + rank=nodes[i]['right'].size()) + + if renormalize: + aux_norm = result2.norm() / sqrt(result2.shape[0]) + if not aux_norm.isinf() and (aux_norm > 0): + result2.tensor = result2.tensor / aux_norm + log_norm += aux_norm.log() + + result1 = result1.parameterize() + nodes[i] = result1 + nodes[i + 1] = result2 + + for i in range(len(nodes) - 1, middle_site, -1): + result1, result2 = nodes[i]['left'].svd_( + side='left', + rank=nodes[i]['left'].size()) + + if renormalize: + aux_norm = result1.norm() / sqrt(result1.shape[0]) + if not aux_norm.isinf() and (aux_norm > 0): + result1.tensor = result1.tensor / aux_norm + log_norm += aux_norm.log() + + result2 = result2.parameterize() + nodes[i] = result2 + nodes[i - 1] = result1 + + nodes[middle_site] = nodes[middle_site].parameterize() + + # Compute mutual information + middle_tensor = nodes[middle_site].tensor.clone() + _, s, _ = torch.linalg.svd( + middle_tensor.reshape(middle_tensor.shape[:-1].numel(), # left x input + middle_tensor.shape[-1]), # right + full_matrices=False) + + s = s[s > 0] + mutual_info = -(s * (s.log() + log_norm)).sum() + + # Rescale + if log_norm != 0: + rescale = (log_norm / len(nodes)).exp() + + if renormalize and (log_norm != 0): + for node in nodes: + node.tensor = node.tensor * rescale + + # Update variables + if self._boundary == 'obc': + self._bond_dim = [node['right'].size() for node in nodes[:-1]] + else: + self._bond_dim = [node['right'].size() for node in nodes] + self._mats_env = nodes + + self.auto_stack = prev_auto_stack + + if renormalize: + return mutual_info, log_norm + else: + return mutual_info @torch.no_grad() def canonicalize(self, diff --git a/tests/models/test_mps.py b/tests/models/test_mps.py index ce63dd7..41a054d 100644 --- a/tests/models/test_mps.py +++ b/tests/models/test_mps.py @@ -1211,6 +1211,55 @@ def test_partial_density(self): for node in mps.mats_env: assert node.grad is not None + def test_mutual_information(self): + for n_features in [1, 2, 3, 4, 10]: + for boundary in ['obc', 'pbc']: + for middle_site in range(n_features - 1): + bond_dim = torch.randint(low=2, high=10, + size=(n_features,)).tolist() + bond_dim = bond_dim[:-1] if boundary == 'obc' else bond_dim + + mps = tk.models.MPS(n_features=n_features, + phys_dim=2, + bond_dim=bond_dim, + boundary=boundary, + in_features=[]) + + mps_tensor = mps() + assert mps_tensor.shape == (2,) * n_features + + mps.out_features = [] + example = torch.randn(1, n_features, 2) + mps.trace(example) + + if mps.boundary == 'obc': + assert len(mps.leaf_nodes) == n_features + 2 + else: + assert len(mps.leaf_nodes) == n_features + assert len(mps.data_nodes) == n_features + + # Mutual Information + scaled_mi, log_norm = mps.mi(middle_site=middle_site, + renormalize=True) + mi = mps.mi(middle_site=middle_site, + renormalize=False) + + assert all([mps.bond_dim[i] <= bond_dim[i] + for i in range(len(bond_dim))]) + + assert torch.isclose(mi, log_norm.exp() * scaled_mi) + + if mps.boundary == 'obc': + assert len(mps.leaf_nodes) == n_features + 2 + else: + assert len(mps.leaf_nodes) == n_features + assert len(mps.data_nodes) == n_features + + mps.unset_data_nodes() + mps.in_features = [] + approx_mps_tensor = mps() + assert approx_mps_tensor.shape == (2,) * n_features + def test_canonicalize(self): for n_features in [1, 2, 3, 4, 10]: for boundary in ['obc', 'pbc']: