# Test for GeneralizedPermutation
+
+def test_generalizedpermutation(*, length, alphabet_size, init_kwargs):
+= init_kwargs["group_sizes"]
+ group_sizes = sum(group_sizes)
+ n_samples
+= torch.randn(n_samples, length, alphabet_size)
+ x = GeneralizedPermutation(**init_kwargs)
+ perm = perm()
+ mats = MatrixApply(group_sizes)
+ mat_apply = mat_apply(x, mats=mats)
+ y
+assert y.shape == x.shape
+ assert y.requires_grad
+
+
+ perm.hard_()assert perm.mode == "hard"
+
+
+
+ test_generalizedpermutation(=5,
+ length=10,
+ alphabet_size={
+ init_kwargs"group_sizes": [3, 2, 4],
+ "fixed_pairings": [[(0, 1)], [(0, 0)], [(1, 0), (2, 3)]],
+ "tau": 0.1,
+
+ }
+ )
+
+def test_batch_perm(shape: tuple[int, int, int, int]):
+= torch.randn(*shape)
+ perms = torch.randn(shape[-2], shape[-1])
+ x
+= perms.argmax(-1)
+ argmax = x[argmax]
+ x_permuted_rows = argmax.view(*argmax.shape[:-1], 1, -1).expand_as(perms)
+ index = torch.gather(x_permuted_rows, -1, index)
+ output
+= torch.stack([
+ expected
+ torch.stack([for j in range(shape[1])
+ x[argmax[i, j], :][:, argmax[i, j]] =0) for i in range(shape[0])
+ ], dim=0)
+ ], dim
+assert torch.equal(output, expected)
+
+
+2, 5, 4, 4)) test_batch_perm((
base
+Type aliases
+= list # List indexed by bootstrap iteration
+ BootstrapList = list # List indexed by gradient descent iteration
+ GradientDescentList = list # List indexed by group index
+ GroupByGroupList
+= tuple[int, int] # Pair of indices
+ IndexPair = list[IndexPair] # Pairs of indices in a group of sequences
+ IndexPairsInGroup = list[IndexPairsInGroup] # Pairs of indices in groups of sequences IndexPairsInGroups
+ +
make_pbar
++++make_pbar (epochs:int, show_pbar:bool)
+ +
dccn
++++dccn (x:torch.Tensor)
+ +
DiffPaSSResults
++++DiffPaSSResults (log_alphas:Union[list[list[numpy.ndarray]],list[list[lis + t[numpy.ndarray]]],NoneType], soft_perms:Union[list[list + [numpy.ndarray]],list[list[list[numpy.ndarray]]],NoneTyp + e], hard_perms:Union[list[list[numpy.ndarray]],list[list + [list[numpy.ndarray]]]], hard_losses:Union[list[list[num + py.ndarray]],list[list[list[numpy.ndarray]]]], soft_loss + es:Union[list[list[numpy.ndarray]],list[list[list[numpy. + ndarray]]],NoneType])
Container for results of DiffPaSS fits.
++ +
DiffPaSSModel
++++DiffPaSSModel (*args, **kwargs)
Base class for DiffPaSS models.
++ +
DiffPaSSModel.fit
++++DiffPaSSModel.fit (x:torch.Tensor, y:torch.Tensor, epochs:int=1, + optimizer_name:Optional[str]='SGD', + optimizer_kwargs:Optional[dict[str,Any]]=None, + mean_centering:bool=False, show_pbar:bool=False, + compute_final_soft:bool=False, + record_log_alphas:bool=False, + record_soft_perms:bool=False, + record_soft_losses:bool=False)
Fit permutations to data using gradient descent.
++ | Type | +Default | +Details | +
---|---|---|---|
x | +Tensor | ++ | The object (MSA or adjacency matrix of graphs) to be permuted | +
y | +Tensor | ++ | The target object (MSA or adjacency matrix of graphs), that the objects represented by x should be paired with. Not acted upon by soft/hard permutations |
+
epochs | +int | +1 | ++ |
optimizer_name | +typing.Optional[str] | +SGD | ++ |
optimizer_kwargs | +typing.Optional[dict[str, typing.Any]] | +None | ++ |
mean_centering | +bool | +False | ++ |
show_pbar | +bool | +False | ++ |
compute_final_soft | +bool | +False | ++ |
record_log_alphas | +bool | +False | ++ |
record_soft_perms | +bool | +False | ++ |
record_soft_losses | +bool | +False | ++ |
Returns | +DiffPaSSResults | ++ | + |
+ +
DiffPaSSModel.fit_bootstrap
++++DiffPaSSModel.fit_bootstrap (x:torch.Tensor, y:torch.Tensor, + n_start:int=1, n_end:Optional[int]=None, + step_size:int=1, show_pbar:bool=True, + single_fit_cfg:Optional[dict]=None)
*Fit permutations to data using the DiffPaSS bootstrap.
+The DiffPaSS bootstrap consists of a sequence of short gradient descent runs (default: one epoch per run). At the end of each run, a subset of the found pairings is chosen uniformly at random and fixed for the next run. The number of pairings fixed at each iteration ranges between n_start
(default: 1) and n_end
(default: total number of pairs), with a step size of step_size
.*
+ | Type | +Default | +Details | +
---|---|---|---|
x | +Tensor | ++ | The object (MSA or adjacency matrix of graphs) to be permuted | +
y | +Tensor | ++ | The target object (MSA or adjacency matrix of graphs), that the objects represented by x should be paired with. Not acted upon by soft/hard permutations |
+
n_start | +int | +1 | ++ |
n_end | +typing.Optional[int] | +None | ++ |
step_size | +int | +1 | ++ |
show_pbar | +bool | +True | ++ |
single_fit_cfg | +typing.Optional[dict] | +None | ++ |
Returns | +DiffPaSSResults | ++ | + |