diff --git a/README.md b/README.md index 59993aa..7cf5e05 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,8 @@ model = PerceiverIO( latent_heads = 8, # number of heads for latent self attention, 8 cross_dim_head = 64, # number of dimensions per cross attention head latent_dim_head = 64, # number of dimensions per latent self attention head - weight_tie_layers = False # whether to weight tie layers (optional, as indicated in the diagram) + weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram) + seq_dropout_prob = 0.2 # fraction of the tokens from the input sequence to dropout (structured dropout, for saving compute and regularizing effects) ) seq = torch.randn(1, 512, 32) diff --git a/perceiver_pytorch/perceiver_io.py b/perceiver_pytorch/perceiver_io.py index 7eb79ee..4fa7919 100644 --- a/perceiver_pytorch/perceiver_io.py +++ b/perceiver_pytorch/perceiver_io.py @@ -28,6 +28,28 @@ def cached_fn(*args, _cache = True, **kwargs): return cache return cached_fn +# structured dropout, more effective than traditional attention dropouts + +def dropout_seq(seq, mask, dropout): + b, n, *_, device = *seq.shape, seq.device + logits = torch.randn(b, n, device = device) + + if exists(mask): + logits = logits.masked_fill(~mask, -torch.finfo(logits.dtype).max) + + num_keep = max(1, int((1 - dropout) * n)) + keep_indices = logits.topk(num_keep, dim = 1).indices + + batch_indices = torch.arange(b, device = device) + batch_indices = rearrange(batch_indices, 'b -> b 1') + + seq = seq[batch_indices, keep_indices] + + if exists(mask): + mask = mask[batch_indices, keep_indices] + + return seq, mask + # helper classes class PreNorm(nn.Module): @@ -117,9 +139,12 @@ def __init__( cross_dim_head = 64, latent_dim_head = 64, weight_tie_layers = False, - decoder_ff = False + decoder_ff = False, + seq_dropout_prob = 0. ): super().__init__() + self.seq_dropout_prob = seq_dropout_prob + self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) self.cross_attend_blocks = nn.ModuleList([ @@ -157,6 +182,11 @@ def forward( cross_attn, cross_ff = self.cross_attend_blocks + # structured dropout (as done in perceiver AR https://arxiv.org/abs/2202.07765) + + if self.training and self.seq_dropout_prob > 0.: + data, mask = dropout_seq(data, mask, self.seq_dropout_prob) + print(data.shape, mask.shape) # cross attention only happens once for Perceiver IO x = cross_attn(x, context = data, mask = mask) + x diff --git a/setup.py b/setup.py index 6370c62..b9a375d 100644 --- a/setup.py +++ b/setup.py @@ -3,9 +3,10 @@ setup( name = 'perceiver-pytorch', packages = find_packages(), - version = '0.8.3', + version = '0.8.4', license='MIT', description = 'Perceiver - Pytorch', + long_description_content_type = 'text/markdown', author = 'Phil Wang', author_email = 'lucidrains@gmail.com', url = 'https://github.com/lucidrains/perceiver-pytorch',