Skip to content

Commit

Permalink
Add mutual information
Browse files Browse the repository at this point in the history
  • Loading branch information
joserapa98 committed Apr 13, 2024
1 parent ec61d50 commit 175d59a
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 0 deletions.
129 changes: 129 additions & 0 deletions tensorkrowch/models/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,10 @@ def partial_density(self, trace_sites: Sequence[int] = []) -> torch.Tensor:
it should be ``reset`` before calling ``partial_density`` to avoid
undesired behaviour.
Since the density matrix is computed by contracting the MPS, it means
one can take gradients of it with respect to the MPS tensors, if it
is needed.
This method may also alter the attribute :attr:`n_batches` of the
:class:`MPS`.
Expand Down Expand Up @@ -1269,6 +1273,131 @@ def partial_density(self, trace_sites: Sequence[int] = []) -> torch.Tensor:
result = self.forward(marginalize_output=True)

return result

@torch.no_grad()
def mi(self,
middle_site: int,
renormalize: bool = False) -> Union[float, Tuple[float]]:
r"""
Computes the Mutual Information between subsystems :math:`A` and
:math:`B`, :math:`\textrm{MI}(A:B)`, where :math:`A` goes from site
0 to ``middle_site``, and :math:`B` goes from ``middle_site + 1`` to
``n_features - 1``.
To compute the mutual information, the MPS is put into canonical form
with orthogonality center at ``middle_site``. Bond dimensions are not
changed if possible. Only when the bond dimension is bigger than the
physical dimension multiplied by the other bond dimension of the node,
it will be cropped to that size.
If the MPS is not normalized, it may happen that the computation of the
mutual information fails due to errors in the Singular Value
Decompositions. To avoid this, it is recommended to set
``renormalize = True``. In this case, the norm of each node after the
SVD is extracted in logarithmic form, and accumulated. As a result,
the function will return the tuple ``(mi, log_norm)``, which is a sort
of `scaled` mutual information. The actual mutual information could be
obtained as ``exp(log_norm) * mi``.
Parameters
----------
middle_site : int
Position that separates regios :math:`A` and :math:`B`. It should
be between 0 and ``n_features - 2``.
renormalize : bool
Indicates whether nodes should be renormalized after SVD/QR
decompositions. If not, it may happen that the norm explodes as it
is being accumulated from all nodes. Renormalization aims to avoid
this undesired behavior by extracting the norm of each node on a
logarithmic scale after SVD/QR decompositions are computed. Finally,
the normalization factor is evenly distributed among all nodes of
the MPS.
Returns
-------
float or tuple[float, float]
"""
self.reset()

prev_auto_stack = self._auto_stack
self.auto_stack = False

if (middle_site < 0) or (middle_site > (self._n_features - 2)):
raise ValueError(
'`middle_site` should be between 0 and `n_features` - 2')

log_norm = 0

nodes = self._mats_env[:]
if self._boundary == 'obc':
nodes[0].tensor[1:] = torch.zeros_like(
nodes[0].tensor[1:])
nodes[-1].tensor[..., 1:] = torch.zeros_like(
nodes[-1].tensor[..., 1:])

for i in range(middle_site):
result1, result2 = nodes[i]['right'].svd_(
side='right',
rank=nodes[i]['right'].size())

if renormalize:
aux_norm = result2.norm() / sqrt(result2.shape[0])
if not aux_norm.isinf() and (aux_norm > 0):
result2.tensor = result2.tensor / aux_norm
log_norm += aux_norm.log()

result1 = result1.parameterize()
nodes[i] = result1
nodes[i + 1] = result2

for i in range(len(nodes) - 1, middle_site, -1):
result1, result2 = nodes[i]['left'].svd_(
side='left',
rank=nodes[i]['left'].size())

if renormalize:
aux_norm = result1.norm() / sqrt(result1.shape[0])
if not aux_norm.isinf() and (aux_norm > 0):
result1.tensor = result1.tensor / aux_norm
log_norm += aux_norm.log()

result2 = result2.parameterize()
nodes[i] = result2
nodes[i - 1] = result1

nodes[middle_site] = nodes[middle_site].parameterize()

# Compute mutual information
middle_tensor = nodes[middle_site].tensor.clone()
_, s, _ = torch.linalg.svd(
middle_tensor.reshape(middle_tensor.shape[:-1].numel(), # left x input
middle_tensor.shape[-1]), # right
full_matrices=False)

s = s[s > 0]
mutual_info = -(s * (s.log() + log_norm)).sum()

# Rescale
if log_norm != 0:
rescale = (log_norm / len(nodes)).exp()

if renormalize and (log_norm != 0):
for node in nodes:
node.tensor = node.tensor * rescale

# Update variables
if self._boundary == 'obc':
self._bond_dim = [node['right'].size() for node in nodes[:-1]]
else:
self._bond_dim = [node['right'].size() for node in nodes]
self._mats_env = nodes

self.auto_stack = prev_auto_stack

if renormalize:
return mutual_info, log_norm
else:
return mutual_info

@torch.no_grad()
def canonicalize(self,
Expand Down
49 changes: 49 additions & 0 deletions tests/models/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,55 @@ def test_partial_density(self):
for node in mps.mats_env:
assert node.grad is not None

def test_mutual_information(self):
for n_features in [1, 2, 3, 4, 10]:
for boundary in ['obc', 'pbc']:
for middle_site in range(n_features - 1):
bond_dim = torch.randint(low=2, high=10,
size=(n_features,)).tolist()
bond_dim = bond_dim[:-1] if boundary == 'obc' else bond_dim

mps = tk.models.MPS(n_features=n_features,
phys_dim=2,
bond_dim=bond_dim,
boundary=boundary,
in_features=[])

mps_tensor = mps()
assert mps_tensor.shape == (2,) * n_features

mps.out_features = []
example = torch.randn(1, n_features, 2)
mps.trace(example)

if mps.boundary == 'obc':
assert len(mps.leaf_nodes) == n_features + 2
else:
assert len(mps.leaf_nodes) == n_features
assert len(mps.data_nodes) == n_features

# Mutual Information
scaled_mi, log_norm = mps.mi(middle_site=middle_site,
renormalize=True)
mi = mps.mi(middle_site=middle_site,
renormalize=False)

assert all([mps.bond_dim[i] <= bond_dim[i]
for i in range(len(bond_dim))])

assert torch.isclose(mi, log_norm.exp() * scaled_mi)

if mps.boundary == 'obc':
assert len(mps.leaf_nodes) == n_features + 2
else:
assert len(mps.leaf_nodes) == n_features
assert len(mps.data_nodes) == n_features

mps.unset_data_nodes()
mps.in_features = []
approx_mps_tensor = mps()
assert approx_mps_tensor.shape == (2,) * n_features

def test_canonicalize(self):
for n_features in [1, 2, 3, 4, 10]:
for boundary in ['obc', 'pbc']:
Expand Down

0 comments on commit 175d59a

Please sign in to comment.