diff --git a/tests/test_attentions.py b/tests/test_attentions.py deleted file mode 100644 index ba8c7fd282..0000000000 --- a/tests/test_attentions.py +++ /dev/null @@ -1,505 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import math -from typing import Tuple - -import pytest -import torch - -from xformers.components import ( - InputProjection, - InputProjectionConfig, - MultiHeadDispatch, -) - -# Automatically test all the registered attentions -from xformers.components.attention import ( - _DENSITY_THRESHOLD, - ATTENTION_REGISTRY, - build_attention, -) - -DEVICES = ( - [torch.device("cpu")] if not torch.cuda.is_available() else [torch.device("cuda")] -) - -BATCH = 2 -SEQ = 128 if torch.cuda.is_available() else 36 -MODEL = 128 if torch.cuda.is_available() else 16 -GLOBAL_ATTENTION_RATIO = ( - _DENSITY_THRESHOLD * 0.9 -) # Make sure that we test the sparse implementation, no matter the threshold - -assert ATTENTION_REGISTRY.keys(), "Attention layers should have been registered" - -_non_order_invariant_attentions = ["visual", "pooling"] - - -def _get_multihead( - attention_name, - attn_dropout, - res_dropout, - causal, - heads, - device, - skip_output_projection=False, - use_separate_proj_weights=True, -): - test_config = { - "name": attention_name, - "dropout": attn_dropout, - "causal": causal, - "seq_len": SEQ, - "window_size": SEQ // 8 + 1, # local attention - "attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO, - "dim_model": MODEL, - "num_heads": heads, - "dim_head": MODEL / heads, - "num_rules": 2, # Compositional Attention - "r": 0.5, # random attention, ratio of tokens that the attention can attend to - } - - if skip_output_projection: - - def noop(x): - return x - - test_config["out_proj"] = noop - - attention = build_attention(test_config) - - # build a multi head dispatch to test this attention mechanism - multi_head = MultiHeadDispatch( - seq_len=SEQ, - dim_model=MODEL, - residual_dropout=res_dropout, - num_heads=heads, - attention=attention, - use_separate_proj_weight=use_separate_proj_weights, - ).to(device) - - return multi_head - - -@pytest.mark.parametrize("attn_dropout", [0.0, 0.3]) -@pytest.mark.parametrize("residual_dropout", [0.0, 0.1]) -@pytest.mark.parametrize("causal", [True, False]) -@pytest.mark.parametrize("heads", [1, 4]) -@pytest.mark.parametrize( - "attention_name", ATTENTION_REGISTRY.keys() - _non_order_invariant_attentions -) -@pytest.mark.parametrize("device", DEVICES) -def test_order_invariance( - attention_name: str, - heads: int, - attn_dropout: float, - residual_dropout: float, - causal: bool, - device: torch.device, -): - if ( - torch.version.hip - and device == torch.device("cuda") - and attention_name == "local" - ): - # Backend calls into Sputnik library which isn't built on ROCm - device = torch.device("cpu") - - torch.manual_seed(42) - torch.cuda.manual_seed_all(42) - - multi_head = _get_multihead( - attention_name, - attn_dropout, - residual_dropout, - causal, - heads, - device, - use_separate_proj_weights=False, - ) - - if ( - int(math.sqrt(SEQ)) ** 2 != SEQ - and multi_head.attention.requires_squared_context - ): - pytest.skip(f"{attention_name} requires squared sequence lengths") - - # Check that we can pass a smaller sequence - seqs = ( - [SEQ, SEQ // 2] - if not multi_head.attention.requires_same_k_q_dimensions - else [SEQ] - ) - - for seq in seqs: - # Check that the attention is invariant to a permutation of K, V - inputs = torch.rand(BATCH, seq, MODEL, device=device) - shuffle = torch.randperm(inputs.shape[1]) - inputs_shuffled = inputs[:, shuffle, :].clone() - - results = multi_head(inputs, inputs, inputs) - results_shuffled = multi_head(inputs, inputs_shuffled, inputs_shuffled) - torch.allclose(results, results_shuffled) - - # Check that the attention is equivariant to a permutation of Q, - # meaning that the result is permuted in the same way - results_shuffled = multi_head(inputs_shuffled, inputs, inputs) - torch.allclose(results[:, shuffle, :], results_shuffled) - - # Check that dropout actually drops some values - if attn_dropout > 0: - att_1 = multi_head(inputs, inputs_shuffled, inputs) - att_2 = multi_head(inputs, inputs_shuffled, inputs) - assert (att_1 != att_2).any() - - # Test AMP, if available - if device.type == "cuda": - with torch.amp.autocast("cuda", enabled=True): - _ = multi_head(inputs, inputs_shuffled, inputs) - - -@pytest.mark.parametrize("heads", [1, 4]) -@pytest.mark.parametrize("attention_name", ["scaled_dot_product"]) -@pytest.mark.parametrize("device", DEVICES) -def test_kqv_ordering( - attention_name: str, - heads: int, - device: torch.device, -): - multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) - - # Check kqv are not flipped - # this will not catch all issues, but would catch a V being misplaced - # make k and q complimentary, so that QKt is all zero and attention is uniform - - q = torch.cat( - ( - torch.rand((1, MODEL // 2), device=device), - torch.zeros((1, MODEL // 2), device=device), - ), - dim=1, - ).expand((BATCH, SEQ, MODEL)) - - k = torch.cat( - ( - torch.zeros((1, MODEL // 2), device=device), - torch.rand((1, MODEL // 2), device=device), - ), - dim=1, - ).expand((BATCH, SEQ, MODEL)) - v = torch.rand(BATCH, SEQ, MODEL, device=device) - - # Normal call - res = multi_head(query=q, key=k, value=v) - for i in range(BATCH): - assert torch.allclose(res[i, :, :], res[i, 0, :].unsqueeze(-2)) - - assert not torch.allclose(res[0, :, :], res[1, :, :]) - - # Flip qkv, and check that we invert the above check properly - res_false = multi_head(query=v, key=k, value=q) - assert torch.allclose(res_false[0, :, :], res_false[1, :, :]) - - -@pytest.mark.parametrize("heads", [1, 4]) -@pytest.mark.parametrize("attention_name", ["scaled_dot_product"]) -@pytest.mark.parametrize("device", DEVICES) -def test_different_seqlen( - attention_name: str, - heads: int, - device: torch.device, -): - multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) - - # Check kqv are not flipped - # this will not catch all issues, but would catch a V being misplaced - # make k and q complimentary, so that QKt is all zero and attention is uniform - - q = torch.cat( - ( - torch.rand((1, MODEL // 2), device=device), - torch.zeros((1, MODEL // 2), device=device), - ), - dim=1, - ).expand((BATCH, SEQ, MODEL)) - - k = torch.cat( - ( - torch.zeros((1, MODEL // 2), device=device), - torch.rand((1, MODEL // 2), device=device), - ), - dim=1, - ).expand((BATCH, SEQ, MODEL)) - v = torch.rand(BATCH, SEQ, MODEL, device=device) - - # Normal call - res = multi_head(query=q, key=k, value=v) - - # Changing sequence length by dividing by two to simulate differing sequence length - q2 = torch.cat( - ( - torch.rand((1, MODEL // 2), device=device), - torch.zeros((1, MODEL // 2), device=device), - ), - dim=1, - ).expand((BATCH, SEQ // 2, MODEL)) - - k2 = torch.cat( - ( - torch.zeros((1, MODEL // 2), device=device), - torch.rand((1, MODEL // 2), device=device), - ), - dim=1, - ).expand((BATCH, SEQ // 2, MODEL)) - - v2 = torch.rand(BATCH, SEQ // 2, MODEL, device=device) - - res2 = multi_head(query=q2, key=k2, value=v2) - - assert res.shape != res2.shape - - -@pytest.mark.parametrize("proj_bias", [False, True]) -@pytest.mark.parametrize("same_sizes", [False, True]) -@pytest.mark.parametrize("same_settings", [False, True]) -def test_inproj(proj_bias: bool, same_sizes: bool, same_settings: bool): - - test_config = { - "name": "scaled_dot_product", - "dropout": 0.1, - "causal": False, - "seq_len": SEQ, - "window_size": SEQ // 8 + 1, - "num_heads": 1, - "dim_head": MODEL, - } - - attention = build_attention(test_config) - - # Construct the initial projection, test different options - in_params = InputProjectionConfig(MODEL, MODEL, proj_bias) - - if same_settings: - in_proj = InputProjection(in_params, None, None) - out_features = MODEL - else: - out_features = MODEL if same_sizes else MODEL // 2 - in_params_flip = InputProjectionConfig(MODEL, out_features, proj_bias) - in_proj = InputProjection( - in_params_flip, # Q proj - in_params_flip, # K proj - in_params, # V proj - ) - - # build a multi head dispatch to test this attention mechanism - multi_head = MultiHeadDispatch( - seq_len=SEQ, - dim_model=MODEL, - residual_dropout=0.1, - num_heads=1, - attention=attention, - in_proj_container=in_proj, - dim_key=out_features, - dim_value=MODEL, - ) - - # Check kqv are not flipped - # this will not catch all issues, but would catch a V being misplaced - # make k and q complimentary, so that QKt is all zero and attention is uniform - - q = torch.cat( - ( - torch.rand((1, MODEL // 2)), - torch.zeros((1, MODEL // 2)), - ), - dim=1, - ).expand((BATCH, SEQ, MODEL)) - - k = torch.cat( - ( - torch.zeros((1, MODEL // 2)), - torch.rand((1, MODEL // 2)), - ), - dim=1, - ).expand((BATCH, SEQ, MODEL)) - v = torch.rand(BATCH, SEQ, MODEL) - - # just check that a FW does not assert out - _ = multi_head(query=q, key=k, value=v) - - -@pytest.mark.parametrize("heads", [1, 4]) -@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys()) -@pytest.mark.parametrize("device", DEVICES) -def test_different_kq_dimensions( - attention_name: str, - heads: int, - device: torch.device, -): - - multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) - - if multi_head.attention.requires_same_k_q_dimensions: - # pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. - pytest.skip(f"{attention_name} does not support different k, q dimensions yet.") - - seq_q = SEQ // 2 - q = torch.rand((BATCH, seq_q, MODEL), device=device) - k = torch.rand((BATCH, SEQ, MODEL), device=device) - v = torch.rand((BATCH, SEQ, MODEL), device=device) - - res = multi_head(query=q, key=k, value=v) - assert res.shape == torch.Size([BATCH, seq_q, MODEL]) - - -@pytest.mark.parametrize("heads", [1, 4]) -@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys()) -@pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize( - "batch_sizes", - [ - (1, BATCH, BATCH), - (BATCH, 1, BATCH), - (BATCH, BATCH, 1), - (1, 1, BATCH), - (BATCH, 1, 1), - (1, BATCH, 1), - ], -) -def test_broadcast_batch_dimension( - attention_name: str, - heads: int, - device: torch.device, - batch_sizes: Tuple[int, int, int], -): - Q_BATCH, K_BATCH, V_BATCH = batch_sizes - multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) - - if ( - int(math.sqrt(SEQ)) ** 2 != SEQ - and multi_head.attention.requires_squared_context - ): - pytest.skip(f"{attention_name} requires squared sequence lengths") - - if multi_head.attention.requires_same_k_q_dimensions: - # pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. - pytest.skip(f"{attention_name} does not support different k, q dimensions yet.") - - q = torch.rand((Q_BATCH, SEQ, MODEL), device=device) - k = torch.rand((K_BATCH, SEQ, MODEL), device=device) - v = torch.rand((V_BATCH, SEQ, MODEL), device=device) - - res = multi_head(query=q, key=k, value=v) - assert res.shape == torch.Size([BATCH, SEQ, MODEL]) - - -@pytest.mark.parametrize("heads", [1, 4]) -@pytest.mark.parametrize("attention_name", ["scaled_dot_product", "favor"]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires a CUDA gpu") -def test_causal( - attention_name: str, - heads: int, -): - """ - Make sure that the causal flag is respected. - The input data is orthogonal by design if causal is respected, but if the attention looks ahead this will fail - """ - - torch.random.manual_seed(42) - - device = torch.device("cuda") - - multi_head = _get_multihead( - attention_name, - 0.0, - 0.0, - causal=True, - heads=heads, - device=device, - skip_output_projection=True, - ) - - k = ( - torch.tril(torch.ones((SEQ, SEQ), device=device), diagonal=0) - .unsqueeze(0) - .expand(1, -1, -1) - ) - q = ( - torch.triu(torch.ones((SEQ, SEQ), device=device), diagonal=0) - .unsqueeze(0) - .expand(1, -1, -1) - ) - v = ( - torch.arange(SEQ, device=device) - .float() - .unsqueeze(0) - .unsqueeze(-1) - .expand(1, -1, SEQ) - ) - - # Make sure that we don´t project, to keep the embeddings orthogonal - multi_head.attention.requires_input_projection = False - - res = multi_head(query=q, key=k, value=v).squeeze(0) - - # Consolidate along the embedding, if causal was respected the amplitude should be sorted already - res_sum = torch.sum(res, dim=1).cpu() - - assert torch.allclose(torch.sort(res_sum)[1], torch.arange(SEQ)) or torch.allclose( - torch.sort(res_sum, descending=True)[1], torch.arange(SEQ) - ), res_sum - - -@pytest.mark.parametrize("attn_dropout", [0.0, 0.1]) -@pytest.mark.parametrize("heads", [2]) -@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys()) -@pytest.mark.skipif(torch.cuda.is_available(), reason="CUDA gpu not supported yet") -def test_torch_script_ability( - attention_name: str, - heads: int, - attn_dropout: float, -): - if attention_name in {"favor", "global", "local", "random"}: - # pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. - pytest.skip(f"{attention_name} does not support scripting yet.") - - device = torch.device("cpu") - - multi_head = _get_multihead(attention_name, attn_dropout, 0.0, False, heads, device) - - if ( - int(math.sqrt(SEQ)) ** 2 != SEQ - and multi_head.attention.requires_squared_context - ): - pytest.skip(f"{attention_name} requires squared sequence lengths") - - # input for tracing the function - q = torch.rand((BATCH, SEQ, MODEL), device=device) - k = torch.rand((BATCH, SEQ, MODEL), device=device) - v = torch.rand((BATCH, SEQ, MODEL), device=device) - - # to make sure dropout behaves deterministically - torch.random.manual_seed(42) - # tracing the attention module - traced_multi_head = torch.jit.trace(multi_head, (q, k, v)) - - # create new random inputs for testing the eager model and traced model - q = torch.rand((BATCH, SEQ, MODEL), device=device) - k = torch.rand((BATCH, SEQ, MODEL), device=device) - v = torch.rand((BATCH, SEQ, MODEL), device=device) - - # to make sure dropout behaves deterministically need to set the seed again - torch.random.manual_seed(42) - res = multi_head(query=q, key=k, value=v) - - # to make sure dropout behaves deterministically need to set the seed again - torch.random.manual_seed(42) - res_traced = traced_multi_head(query=q, key=k, value=v) - - assert torch.allclose(res, res_traced) - - -# TODO: way more unit tests.. diff --git a/tests/test_compositional_attention.py b/tests/test_compositional_attention.py deleted file mode 100644 index 615a1b1d0e..0000000000 --- a/tests/test_compositional_attention.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import pytest -import torch - -from xformers.components import MultiHeadDispatch - -# Automatically test all the registered attentions -from xformers.components.attention import ATTENTION_REGISTRY, build_attention - -DEVICES = ( - [torch.device("cpu")] if not torch.cuda.is_available() else [torch.device("cuda")] -) - -BATCH = 2 -SEQ = 128 if torch.cuda.is_available() else 16 -MODEL = 128 if torch.cuda.is_available() else 32 - -assert ATTENTION_REGISTRY.keys(), "Attention layers should have been registered" - - -@pytest.mark.parametrize("heads", [4]) -@pytest.mark.parametrize("attn_dropout", [0.0, 0.3]) -@pytest.mark.parametrize("causal", [True, False]) -@pytest.mark.parametrize("rules", [4]) -@pytest.mark.parametrize("q_compose", [False, True]) -@pytest.mark.parametrize("dim_selection", [MODEL // 2, None]) -@pytest.mark.parametrize("num_rules", [2]) -@pytest.mark.parametrize("qk_rule", [True, False]) -@pytest.mark.parametrize("nonlinear", [True, False]) -@pytest.mark.parametrize("device", DEVICES) -def test_build_and_run( - heads: int, - attn_dropout: float, - causal: bool, - rules: int, - q_compose: bool, - dim_selection: int, - num_rules: int, - qk_rule: bool, - nonlinear: bool, - device: torch.device, -): - - torch.manual_seed(42) - - test_config = { - "name": "compositional", - "dropout": attn_dropout, - "causal": causal, - "seq_len": SEQ, - "dim_model": MODEL, - "num_heads": heads, - "num_rules": num_rules, - "q_compose": q_compose, - "rules": rules, - "dim_selection": dim_selection, - "qk_rule": qk_rule, - "nonlinear": nonlinear, - } - - attention = build_attention(test_config) - - # build a multi head dispatch to test this attention mechanism - multi_head = MultiHeadDispatch( - seq_len=SEQ, - dim_model=MODEL, - num_heads=heads, - attention=attention, - residual_dropout=0.0, - ).to(device) - - # Check that a shuffled input produces the same results - seqs = [SEQ, SEQ // 2] - - for seq in seqs: - # Check that we can pass a smaller sequence - inputs = torch.rand(BATCH, seq, MODEL, device=device) - shuffle = torch.randperm(inputs.shape[1]) - inputs_shuffled = inputs[:, shuffle, :].clone() - - results = multi_head(inputs, inputs, inputs) - results_shuffled = multi_head(inputs_shuffled, inputs_shuffled, inputs_shuffled) - - if attn_dropout == 0.0 and num_rules == 1 and not causal: - assert (results[:, shuffle, :] - results_shuffled).abs().max() < 1e-3 - - # Test the non-self-attention codepath - att = multi_head(inputs, inputs_shuffled, inputs) - - # Check that dropout actually drops some values - if attn_dropout > 0: - att_2 = multi_head(inputs, inputs_shuffled, inputs) - assert (att != att_2).any() diff --git a/tests/test_embedding.py b/tests/test_embedding.py deleted file mode 100644 index 4bf6e3ef6b..0000000000 --- a/tests/test_embedding.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import pytest -import torch - -from xformers.components import PatchEmbeddingConfig, build_patch_embedding -from xformers.components.positional_embedding import ( - POSITION_EMBEDDING_REGISTRY, - build_positional_embedding, -) - -BATCH = 20 -SEQ = 512 -MODEL = 384 -assert ( - POSITION_EMBEDDING_REGISTRY.keys() -), "Positional encoding layers should have been registered" - - -@pytest.mark.parametrize("encoding_name", POSITION_EMBEDDING_REGISTRY.keys()) -@pytest.mark.parametrize("dropout", [0.0, 0.2]) -def test_dimensions(encoding_name: str, dropout: float): - test_config = { - "name": encoding_name, - "dim_model": MODEL, - "vocab_size": 32, - "dropout": dropout, - "seq_len": SEQ, - } - - # dummy, just check construction and dimensions in the FW pass - encoding = build_positional_embedding(test_config) - inputs = (torch.rand(BATCH, SEQ) * 10).abs().to(torch.int) - _ = encoding(inputs) - - # Test that inputs having an embedding dimension would also work out - if "name" == "sine": - inputs = (torch.rand(BATCH, SEQ, MODEL) * 10).abs().to(torch.int) - _ = encoding(inputs) - - -def test_patch_embedding(): - patch_embedding_config = { - "in_channels": 3, - "out_channels": 64, - "kernel_size": 7, - "stride": 4, - "padding": 2, - } - - # dummy, just check construction and dimensions in the FW pass - patch_emb = build_patch_embedding(PatchEmbeddingConfig(**patch_embedding_config)) - - # Check BHWC - inputs = torch.rand(BATCH, 32 * 32, 3) - out = patch_emb(inputs) - assert out.shape[-1] == 64 - - # Check BCHW - inputs = torch.rand(BATCH, 3, 32, 32) - out = patch_emb(inputs) - assert out.shape[-1] == 64 diff --git a/tests/test_favor.py b/tests/test_favor.py deleted file mode 100644 index 44213b0f44..0000000000 --- a/tests/test_favor.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import math - -import pytest -import torch - -from xformers.components.attention import FavorAttention, ScaledDotProduct -from xformers.components.attention.feature_maps import ( - FeatureMapType, - NormDistribution, - SMHyperbolic, - SMOrf, - SMReg, -) - -_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - - -@pytest.mark.parametrize("features", [SMOrf, SMHyperbolic, SMReg]) -def test_random_matrix(features): - torch.random.manual_seed(0) - - DRAWS = 100 - DIM = 10 - for _ in range(DRAWS): - q = features._get_random_ortho_matrix( - 1, DIM, device=_device, norm_distribution=NormDistribution.Xi - ).squeeze(0) - - # Check that the matrix is indeed orthonormal - torch.allclose( - torch.diag(q @ q.transpose(0, 1)), - torch.diag(torch.ones(10, device=_device)), - ) - - # Check that the row norm is in the right ballpark (sqrt(dim)) - assert abs(torch.mean(torch.norm(q, dim=1)).item() - math.sqrt(DIM)) < 1.0 - - -def _plot_distribution(ortho_feature_map): - # Debug helper, check the uniformity of the random matrix draws - DRAWS = 1000 - DIM = 50 - q = ortho_feature_map._get_random_ortho_matrix(DRAWS, DIM, device=_device) - x, y = [], [] - - for qq in q: - # For every matrix, look at the real and imaginary eigen value - e = torch.linalg.eigvals(qq) - x.append(e.real) - y.append(e.imag) - - # Ideally the repartition of the real and imaginary eigenvalues - # should build a circle in the complex plane - import matplotlib.pyplot as plt - import seaborn as sns - - sns.kdeplot(x=torch.cat(x).cpu().numpy(), y=torch.cat(y).cpu().numpy()) - plt.axis("equal") - plt.savefig("kde.png") - - -def _get_rng_data(device): - emb = 10 - batch_size = 2 - seq_len = 20 - num_heads = 1 - - shape = (batch_size * num_heads, seq_len, emb) - return torch.randn(shape, device=device) - - -def test_feature_map_shape(): - # Check the delayed initialization of the feature map - nb_random_features = 1000 - batch = _get_rng_data(_device) - att = FavorAttention( - dropout=0.0, - dim_features=nb_random_features, - feature_map_type=FeatureMapType.SMOrf, - ) - _ = att(batch, batch, batch) - - assert att.feature_map.features.shape[0] == batch.shape[-1] - assert att.feature_map.features.shape[1] == nb_random_features - - -def test_feature_map_redraw(): - # Check the delayed initialization of the feature map - nb_random_features = 1000 - batch = _get_rng_data(_device) - - def check(should_redraw: bool): - att = FavorAttention( - dropout=0.0, - dim_features=nb_random_features, - feature_map_type=FeatureMapType.SMOrf, - iter_before_redraw=1 if should_redraw else 100, - ) - v0 = att(batch, batch, batch) - assert att.feature_map is not None - - f0 = att.feature_map.features - - v1 = att(batch, batch, batch) - f1 = att.feature_map.features - - # There should not have been a redraw after v0 - assert should_redraw != torch.allclose(v0, v1) - assert should_redraw != torch.allclose(f0, f1) # type: ignore - - check(should_redraw=True) - check(should_redraw=False) - - -@pytest.mark.parametrize("feature", ["sm_orf", "sm_hyp", "sm_reg"]) -@pytest.mark.parametrize("causal", [True, False]) -@pytest.mark.parametrize("normalize_inputs", [True, False]) -@pytest.mark.parametrize("device", [_device]) -def test_favor_approximation_accuracy(feature, causal, normalize_inputs, device): - # Run two attentions in parallel, the normal scaled dot product and the favor approximation - - torch.random.manual_seed(0) - query, key, value = ( - _get_rng_data(device), - _get_rng_data(device), - _get_rng_data(device), - ) - - for x in (query, key, value): - x.requires_grad = True - - # Build the two attention heads - sdp_attention = ScaledDotProduct(dropout=0.0, causal=causal).to(device) - approx_attention = FavorAttention( - dropout=0.0, - causal=causal, - dim_head=10, - feature_map_type=FeatureMapType(feature), - normalize_inputs=normalize_inputs, - ).to(device) - - with torch.amp.autocast("cuda", enabled=_device.type == "cuda"): - standard_attention_result = sdp_attention(query, key, value) - approx_attention_result = approx_attention(query, key, value) - - mismatch = torch.mean( - (standard_attention_result - approx_attention_result) ** 2 - ).item() - - if causal: - # FIXME(@lefaudeux) the causal case seems significantly worse, not obvious why, - # could be worth investigating - assert mismatch < 0.6 - else: - assert mismatch < 0.23 - - # Check trainability - torch.sum(approx_attention_result).backward() - - -if __name__ == "__main__": - _plot_distribution(SMOrf) diff --git a/tests/test_global_attention.py b/tests/test_global_attention.py deleted file mode 100644 index b21e7382a4..0000000000 --- a/tests/test_global_attention.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import torch - -from xformers.components.attention import GlobalAttention, ScaledDotProduct - - -def test_global_attention(): - b, s, d = 2, 90, 40 - - torch.cuda.manual_seed(42) - torch.manual_seed(42) - - def test_ratio(global_attention_ratio: float): - # Make sure that Global and Normal attention get the same results for the corresponding tokens - a = torch.rand(b, s, d) - config = { - "name": "global", - "dropout": 0.0, - "causal": False, - "max_seq_len": s, - "attention_query_mask": torch.rand((s, 1)) < global_attention_ratio, - } - - global_attention = GlobalAttention(**config) - sdp_attention = ScaledDotProduct(**config) - - r_global = global_attention(a, a, a) - r_dense = sdp_attention(a, a, a) - - # Check that the tokens which have access to the full attention give the same - # results as the monolithic dense scaled_dot_product - mask = config["attention_query_mask"][:, 0] - assert torch.allclose(r_global[:, mask, :], r_dense[:, mask, :]) - - # Test with different levels of sparsity, to make sure that all the paths are covered - test_ratio(0.02) - test_ratio(0.5) - test_ratio(1.0) # All queries allowed diff --git a/tests/test_nystrom_attention.py b/tests/test_nystrom_attention.py deleted file mode 100644 index 6d57d1dbba..0000000000 --- a/tests/test_nystrom_attention.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import random - -import pytest -import torch - -from xformers.components.attention import NystromAttention, ScaledDotProduct -from xformers.components.attention.utils import maybe_merge_masks - - -@pytest.mark.parametrize("pinverse_original_init", [True, False]) -@pytest.mark.parametrize("use_razavi_pinverse", [True, False]) -@pytest.mark.parametrize("num_landmarks", [30, 33, 905]) -def test_nystrom_attention_close_to_sdp( - pinverse_original_init: bool, - use_razavi_pinverse: bool, - num_landmarks: int, -): - # TODO: conv_kernel_size parameter not set to None fails this test. Investigate. - b, s, d = 2, 900, 40 - num_heads = 2 - seed = 42 - torch.random.manual_seed(seed) - random.seed(seed) - - nystrom_config = { - "name": "nystrom", - "dropout": 0.0, - "num_landmarks": num_landmarks, - "num_heads": num_heads, - "pinverse_original_init": pinverse_original_init, - "use_razavi_pinverse": use_razavi_pinverse, - } - - sdp_config = { - "name": "scaled_dot_product", - "dropout": 0.0, - } - - a = torch.rand(b, s, d) - - def test_close_to_sdp(): - # Make sure that Nystrom and Normal attention are not too far off. - - nystrom_attention = NystromAttention(**nystrom_config) - sdp_attention = ScaledDotProduct(**sdp_config) - - r_nystrom = nystrom_attention(a, a, a, att_mask=None) - r_sdp = sdp_attention(a, a, a, att_mask=None) - - assert torch.allclose(r_nystrom, r_sdp, rtol=0.005, atol=1e-2) - - # Make sure that Nystrom and Normal attention are not too far off. - - nystrom_attention = NystromAttention(**nystrom_config) - sdp_attention = ScaledDotProduct(**sdp_config) - - r_nystrom = nystrom_attention(a, a, a, att_mask=None) - r_sdp = sdp_attention(a, a, a, att_mask=None) - - assert torch.allclose(r_nystrom, r_sdp, rtol=0.005, atol=1e-2) - - test_close_to_sdp() - - -@pytest.mark.parametrize("pinverse_original_init", [True]) -@pytest.mark.parametrize("use_razavi_pinverse", [True]) -@pytest.mark.parametrize("num_landmarks", [30]) -def test_nystrom_attention( - pinverse_original_init: bool, - use_razavi_pinverse: bool, - num_landmarks: int, -): - # TODO: conv_kernel_size parameter not set to None fails this test. Investigate. - b, s, d = 2, 900, 40 - num_heads = 2 - seed = 42 - torch.random.manual_seed(seed) - random.seed(seed) - - nystrom_config = { - "name": "nystrom", - "dropout": 0.0, - "num_landmarks": num_landmarks, - "num_heads": num_heads, - "pinverse_original_init": pinverse_original_init, - "use_razavi_pinverse": use_razavi_pinverse, - } - - sdp_config = { - "name": "scaled_dot_product", - "dropout": 0.0, - } - - a = torch.rand(b, s, d) - - def test_att_mask_ignored(): - # If an sxs attention mask is passed in, it should be ignored. - # Results should be the same as if no mask was passed in. - nystrom_attention = NystromAttention(**nystrom_config) - sdp_attention = ScaledDotProduct(**sdp_config) - - key_padding_mask = None - att_mask = torch.randint(0, 2, (s, s)).to(dtype=torch.bool) - sdp_mask = maybe_merge_masks( - att_mask=None, - key_padding_mask=key_padding_mask, - batch_size=b // num_heads, - src_len=s, - num_heads=num_heads, - ) - r_nystrom = nystrom_attention( - a, a, a, att_mask=att_mask, key_padding_mask=key_padding_mask - ) - r_sdp = sdp_attention(a, a, a, att_mask=sdp_mask) - assert torch.allclose(r_nystrom, r_sdp, rtol=0.005, atol=1e-2) - - def test_masking(): - # FIXME - # nystrom_config["causal"] = True - # sdp_config["causal"] = True - - nystrom_attention = NystromAttention(**nystrom_config) - sdp_attention = ScaledDotProduct(**sdp_config) - - key_padding_mask = torch.rand((b // num_heads, s)) > 0.1 - att_mask = None - mask = maybe_merge_masks( - att_mask, - key_padding_mask, - batch_size=b // num_heads, - src_len=s, - num_heads=num_heads, - ) - r_nystrom = nystrom_attention(a, a, a, key_padding_mask=key_padding_mask) - r_sdp = sdp_attention(a, a, a, att_mask=mask) - - # Not very close, but more so testing functionality. - assert torch.allclose( - r_nystrom, r_sdp, rtol=0.1, atol=0.5 - ), f"max diff {torch.max(torch.abs(r_nystrom-r_sdp))}" - - # Error when key padding mask doesn't have expected dimensions. - key_padding_mask = torch.randint(0, 2, (s, b)).to(dtype=torch.bool) - with pytest.raises(AssertionError): - nystrom_attention(a, a, a, key_padding_mask=key_padding_mask) - - test_att_mask_ignored() - test_masking() diff --git a/tests/test_ortho_attention.py b/tests/test_ortho_attention.py deleted file mode 100644 index 22a62f8f87..0000000000 --- a/tests/test_ortho_attention.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import random - -import pytest -import torch - -from xformers.components.attention import OrthoFormerAttention, ScaledDotProduct -from xformers.components.attention.utils import maybe_merge_masks - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") -@pytest.mark.parametrize( - "landmark_selection", ["orthogonal", "kmeans", "kmeans_spherical", "random"] -) -@pytest.mark.parametrize("num_landmarks", [30, 33, 905]) -@pytest.mark.parametrize("subsample_fraction", [1.0, 0.3]) -def test_ortho_attention( - landmark_selection: str, num_landmarks: int, subsample_fraction: float -): - # TODO: conv_kernel_size parameter not set to None fails this test. Investigate. - b, s, d = 8, 900, 32 - num_heads = 2 - seed = 42 - torch.random.manual_seed(seed) - random.seed(seed) - - ortho_config = { - "name": "orthoformer", - "dropout": 0.0, - "num_landmarks": num_landmarks, - "num_heads": num_heads, - "landmark_selection": landmark_selection, - "subsample_fraction": subsample_fraction, - } - - sdp_config = { - "name": "scaled_dot_product", - "dropout": 0.0, - } - - a = torch.rand(b, s, d, device=torch.device("cuda")) - - def test_close_to_sdp(): - # Make sure that Ortho and Normal attention are not too far off. - ortho_attention = OrthoFormerAttention(**ortho_config).cuda() - sdp_attention = ScaledDotProduct(**sdp_config).cuda() - - r_ortho = ortho_attention(a, a, a, att_mask=None) - r_sdp = sdp_attention(a, a, a, att_mask=None) - - assert torch.allclose(r_ortho, r_sdp, rtol=0.02, atol=1e-1) - - # Make sure that OrthoFormerAttention and Normal attention are not too far off. - ortho_attention = OrthoFormerAttention(**ortho_config).cuda() - sdp_attention = ScaledDotProduct(**sdp_config).cuda() - - r_ortho = ortho_attention(a, a, a, att_mask=None) - r_sdp = sdp_attention(a, a, a, att_mask=None) - - assert torch.allclose(r_ortho, r_sdp, rtol=0.02, atol=1e-1) - - def test_att_mask_ignored(): - # If an sxs attention mask is passed in, it should be ignored. - # Results should be the same as if no mask was passed in. - ortho_attention = OrthoFormerAttention(**ortho_config).cuda() - sdp_attention = ScaledDotProduct(**sdp_config).cuda() - - key_padding_mask = None - att_mask = torch.randint(0, 2, (s, s), device=torch.device("cuda")).to( - dtype=torch.bool - ) - sdp_mask = maybe_merge_masks( - att_mask=None, - key_padding_mask=key_padding_mask, - batch_size=b // num_heads, - src_len=s, - num_heads=num_heads, - ) - r_ortho = ortho_attention( - a, a, a, att_mask=att_mask, key_padding_mask=key_padding_mask - ) - r_sdp = sdp_attention(a, a, a, att_mask=sdp_mask) - assert torch.allclose(r_ortho, r_sdp, rtol=0.02, atol=1e-1) - - test_close_to_sdp() - test_att_mask_ignored() diff --git a/tests/test_rotary_embeddings.py b/tests/test_rotary_embeddings.py deleted file mode 100644 index e0b43ac105..0000000000 --- a/tests/test_rotary_embeddings.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import pytest -import torch - -from xformers.components.positional_embedding import RotaryEmbedding -from xformers.components.positional_embedding.rotary import ( - apply_rotary_pos_emb, - rotate_half, -) - -DEVICES = ( - [torch.device("cpu")] - if not torch.cuda.is_available() - else [ - torch.device("cuda") - ] # save a bit on CI for now, we have separate cpu and gpu jobs -) -BATCH = 2 -SEQ = 32 -HEADS = 2 -EMB = 32 - - -def test_helper_methods(): - # rotate_half - tens = torch.tensor([[0, 1, 2, 3], [3, 1, 2, 0], [0, 1, 0, 1], [1, 0, 1, 0]]) - tens_rotated = rotate_half(tens) - assert torch.equal( - tens_rotated, - torch.tensor([[-2, -3, 0, 1], [-2, 0, 3, 1], [0, -1, 0, 1], [-1, 0, 1, 0]]), - ) - - # apply_rotary_pos_emb - cos_test = torch.ones((1, 1, 4, 4)) - sin_test = cos_test.clone() - q_test = 3 * torch.ones((2, 2, 3, 4)) - q_applied = apply_rotary_pos_emb(q_test, cos_test, sin_test) - assert torch.equal( - q_applied, - torch.concat( - ( - torch.zeros((2, 2, 3, 2), dtype=torch.float), - 6 * torch.ones((2, 2, 3, 2), dtype=torch.float), - ), - dim=-1, - ), - ) - - -@pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) -def test_rotary_embeddings(device, dtype): - rotary = RotaryEmbedding(EMB).to(device) - - # Generate dummy inputs - q = torch.ones( - (BATCH, HEADS, SEQ, EMB), device=device, dtype=dtype - ) # uniform on purpose - k = q.clone() - - q_rot, k_rot = rotary(q, k) - - assert q_rot.dtype == q.dtype - assert k_rot.dtype == k.dtype - - # Check that the sequences now encode relative position information - q, k = q.float(), k.float() - q_rot, k_rot = q_rot.float(), k_rot.float() - - att = torch.einsum("bhne,bhme->bhnm", q, k) - att_rot = torch.einsum("bhne,bhme->bhnm", q_rot, k_rot) - - # - the attention for the same positions is not meaningfully changed - assert torch.allclose( - torch.diag(att[0, 0, :, :]), torch.diag(att_rot[0, 0, :, :]), rtol=0.1 - ) - - # - the post-rotary attention is more focused on the diagonal - diag_max = torch.max(torch.diag(att_rot[0, 0, :, :])) - att_rot -= diag_max - att_rot = ( - att_rot <= 1e-4 - ) # all non diagonal elements had lower attention than diagonal (+ float tolerance) - assert torch.all(att_rot) - - # Test that different sequence lengths is ok - _, _ = rotary(q[:, :, :-16, :], k) diff --git a/xformers/benchmarks/benchmark_multi_head_dispatch.py b/xformers/benchmarks/benchmark_multi_head_dispatch.py deleted file mode 100644 index 2345cf2a5f..0000000000 --- a/xformers/benchmarks/benchmark_multi_head_dispatch.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Any, Dict - -import torch -import torch.nn as nn -import triton - -from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print -from xformers.components import MultiHeadDispatch -from xformers.components.attention import ScaledDotProduct - -SHAPES = [ - (8, 384, 128), - (8, 784, 512), - (4, 1024, 768), - (4, 2048, 1024), - (2, 2048, 2048), - (2, 2048, 4096), - (2, 4096, 4096), - (1, 2048, 12288), -] - -N_HEADS = [4] - - -def bench_multihead_dispatch(backward: bool, self_attention: bool): - device = torch.device("cuda") - bw = "+bw" if backward else "" - sa = " (self_attn)" if self_attention else "" - - for dtype in [torch.float16, torch.float32]: - results: Dict[str, Any] = {} - - for B, M, K in SHAPES: - for heads in N_HEADS: - xf_multi_head = MultiHeadDispatch( - dim_model=K, - residual_dropout=0.0, - num_heads=heads, - attention=ScaledDotProduct(), - bias=(True, True, True, True), - ).to(device=device, dtype=dtype) - torch_multi_head = nn.MultiheadAttention( - embed_dim=K, num_heads=heads, batch_first=True - ).to(device=device, dtype=dtype) - - q = torch.randn( - (B, M, K), requires_grad=backward, device=device, dtype=dtype - ) - - if self_attention: - k = q - v = q - else: - k = torch.randn( - (B, M, K), requires_grad=backward, device=device, dtype=dtype - ) - v = torch.randn( - (B, M, K), requires_grad=backward, device=device, dtype=dtype - ) - - def torch_mha(): - y, _ = torch_multi_head(query=q, key=k, value=v) - if backward: - torch.norm(y).backward() - return y - - def xformers_mha(): - y = xf_multi_head(query=q, key=k, value=v) - if backward: - torch.norm(y).backward() - return y - - for testcase in [ - TestCase(torch_mha, f"torch - fw{bw}{sa}"), - TestCase(xformers_mha, f"xf - fw{bw}{sa}"), - ]: - time = triton.testing.do_bench(testcase.function)[0] - key = f"B={B}, M={M}, K={K}, N_HEADS={heads}" - if key not in results: - results[key] = {} - - results[key][testcase.name] = f"{time:.2f}" - - pretty_print( - results, - title=f"\n --- Type: {dtype} --- ", - units="runtime in ms, lower is better", - ) - pretty_plot( - results, - title=f"MHA-FW{bw}-{dtype}", - units="runtime in ms, lower is better", - dash_key="torch", - ) - - -for bw in [False, True]: - for self_attention in [False, True]: - bench_multihead_dispatch(bw, self_attention) diff --git a/xformers/components/__init__.py b/xformers/components/__init__.py index ed31269284..429d4f3936 100644 --- a/xformers/components/__init__.py +++ b/xformers/components/__init__.py @@ -5,19 +5,12 @@ import warnings -from dataclasses import fields from pathlib import Path -from typing import Any, Dict, Union from xformers.utils import import_all_modules -from .activations import Activation, build_activation # noqa from .attention import Attention, build_attention # noqa from .input_projection import InputProjection, InputProjectionConfig # noqa -from .multi_head_dispatch import MultiHeadDispatch # noqa -from .multi_head_dispatch import MultiHeadDispatchConfig -from .patch_embedding import PatchEmbeddingConfig # noqa -from .patch_embedding import build_patch_embedding # noqa from .residual import NormalizationType # noqa from .residual import PostNorm # noqa from .residual import PreNorm # noqa @@ -35,52 +28,3 @@ # automatically import any Python files in the directory import_all_modules(str(Path(__file__).parent), "xformers.components") - - -def build_multi_head_attention( - multi_head_config: Union[MultiHeadDispatchConfig, Dict[str, Any]], -): - """Builds a multihead attention from a config. - - This assumes a 'name' key in the config which is used to determine what - attention class to instantiate. For instance, a config `{"name": "my_attention", - "foo": "bar"}` will find a class that was registered as "my_attention" - (see :func:`register_attention`) and call .from_config on it.""" - - if not isinstance(multi_head_config, MultiHeadDispatchConfig): - # Extract the required fields - field_names = list(map(lambda x: x.name, fields(MultiHeadDispatchConfig))) - - # The missing fields get Noned - for k in field_names: - if k not in multi_head_config.keys(): - multi_head_config[k] = None - - # Could be that the attention needs to be instantiated - if not isinstance(multi_head_config["attention"], Attention): - # Convenience: fill in possible missing fields - if "num_heads" not in multi_head_config["attention"]: - multi_head_config["attention"]["num_heads"] = multi_head_config[ - "num_heads" - ] - - if "dim_model" not in multi_head_config["attention"]: - multi_head_config["attention"]["dim_model"] = multi_head_config[ - "dim_model" - ] - - if ( - "dim_features" not in multi_head_config["attention"] - or multi_head_config["attention"]["dim_features"] is None - ): - multi_head_config["attention"]["dim_features"] = ( - multi_head_config["dim_model"] // multi_head_config["num_heads"] - ) - - multi_head_config["attention"] = build_attention( - multi_head_config["attention"] - ) - - multi_head_config = MultiHeadDispatchConfig(**multi_head_config) - - return MultiHeadDispatch.from_config(multi_head_config) diff --git a/xformers/components/activations.py b/xformers/components/activations.py deleted file mode 100644 index 314a7962df..0000000000 --- a/xformers/components/activations.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from enum import Enum -from typing import Optional - -import torch -from torch import nn - -from xformers._deprecation_warning import deprecated_function - - -class Activation(str, Enum): - SquaredReLU = "squared_relu" - GeLU = "gelu" - LeakyReLU = "leaky_relu" - ReLU = "relu" - SmeLU = "smelu" - StarReLU = "star_relu" - - -# For unit testing / parity comparisons, probably not the fastest way -class SquaredReLU(nn.Module): - def __init__(self) -> None: - super().__init__() - deprecated_function(self) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_ = torch.nn.functional.relu(x) - return x_ * x_ - - -class StarReLU(nn.Module): - def __init__(self) -> None: - super().__init__() - deprecated_function(self) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_ = torch.nn.functional.relu(x) - return 0.8944 * x_ * x_ - 0.4472 - - -class SmeLU(nn.Module): - def __init__(self, beta: float = 2.0) -> None: - super().__init__() - self.beta = beta - deprecated_function(self) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - relu = torch.where( - x >= self.beta, - x, - torch.tensor([0.0], device=x.device, dtype=x.dtype), - ) - return torch.where( - torch.abs(x) <= self.beta, - ((x + self.beta) ** 2).type_as(x) / (4.0 * self.beta), - relu, - ) - - -def build_activation(activation: Optional[Activation]): - if not activation: - return nn.Identity() - - return { - Activation.ReLU: nn.ReLU, - Activation.GeLU: nn.GELU, - Activation.LeakyReLU: nn.LeakyReLU, - Activation.SquaredReLU: SquaredReLU, - Activation.StarReLU: StarReLU, - Activation.SmeLU: SmeLU, - }[activation]() diff --git a/xformers/components/attention/__init__.py b/xformers/components/attention/__init__.py index b19d99efb5..ff72d3598c 100644 --- a/xformers/components/attention/__init__.py +++ b/xformers/components/attention/__init__.py @@ -96,24 +96,10 @@ def sparsify(matrix): return matrix.to_sparse() -from .favor import FavorAttention # noqa -from .global_tokens import GlobalAttention # noqa -from .linformer import LinformerAttention # noqa -from .local import LocalAttention # noqa -from .nystrom import NystromAttention # noqa -from .ortho import OrthoFormerAttention # noqa -from .random import RandomAttention # noqa from .scaled_dot_product import ScaledDotProduct # noqa __all__ = [ "ScaledDotProduct", - "LocalAttention", - "LinformerAttention", - "NystromAttention", - "RandomAttention", - "OrthoFormerAttention", - "GlobalAttention", - "FavorAttention", "Attention", "AttentionMask", "build_attention", diff --git a/xformers/components/attention/compositional.py b/xformers/components/attention/compositional.py deleted file mode 100644 index a06053c27d..0000000000 --- a/xformers/components/attention/compositional.py +++ /dev/null @@ -1,341 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -# Credits: this is heavily inspired by the official implementation, present in -# https://github.com/sarthmit/Compositional-Attention -# Original author: Sarthak Mittal - -# This is a simplified version, for the sake of clarity, and because some features could be exposed later -# via the library directly. -# In particular, code paths for TPUs, quantization and gumbel softmax have been removed -# We're also following the same dimension ordering as in the rest of the xformers library -# which is to say [Batch, Sequence, Embedding] wherever possible - -import math -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn.functional as F -from torch import Tensor, nn - -from xformers.components.attention import ( - Attention, - AttentionConfig, - AttentionMask, - register_attention, -) -from xformers.components.attention.core import _softmax -from xformers.components.input_projection import InputProjection, InputProjectionConfig - - -def _either_or(a: Optional[int], b: int) -> int: - return a if a is not None else b - - -@dataclass -class CompositionalAttentionConfig(AttentionConfig): - dim_model: int - num_heads: int - dim_attn: Optional[int] = None - num_rules: Optional[int] = None - dim_key: Optional[int] = None - dim_value: Optional[int] = None - dim_selection: Optional[int] = None - dropout: float - qk_rule: bool = False - nonlinear: bool = False - q_compose: bool = False - bias: bool = True - causal: Optional[bool] = False - in_proj_container: Optional[InputProjection] = None - use_separate_proj_weight: Optional[bool] = False - - -@register_attention("compositional", CompositionalAttentionConfig) -class CompositionalAttention(Attention): - """Compositional Attention, as proposed in - "Compositional Attention: Disentangling search and retrieval"_, S. Mittal et al. - - A key insight from this proposal is that the attention mechanism can be conceived as two steps: - a search and a retrieval operation. When queried, the model can search for the most relevant information - (Softmax(QKt)), then retrieve information given the Value. - - Contrary to the original attention proposal, which does not consider interactions in between heads, - the compositional attention will consider all possible interactions and softmax over that dimension, - so that the information retrieved covers the most relevant dimensions. The number of heads and rules to - use is thus typically smaller than for a comparable traditional Transformer, and asking for the same number of heads - may not fit in memory. - - Args: - dim_model: dimension of the incoming latent space - num_heads: number of heads *for the search operation* - dim_attn: dimension (embedding) of the attention - num_rules: number of rules to consider *for the retrieval operation* - dim_selection: dimension of the scoring/selection space for the retrievals - dim_key, dim_value: dimensions of K and V, if different from Q - dropout: attention dropout probability - qk_rule: QK product will drive the retrieval process - nonlinear: use a non linear method to score the retrievals - bias: use bias in the initial projection step - causal: causal computations (attend to the past only) - - _"Compositional Attention: Disentangling search and retrieval": https://arxiv.org/pdf/2110.09419v1.pdf - """ - - def __init__( - self, - dim_model: int, - num_heads: int, - dim_attn: Optional[int] = None, - num_rules: Optional[int] = None, - dim_selection: Optional[int] = None, - dim_key: Optional[int] = None, - dim_value: Optional[int] = None, - dropout=0.0, - qk_rule=False, - nonlinear=False, - q_compose=False, - in_proj_container: Optional[InputProjection] = None, - use_separate_proj_weight: Optional[bool] = False, - bias=True, - causal=False, - *_, - **__, - ): - super().__init__() - - # Define the inherited flags - self.requires_skip_multi_head = ( - True # This attention owns the multi-head mechanism - ) - - # Handle defaults / undefined values - self.dim_model = dim_model - num_rules = _either_or(num_rules, num_heads) - dim_selection = _either_or(dim_selection, dim_model // num_heads) - - # All the initial definition plumbing - dim_attn = _either_or(dim_attn, dim_model) - dim_key = _either_or(dim_key, dim_model) - dim_value = _either_or(dim_value, dim_model) - - self.in_proj_container = ( - in_proj_container - if in_proj_container is not None - else InputProjection( - query_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias), - key_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias) - if use_separate_proj_weight - else None, - value_proj_params=InputProjectionConfig(dim_model, dim_value, bias=bias) - if use_separate_proj_weight - else None, - ) - ) - - self.num_heads = num_heads - self.num_rules = num_rules - self.qk_rule = qk_rule - self.dim_selection = dim_selection - self.nonlinear = nonlinear - self.q_compose = q_compose - - self.dropout_module = nn.Dropout(dropout) - self.dim_head = dim_model // num_heads - self.value_dim = dim_attn // num_rules - - assert ( - self.value_dim * num_rules == dim_attn - ), "value_dim must be divisible by num_rules" - - self.scaling = self.dim_head**-0.5 - self.scaling_values = self.dim_selection**-0.5 - - self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_model, bias=bias) - - if self.qk_rule: - self.value_k = nn.Linear(self.value_dim, self.dim_selection, bias=bias) - if self.q_compose: - self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias) - else: - self.value_q = nn.Linear( - dim_model, self.dim_selection * self.num_heads, bias=bias - ) - else: - if self.q_compose: - self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias) - else: - self.value_q = nn.Linear( - dim_model, self.dim_selection * self.num_heads, bias=bias - ) - if self.nonlinear: - self.score_network: nn.Module = nn.Sequential( - nn.Linear( - self.dim_selection + self.value_dim, - self.dim_selection, - bias=bias, - ), - nn.ReLU(), - nn.Linear(self.dim_selection, 1, bias=bias), - ) - else: - self.score_network = nn.Linear( - self.dim_selection + self.value_dim, 1, bias=bias - ) - - self.causal = causal - - # Properties specific to this attention mechanism - self.supports_attention_mask = True - self.supports_key_padding_mask = False - - self._reset_parameters() - - def _reset_parameters(self): - # NOTE: in_proj_container is already initialized - - if self.qk_rule: - nn.init.xavier_uniform_(self.value_k.weight, gain=1 / math.sqrt(2)) - nn.init.xavier_uniform_(self.value_q.weight, gain=1 / math.sqrt(2)) - else: - nn.init.xavier_uniform_(self.value_q.weight) - if self.nonlinear: - nn.init.xavier_uniform_(self.score_network[0].weight) - nn.init.xavier_uniform_(self.score_network[2].weight) - else: - nn.init.xavier_uniform_(self.score_network.weight) - - nn.init.xavier_uniform_(self.out_proj.weight) - if self.out_proj.bias is not None: - nn.init.constant_(self.out_proj.bias, 0.0) - - def forward( - self, - q: Tensor, - k: Tensor, - v: Tensor, - att_mask: Optional[Tensor] = None, - *args, - **kwargs, - ) -> Tensor: - """ - Input shape: Time x Batch x Channel - - Args: - att_mask (ByteTensor, optional): typically used to - implement causal attention, where the mask prevents the - attention from looking forward in time (default: None). - """ - - B, Sq, E = q.shape - _, Sk, _ = k.shape - - assert E == self.dim_model - - # First define projected query/key/values - # We keep the projected and original tensors in flight, - # depending on the options the original values could be reused - q_unprojected = q - q, k, v = self.in_proj_container(query=q, key=k, value=v) - q *= self.scaling - - # Init causal mask if needed, now that we know the context length - if self.causal and ( - self._causal_mask is None or self._causal_mask.shape[0] != Sk - ): - self._causal_mask = AttentionMask.make_causal(Sq, Sq, device=q.device) - - # Convenience, create an attention mask if a tensor was passed - # This sanitizes different mask types being passed, from now on it's additive - if isinstance(att_mask, torch.Tensor): - # By default we don't know of the causality, and a check would be expensive - att_mask_additive: Optional[AttentionMask] = ( - AttentionMask.from_bool(att_mask) - if att_mask.dtype == torch.bool - else AttentionMask(att_mask, is_causal=False) - ) - else: - att_mask_additive = None - - # Handle the attention and key padding masks - if self._causal_mask is not None: - # Optionally add the causal mask - if att_mask_additive is not None: - att_mask_additive += self._causal_mask - else: - att_mask_additive = self._causal_mask - - # Flatten the heads or the rules - q = ( - q.view(B, Sq, self.num_heads, self.dim_head) - .movedim(2, 1) - .flatten(0, 1) # [B * num_heads, Sq, dim_head] - ) - k = ( - k.view(B, Sk, self.num_heads, self.dim_head).movedim(2, 1).flatten(0, 1) - ) # [B * num_heads, Sk, dim_head] - v = v.view(B, -1, self.num_rules, self.value_dim).movedim(2, 1).flatten(0, 1) - - # Compute the search: Softmax(QKt) - attn_weights = torch.bmm(q, k.transpose(1, 2)) # [B * self.num_heads, Sq, Sk] - - if att_mask_additive is not None: - attn_weights += att_mask_additive.values - - attn_weights = _softmax(attn_weights, causal=self.causal) - - attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk) - attn_probs = self.dropout_module(attn_weights) - - # Now compute the information retrieval - # keep all the heads in flight, we'll score the different possibilities - # - compute all the possible retrievals - v = v.view(B, 1, self.num_rules, Sk, self.value_dim) - attn_probs = attn_probs.unsqueeze(2) - attn = torch.matmul(attn_probs, v).view( - B, self.num_heads, self.num_rules, Sq, self.value_dim - ) - - attn = attn.movedim(3, 1) # [B, Sq, H, Rules, Values] - - # - search the most appropriate retrieval among all the values - if self.q_compose: - v_q = self.value_q(q.transpose(0, 1)).view( - B, Sq, self.num_heads, 1, self.dim_selection - ) - else: - v_q = self.value_q(q_unprojected).view( - B, Sq, self.num_heads, 1, self.dim_selection - ) - - if self.qk_rule: - v_q *= self.scaling_values - v_k = ( - self.value_k(attn) - .view(B, Sq, self.num_heads, self.num_rules, self.dim_selection) - .transpose(4, 3) - .contiguous() - ) - v_score = torch.matmul(v_q, v_k).view( - B, Sq, self.num_heads, self.num_rules, 1 - ) - else: - v_q = v_q.expand(-1, -1, -1, self.num_rules, -1) - v_in = torch.cat([attn, v_q], dim=-1) - v_score = self.score_network(v_in).view( - B, Sq, self.num_heads, self.num_rules, 1 - ) - - v_score = F.softmax(v_score, dim=3) - - # - extracted values are the original attention (inc. all the values) weighted by value score - attn = (attn * v_score).sum(dim=3).view(B, Sq, self.num_heads * self.value_dim) - - # Final attention projection, same as other mechanisms - attn = self.out_proj(attn) - - return attn diff --git a/xformers/components/attention/favor.py b/xformers/components/attention/favor.py deleted file mode 100644 index d7dfbc53ab..0000000000 --- a/xformers/components/attention/favor.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import math -from dataclasses import dataclass -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from torch.amp import autocast - -from xformers.components.attention import Attention, AttentionConfig, register_attention -from xformers.components.attention.feature_maps import ( - FeatureMap, - FeatureMapType, - SMHyperbolic, - SMOrf, - SMReg, -) - -logger = logging.getLogger("xformers") - - -@dataclass -class FavorAttentionConfig(AttentionConfig): - causal: Optional[bool] - dim_features: Optional[int] = None # The dimensions of the random features - dim_head: Optional[ - int - ] = None # The embedding dimension of the inputs. Only useful to get a dim_features estimate - iter_before_redraw: Optional[ - int - ] = None # The number of iterations before the random features are re-drawn from scratch - feature_map: Optional[FeatureMapType] = None - - -@register_attention("favor", FavorAttentionConfig) -class FavorAttention(Attention): - def __init__( - self, - causal: bool = False, - dropout: float = 0.0, - dim_features: Optional[int] = None, - dim_head: Optional[int] = None, - iter_before_redraw: Optional[int] = None, - feature_map_type: FeatureMapType = FeatureMapType.SMReg, - normalize_inputs: bool = False, - *_, - **__, - ): - r""" - Kernelized attention, as proposed in Performers_ - ("Rethinking attention with performers." K. Choromanski et al. (2020).). - - FAVOR stands for "Fast Attention Via positive Orthogonal Random features" - - Args: - dropout (float): the probability of an output to be randomly dropped at training time - dim_features (int): the dimension of the random features space - iter_before_redraw (int): the number of steps (forward calls) before a redraw of the features - feature_map_type (FeatureMapType): the type of feature map being used, - for instance orthogonal random features. - - .. _Performers: https://arxiv.org/pdf/2009.14794v1.pdf - """ - super().__init__() - - self.causal = causal - self.iter_before_redraw = ( - (2 * iter_before_redraw) - if iter_before_redraw is not None - else iter_before_redraw - ) # This will be used for both key and query - self.normalize_inputs = normalize_inputs - self.feature_map_type = feature_map_type - self.attn_drop = nn.Dropout(dropout, inplace=True) - - # Setup dimension-dependent variables - # Reasonable dimension default - if dim_features is None: - assert dim_head is not None, "dim_features or dim_head needs to be passed" - self.dim_features = math.ceil(dim_head * (1 + math.log2(dim_head))) - self.dim_features = 2 * ( - self.dim_features // 2 - ) # needs to be even for some variants - logger.info( - f"FAVOR: Automatically setting the random mapping dimension to {self.dim_features} from {dim_head}" - ) - else: - self.dim_features = dim_features - - feature_map_constructor = { - FeatureMapType.SMHyp: SMHyperbolic, - FeatureMapType.SMReg: SMReg, - FeatureMapType.SMOrf: SMOrf, - }[self.feature_map_type] - - feature_settings = { - "dim_features": self.dim_features, - "iter_before_redraw": self.iter_before_redraw, - "normalize_inputs": self.normalize_inputs, - } - - self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore - - # Properties specific to this attention mechanism - self.supports_attention_mask = False - self.supports_key_padding_mask = False - - @staticmethod - def _maybe_promote(x: torch.Tensor) -> torch.Tensor: - # Only promote fp16 buffers, bfloat16 would be fine for instance - return x.float() if x.dtype == torch.float16 else x - - @staticmethod - def _causal_attention( - k_prime: torch.Tensor, q_prime: torch.Tensor, v: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Algorithm 1 in the paper - ref_v = torch.ones_like(v.unsqueeze(2)) # BATCH x SEQ x 1 x EMB - Gps = k_prime.unsqueeze(3) * v.unsqueeze(2) - Grenorm = k_prime.unsqueeze(3) * ref_v - - # Consolidate against the feature dimension - att_raw = torch.einsum("bcfe,bcf->bce", Gps, q_prime) - att_norm = torch.einsum("bcfe,bcf->bce", Grenorm, q_prime) - - # Cumulative sum over the sequence - att_raw = att_raw.cumsum(2) - att_norm = att_norm.cumsum(2) - - return att_raw, att_norm - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - *_, - **__, - ): - - # Project key and queries onto the feature map space - k_prime = self.feature_map(k) - q_prime = self.feature_map(q) - - with autocast("cuda", enabled=False): - # The softmax kernel approximation for Favor will easily overflow - # Force the computations here to stay in fp32 for numerical stability - # Note that the dimensions are vastly reduced when compared to scaled_dot_product - k_prime = self._maybe_promote(k_prime) - q_prime = self._maybe_promote(q_prime) - v = self._maybe_promote(v) - - if not self.causal: - att_normalization = q_prime @ ( - k_prime.transpose(-2, -1) @ torch.ones_like(v) - ) - att_raw = q_prime @ (k_prime.transpose(-2, -1) @ v) - else: - # Actually compute attention - att_raw, att_normalization = self._causal_attention(k_prime, q_prime, v) - - # Normalize - att = att_raw / att_normalization - - if self.attn_drop is not None: - att = self.attn_drop(att) - - return att diff --git a/xformers/components/attention/feature_maps/__init__.py b/xformers/components/attention/feature_maps/__init__.py deleted file mode 100644 index ed308d17a8..0000000000 --- a/xformers/components/attention/feature_maps/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from enum import Enum - -from .base import FeatureMap, FeatureMapConfig -from .softmax import NormDistribution, SMHyperbolic, SMOrf, SMReg - - -class FeatureMapType(str, Enum): - SMOrf = "sm_orf" - SMHyp = "sm_hyp" - SMReg = "sm_reg" # regularized softmax kernel - - -__all__ = [ - "SMOrf", - "SMReg", - "SMHyperbolic", - "NormDistribution", - "FeatureMapConfig", - "FeatureMap", -] diff --git a/xformers/components/attention/feature_maps/base.py b/xformers/components/attention/feature_maps/base.py deleted file mode 100644 index 8d41de827a..0000000000 --- a/xformers/components/attention/feature_maps/base.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from abc import abstractmethod -from dataclasses import asdict, dataclass -from typing import Optional, Type, TypeVar - -import torch - -""" -Feature maps allow for a given query or key to be encoded in a different space. -""" - -Self = TypeVar("Self", bound="FeatureMap") - - -@dataclass -class FeatureMapConfig: - name: str - dim_features: int - iter_before_redraw: Optional[int] - normalize_inputs: Optional[bool] - epsilon: Optional[float] - - -class FeatureMap(torch.nn.Module): - def __init__( - self, - dim_features: int, - iter_before_redraw: Optional[int] = None, - normalize_inputs: bool = False, - epsilon: float = 1e-6, - ): - super().__init__() - - self.dim_features = dim_features - self.dim_feature_map = dim_features - - self.iter_before_redraw = iter_before_redraw - self.features: Optional[torch.Tensor] = None - self.epsilon = epsilon - self.normalize_inputs = normalize_inputs - - self._iter_counter = 0 - - @abstractmethod - def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device): - raise NotImplementedError() - - @classmethod - def from_config(cls: Type[Self], config: FeatureMapConfig) -> Self: - # Generate the class inputs from the config - fields = asdict(config) - - # Skip all Nones so that default values are used - fields = {k: v for k, v in fields.items() if v is not None} - - return cls(**fields) diff --git a/xformers/components/attention/feature_maps/softmax.py b/xformers/components/attention/feature_maps/softmax.py deleted file mode 100644 index d0dd1df734..0000000000 --- a/xformers/components/attention/feature_maps/softmax.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -import math -from enum import Enum, auto -from typing import Optional - -import torch -from torch.autograd.profiler import record_function - -from .base import FeatureMap - -""" -A set of feature maps which approximate the softmax kernel, as per the Performers_ paper. - -_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020). - https://arxiv.org/pdf/2009.14794v1.pdf -""" - - -class NormDistribution(Enum): - Xi = auto() - Uniform = auto() - - -class SoftMaxPositiveEstimators(FeatureMap): - def __init__( - self, - dim_features: int, - iter_before_redraw: Optional[int], - normalize_inputs: bool = False, - epsilon: float = 1e-6, - softmax_temp: float = -1, - ): - super().__init__(dim_features, iter_before_redraw, normalize_inputs, epsilon) - self.softmax_temp = softmax_temp - - # Handle the scaling from all kernels by √m. - # This normalizes for all the feature maps involved - self.h_scale = math.log(math.sqrt(self.dim_features)) - - def pre_scale(self, x: torch.Tensor) -> torch.Tensor: - with record_function("feature_map::pre_scale"): - # Re-draw counting logic - if ( - ( - self.iter_before_redraw is not None - and self._iter_counter > self.iter_before_redraw - ) - or self.features is None - or self.features.device != x.device - ): - # The feature map is actually using half the dimension, we'll concatenate + and - features - self._iter_counter = 1 - self.features = self._get_feature_map( - x.shape[-1], self.dim_feature_map, x.device - ) - - features = self.features - assert features is not None - - if features.dtype != x.dtype: - self.features = features.to(x.dtype) - - self._iter_counter += 1 - - # Normalization / softmax - if self.softmax_temp < 0: - # A = exp(QK.t/√d), so each input will be scaled by √√d - self.softmax_temp = x.shape[-1] ** -0.25 - - x_scaled = x * self.softmax_temp - - # Compute the scaling factors in logspace, applied from within the exponential - # - dimnish possible exponential overflow - # - remove a multiply across the batch, replace by an addition - norm_x_2 = torch.einsum("...d,...d->...", x_scaled, x_scaled).unsqueeze(-1) - self.offset = -0.5 * norm_x_2 - self.h_scale + self.epsilon - - if self.normalize_inputs: - # L0 normalize the exponential term, can be useful for numerical stability - # This ensures that features +- offset is below 1 - self.offset -= norm_x_2.max(1, keepdim=True)[0] - - # Return the scaled inputs, the rest depends on the kernel being used - return x_scaled - - @staticmethod - @torch.no_grad() - def _get_random_ortho_matrix( - blocks: int, - dim: int, - device: torch.device, - norm_distribution: NormDistribution = NormDistribution.Uniform, - ) -> torch.Tensor: - r""" - Generate a random matrix whose rows are exactly orthonormal - - "How to generate random matrices from the classical compact groups", Mezzadri, 2007 - https://arxiv.org/pdf/math-ph/0609050v2.pdf - - .. note: the typical qr decomposition does not give uniform results, qr decomposition is not - unique and the qr decomposition routines are biased towards numerical stability. See the above - paper for more information. - - .. note: this does not follow the original implementation from the Performers authors. - see docs/assets/kde plots to visualize the impact of using the R signs to correct Q - """ - - H = torch.randn((blocks, dim, dim), device=device, requires_grad=False) - - # Randomly scale the norms of the features, Xi distributed - if norm_distribution == NormDistribution.Xi: - # NOTE: This averages to sqrt(d) - norms = torch.sqrt(torch.einsum("...d,...d->...", H, H)) - - Q, R = torch.linalg.qr(H) - Q = torch.diag_embed(torch.sign(torch.diagonal(R, dim1=1, dim2=2))) @ Q - - # Normalize if need be. Uniform NormDistribution does nothing, Q is already orthonormal - if norm_distribution == NormDistribution.Xi: - return torch.diag_embed(norms) @ Q - - return Q - - -class SMOrf(SoftMaxPositiveEstimators): - """ - "Positive random orthogonal features" softmax estimator, - SM_ort^m+, as proposed in the Performers_ paper, Lemma 1. - - _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020). - https://arxiv.org/pdf/2009.14794v1.pdf - """ - - @torch.no_grad() - def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device): - """ - Generate the projection matrix onto the random features - - .. note: The heads dimension needs to be taken into account, hence the per-block random matrix - and not uniformally random. - """ - - # Get per block random unitary matrices. - # We need enough of them to project the whole input dimension, regardless of the - # requested dimension of the features - features = self._get_random_ortho_matrix( - math.ceil(dim_input / dim_features), - dim_features, - norm_distribution=NormDistribution.Xi, - device=device, - ) - - return features.flatten(0, 1)[:dim_input] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # Softmax-dimension related scaling, shared for all kernels - x_scaled = super().pre_scale(x) - assert self.features is not None - - # Project onto the random feature map. - x_scaled = x_scaled @ self.features - return torch.exp(x_scaled + self.offset) - - -class SMHyperbolic(SoftMaxPositiveEstimators): - """ - "Positive random features hyperbolic" estimator, SMHyp+, - as proposed in the Performers_ paper, Lemma 1. - - _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020). - https://arxiv.org/pdf/2009.14794v1.pdf - """ - - def __init__( - self, - dim_features: int, - iter_before_redraw: Optional[int], - normalize_inputs: bool = False, - epsilon: float = 1e-6, - softmax_temp: float = -1, - ): - super().__init__( - dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp - ) - - assert ( - dim_features % 2 == 0 - ), "The feature dimension needs to be even with this kernel" - self.dim_feature_map = self.dim_features // 2 - - @torch.no_grad() - def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device): - """ - Generate the projection matrix onto the random features - - .. note: The heads dimension needs to be taken into account, hence the per-block random matrix - and not uniformally random. - """ - - # Get per block random unitary matrices. - # We need enough of them to project the whole input dimension, regardless of the - # requested dimension of the features - features = self._get_random_ortho_matrix( - math.ceil(dim_input / dim_features), - dim_features, - norm_distribution=NormDistribution.Xi, - device=device, - ) - - return features.flatten(0, 1)[:dim_input] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # Softmax-dimension related scaling, shared for all kernels - x_scaled = super().pre_scale(x) - - # Project onto the random feature map, concatenate both + and - results - # This follows Lemma 1 in the original Performers Paper to best approximate a - # softmax kernel (cosh representation) - x_scaled = x_scaled @ self.features - return torch.cat( - [torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)], - dim=-1, - ) - - -class SMReg(SoftMaxPositiveEstimators): - """ - "Regularized softmax kernel" estimator, SMREG+, as proposed in the Performers_ paper. - - _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020). - https://arxiv.org/pdf/2009.14794v1.pdf - """ - - def __init__( - self, - dim_features: int, - iter_before_redraw: Optional[int], - normalize_inputs: bool = False, - epsilon: float = 1e-6, - softmax_temp: float = -1, - ): - super().__init__( - dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp - ) - - assert ( - dim_features % 2 == 0 - ), "The feature dimension needs to be even with this kernel" - self.dim_feature_map = self.dim_features // 2 - - @torch.no_grad() - def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device): - """ - Generate the projection matrix onto the random features - - .. note: The heads dimension needs to be taken into account, hence the per-block random matrix - and not uniformally random. - """ - - # Get per block random unitary matrices. - # We need enough of them to project the whole input dimension, regardless of the - # requested dimension of the features - features = self._get_random_ortho_matrix( - math.ceil(dim_input / dim_features), - dim_features, - norm_distribution=NormDistribution.Uniform, - device=device, - ).flatten(0, 1) - norms = math.sqrt(dim_input) * torch.ones(features.shape[0], device=device) - return (torch.diag(norms) @ features)[:dim_input] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # Softmax-dimension related scaling, shared for all kernels - x_scaled = super().pre_scale(x) - - # Project onto the random feature map, concatenate both + and - results - # This follows Lemma 1 in the original Performers Paper to best approximate a - # softmax kernel (cosh representation + sample regularization) - x_scaled = x_scaled @ self.features - return torch.cat( - [torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)], - dim=-1, - ) diff --git a/xformers/components/attention/global_tokens.py b/xformers/components/attention/global_tokens.py deleted file mode 100644 index c6a5284a2e..0000000000 --- a/xformers/components/attention/global_tokens.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from dataclasses import dataclass -from typing import Optional, Union - -import torch -import torch.nn as nn - -from xformers.components.attention import ( - Attention, - AttentionConfig, - AttentionMask, - maybe_sparsify, - register_attention, - sparsify, -) -from xformers.components.attention.attention_patterns import ( - causal_1d_pattern, - global_token_pattern, -) -from xformers.components.attention.core import scaled_dot_product_attention - - -@dataclass -class GlobalAttentionConfig(AttentionConfig): - attention_query_mask: torch.Tensor # Mark the queries which have global attention - causal: Optional[bool] - force_sparsity: Optional[bool] - - -@register_attention("global", GlobalAttentionConfig) -class GlobalAttention(Attention): - def __init__( - self, - dropout: float, - attention_query_mask: torch.Tensor, - causal: bool = False, - force_sparsity: bool = False, - *_, - **__, - ): - r""" - Global attention, as proposed for instance in BigBird_ or Longformer_. - - Global means in that case that the queries positively labelled in the ```attention_query_mask``` can attend - to all the other queries. The queries negatively labelled in the ```attention_query_mask``` cannot attend to - any other query. - - This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory. - - Args: - dropout (float): probability of an element to be zeroed - attention_query_mask (torch.Tensor): if true, this query can attend to all the others - - """ - super().__init__() - - assert attention_query_mask.dtype == torch.bool, "A boolean mask is expected" - assert ( - attention_query_mask.shape[1] == 1 - and attention_query_mask.shape[0] > attention_query_mask.shape[1] - ), "A N x 1 query mask is expected" - - self.attn_drop = nn.Dropout(dropout, inplace=False) - self.attention_mask = global_token_pattern(attention_query_mask[:, 0]) - self.force_sparsity = force_sparsity - - if causal: - self.attention_mask &= causal_1d_pattern(attention_query_mask.shape[1]) - - self.attention_mask = ( - sparsify(self.attention_mask) - if self.force_sparsity - else maybe_sparsify(self.attention_mask) - ) - - # Properties specific to this attention mechanism - self.requires_same_k_q_dimensions = True - self.supports_attention_mask = False - self.supports_key_padding_mask = False - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, - *_, - **__, - ): - # Make sure that the mask is on the right device - if self.attention_mask.device != q.device: - self.attention_mask = self.attention_mask.to(q.device) - - # Mask-aware attention - if att_mask is not None: - if att_mask.dtype == torch.bool and isinstance( - self.attention_mask, AttentionMask - ): - if not isinstance(att_mask, AttentionMask): - att_mask = AttentionMask.from_bool(att_mask) - mask = self.attention_mask + att_mask - else: - mask = self.attention_mask & att_mask - else: - mask = self.attention_mask - - # Handle q/k/v which would not fit the mask - seq_len = q.shape[-2] - q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v)) - - # Normal attention with the global tokens mask - att = scaled_dot_product_attention( - q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop - ) - - # Take into account an hypothetical padding - return att[:, :seq_len, :] diff --git a/xformers/components/attention/lambda_layer.py b/xformers/components/attention/lambda_layer.py deleted file mode 100644 index 0002a20cbc..0000000000 --- a/xformers/components/attention/lambda_layer.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from dataclasses import dataclass - -import torch - -from xformers.components.attention import Attention, AttentionConfig, register_attention - - -def calc_rel_pos(n: int): - # Adapted from LucidRains - # https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py - rel_pos = torch.arange(n)[None, :] - torch.arange(n)[:, None] # [n, n] - rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2] - return rel_pos - - -@dataclass -class LambdaLayerConfig(AttentionConfig): - seq_len: int # dimension of the input sequence - dim_head: int - - -@register_attention("lambda", LambdaLayerConfig) -class LambdaLayer(Attention): - def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__): - """ - Attention approximation using Lambda layers, from - "Lambda networks: modeling long-range interactions without attention.", Bello, I. (2021). - """ - super().__init__() - - # Possible extensions: - # - support different dimensions for key and queries - # - support varying dimensions in between inputs and outputs - # - support u hyperparam - - self.rel_pos_emb = torch.nn.Parameter( - torch.randn(2 * seq_len - 1, int(dim_head)) - ) - self.rel_pos = calc_rel_pos(seq_len) - self.attn_drop = torch.nn.Dropout(dropout, inplace=True) - - # Properties specific to this attention mechanism - self.requires_same_k_q_dimensions = True - self.supports_attention_mask = False - self.supports_key_padding_mask = False - - def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs - ): - """..NOTE: We're reusing the einsum notation suggested by the paper, changed in that - heads are folded in the batch dimension""" - - content_lambda = torch.einsum("bnk,bnv->bkv", torch.softmax(k, dim=-1), v) - content_output = torch.einsum("bnk,bkv->bnv", q, content_lambda) - - rel_pos_emb = self.rel_pos_emb[self.rel_pos] - - # Handle real sequence length being possibly smaller - seq_len = q.shape[1] - rel_pos_emb = rel_pos_emb[:seq_len, :seq_len, :] - - # Compute the position lambda for every possible combination in one go, then compute the - # position related contribution - position_lambdas = torch.einsum( - "mnk,bnv->bnkv", rel_pos_emb, v - ) # one lambda per position - position_output = (q.unsqueeze(2) @ position_lambdas).squeeze() - att = content_output + position_output - - att = self.attn_drop(att) - - return att diff --git a/xformers/components/attention/linformer.py b/xformers/components/attention/linformer.py deleted file mode 100644 index af6f20b599..0000000000 --- a/xformers/components/attention/linformer.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn as nn - -from xformers.components.attention import Attention, AttentionConfig, register_attention -from xformers.components.attention.core import scaled_dot_product_attention - - -@dataclass -class LinformerSelfAttentionConfig(AttentionConfig): - seq_len: int # dimension of the input sequence - k: Optional[int] # dimension of the internal space - - -@register_attention("linformer", LinformerSelfAttentionConfig) -class LinformerAttention(Attention): - def __init__( - self, dropout: float, seq_len: int, k: Optional[int] = None, *args, **kwargs - ): - """ - Linformer attention mechanism, - from `Linformer: Self-Attention with Linear Complexity`_, Wang et al (2020). - The original notation is kept as is. - - .. _`Linformer: Self-Attention with Linear Complexity` : https://arxiv.org/abs/2006.04768v2 - """ - super().__init__() - - if k is None: - k = seq_len // 4 - - self.k = k - self.E = nn.Linear(seq_len, k, bias=False) - self.F = nn.Linear(seq_len, k, bias=False) - self.attn_drop = nn.Dropout(dropout, inplace=False) - self.seq_len = seq_len - - # MHA related flags: - # kq need to have the same dimension - self.requires_same_k_q_dimensions = True - - # This attention does not support attention masks - self.supports_attention_mask = False - - def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs - ): - # Handle a smaller dimension than expected - padding = 0 - if q.shape[1] < self.seq_len: - padding = self.seq_len - q.shape[1] - pad_dims = (0, 0, 0, padding) - q = torch.nn.functional.pad(q, pad_dims) - k = torch.nn.functional.pad(k, pad_dims) - v = torch.nn.functional.pad(v, pad_dims) - - k_projected = self.E(k.transpose(-2, -1)).transpose(-2, -1) - v_projected = self.F(v.transpose(-2, -1)).transpose(-2, -1) - - y = scaled_dot_product_attention( - q=q, k=k_projected, v=v_projected, att_mask=None, dropout=self.attn_drop - ) - - y = self.attn_drop(y) - - return y[:, :-padding, :] if padding > 0 else y diff --git a/xformers/components/attention/local.py b/xformers/components/attention/local.py deleted file mode 100644 index 3220a8d401..0000000000 --- a/xformers/components/attention/local.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from dataclasses import dataclass -from typing import Optional, Union - -import torch -import torch.nn as nn - -from xformers.components.attention import ( - Attention, - AttentionConfig, - AttentionMask, - maybe_sparsify, - register_attention, - sparsify, -) -from xformers.components.attention.attention_patterns import ( - causal_1d_pattern, - local_1d_pattern, -) -from xformers.components.attention.core import scaled_dot_product_attention - - -@dataclass -class LocalAttentionConfig(AttentionConfig): - causal: Optional[bool] = None - window_size: Optional[int] = None - force_sparsity: Optional[bool] = None - - -@register_attention("local", LocalAttentionConfig) -class LocalAttention(Attention): - def __init__( - self, - dropout: float = 0.0, - causal: bool = False, - window_size: int = 5, - force_sparsity: bool = False, - *args, - **kwargs, - ): - - r""" - An implementation of a sliding window attention, as proposed in RoutingTransformer_, LongFormer_ or BigBird_ - - - Args: - dropout (float): the probability of an output to be randomly dropped at training time - causal (bool): apply a causal mask, in that the attention cannot be applied to the future - window_size (int): the overall window size for local attention. - Odd number is expected if the mask is not causal, as the window size will be evenly - distributed on both sides of each query - - - .. _RoutingTransformer: https://arxiv.org/pdf/2003.05997.pdf - - .. _BigBird: https://arxiv.org/pdf/2007.14062.pdf - - .. _Longformer: https://arxiv.org/pdf/2004.05150.pdf - - """ - super().__init__() - - self.attn_drop = nn.Dropout(dropout, inplace=False) - self.causal = causal - self.force_sparsity = force_sparsity - - if not self.causal: - assert ( - window_size % 2 == 1 - ), "The window size is assumed to be odd (counts self-attention + 2 wings)" - - self.window_size = window_size - self.attention_mask: Optional[torch.Tensor] = None - self.requires_same_k_q_dimensions = True - - # Properties specific to this attention mechanism - self.supports_attention_mask = True - self.supports_key_padding_mask = False - - def _get_local_mask(self, shape: torch.Size) -> torch.Tensor: - window_size = self.window_size * 2 + 1 if self.causal else self.window_size - mask = local_1d_pattern(shape[1], window_size) - - if self.causal: - mask &= causal_1d_pattern(shape[1]) - - mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask) - - return mask - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, - *args, - **kwargs, - ): - # Local window attention masking - if self.attention_mask is None or self.attention_mask.shape[1] != q.shape[1]: - self.attention_mask = self._get_local_mask(q.shape).to(q.device) - - # Take into account the optional user mask - if att_mask is None: - mask = self.attention_mask - else: - if isinstance(att_mask, AttentionMask): - # Needed because & op not defined for SparseCS with AttentionMask - att_mask = att_mask.to_bool() - mask = self.attention_mask & att_mask - - return scaled_dot_product_attention( - q=q, k=k, v=v, att_mask=mask, dropout=self.attn_drop - ) diff --git a/xformers/components/attention/nystrom.py b/xformers/components/attention/nystrom.py deleted file mode 100644 index 93e40b74de..0000000000 --- a/xformers/components/attention/nystrom.py +++ /dev/null @@ -1,295 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -import logging -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn as nn - -from xformers.components.attention import Attention, AttentionConfig, register_attention -from xformers.components.attention.core import ( - scaled_dot_product_attention, - scaled_query_key_softmax, -) -from xformers.components.attention.utils import ( - bool_mask_to_additive, - iterative_pinv, - reshape_key_padding_mask, -) - -logger = logging.getLogger("xformers") - - -@dataclass -class NystromSelfAttentionConfig(AttentionConfig): - """ - num_heads Number of heads. - num_landmarks Number of landmarks to use for softmax approximation. 64 often sufficient for a good - approximation according to https://arxiv.org/pdf/2102.03902.pdf. - causal Apply a causal mask, in that the attention cannot be applied to the future. - use_razavi_pinverse If true, use iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose - inverse, otherwise use standard torch inverse. - pinverse_original_init True if using original initialization when calculating Moore-Penrose pseudo inverse using - method from (Razavi et al. 2014). - False if using exact coefficient computation (leads to faster convergence). - inv_iterations Number of iterations for calculating the Moore-Penrose pseudo inverse. - v_skip_connection A module that will take V as input and will be added as a skip connection to the - softmax approximation. A skip connection is added in the paper to help with training. - conv_kernel_size Kernel size for convolution optionally added to help in training. - If v_skip_connection is not specified, this will be used to define the default - depth wise convolution used as a skip connection. - If both conv_kernel_size and v_skip_connection are None, no skip connection will - be added. - landmark_pooling Which module to use when computing landmarks. Default is AdaptiveAvgPool2d. - """ - - num_heads: int - num_landmarks: Optional[int] - landmark_pooling: Optional[nn.Module] - causal: Optional[bool] - pinverse_original_init: Optional[bool] - inv_iterations: Optional[int] - v_skip_connection: Optional[nn.Module] - conv_kernel_size: Optional[int] - use_razavi_pinverse: Optional[bool] - - -class AvgPool(nn.Module): - def __init__(self, n: int): - super().__init__() - self.n = n - - def forward(self, x: torch.Tensor): - # Average independently for every segment in the sequence dimension - seq_len = x.shape[1] - head_dim = x.shape[2] - segments = seq_len // self.n - assert segments > 0, "num_landmarks should be smaller than the sequence length" - - # Dimensions are a match - if seq_len % self.n == 0: - return x.reshape( - -1, - self.n, - segments, - head_dim, - ).mean(dim=-2) - - # Handle the last segment boundary being off - n_round = self.n - seq_len % self.n - - x_avg_round = ( - x[:, : n_round * segments, :] - .reshape(-1, n_round, segments, head_dim) - .mean(dim=-2) - ) - x_avg_off = ( - x[:, n_round * segments :, :] - .reshape(-1, self.n - n_round, segments + 1, head_dim) - .mean(dim=-2) - ) - return torch.cat((x_avg_round, x_avg_off), dim=-2) - - -@register_attention("nystrom", NystromSelfAttentionConfig) -class NystromAttention(Attention): - # TODO: update defaults for use_razavi_pinverse and inv_iterations - def __init__( - self, - dropout: float, - num_heads: int, - num_landmarks: int = 64, - landmark_pooling: Optional[nn.Module] = None, - causal: bool = False, - use_razavi_pinverse: bool = True, - pinverse_original_init: bool = False, - inv_iterations: int = 6, # recommended default in paper was 6. - v_skip_connection: Optional[nn.Module] = None, - conv_kernel_size: Optional[int] = None, - *args, - **kwargs, - ): - """ - Nystrom attention mechanism, from Nystromformer_. - :: - - "A Nystrom-based Algorithm for Approximating Self-Attention." - Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., Singh, V. (2021) - - Reference codebase: https://github.com/mlpen/Nystromformer - - .. _Nystromformer: https://arxiv.org/pdf/2102.03902.pdf - - """ - super().__init__() - # merged key padding mask and attention mask is not accepted - self.requires_separate_masks = True - self.num_landmarks = num_landmarks - # TODO: should be able to not have to pass in num_heads - self.num_heads = num_heads - self.use_razavi_pinverse = use_razavi_pinverse - self.pinverse_original_init = pinverse_original_init - self.inv_iterations = inv_iterations - self.attn_drop = nn.Dropout(dropout) - self.skip_connection = v_skip_connection - self.causal = causal - - if self.skip_connection is None and conv_kernel_size is not None: - self.skip_connection = nn.Conv2d( - in_channels=self.num_heads, - out_channels=self.num_heads, - kernel_size=(conv_kernel_size, 1), - padding=(conv_kernel_size // 2, 0), - bias=False, - groups=self.num_heads, - ) - - if landmark_pooling is not None: - self.landmark_pooling = landmark_pooling - else: - self.landmark_pooling = AvgPool(n=self.num_landmarks) - - # Optional lower triangular masks for causal attention - self.causal_mask_1: Optional[torch.Tensor] = None - self.causal_mask_2: Optional[torch.Tensor] = None - self.causal_mask_3: Optional[torch.Tensor] = None - - # This attention does not support attention masks - self.supports_attention_mask = False - self.supports_key_padding_mask = True - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - key_padding_mask: Optional[torch.Tensor] = None, - *args, - **kwargs, - ): - r""" - key_padding_mask Only a key padding mask is accepted here. The size must be (batch size, sequence length) or - (batch size * num_heads, 1, sequence length). If dimensions are not correct, the mask will - be ignored. An additive mask is expected, meaning float values using "-inf" to mask values - """ - - batched_dim = k.size(0) - seq_len = k.size(-2) - tt = {"dtype": q.dtype, "device": q.device} - - if key_padding_mask is not None: - if key_padding_mask.dtype == torch.bool: - logger.warning( - "Bool mask found, but an additive mask is expected. Converting but this is slow" - ) - - key_padding_mask = bool_mask_to_additive(key_padding_mask) - - if key_padding_mask.ndim == 2: - key_padding_mask = reshape_key_padding_mask( - key_padding_mask, batched_dim - ) - - zeros = torch.zeros_like(key_padding_mask) - ones = torch.ones_like(key_padding_mask) - is_masked = torch.isinf(-key_padding_mask) - - # _mask takes 1 if the token is not padded, otherwise 0. - _mask = torch.where(is_masked, zeros, ones) - _mask = _mask.transpose(2, 1) - assert _mask.shape == (batched_dim, q.shape[1], 1) - - # Mask q and k before pooling - # https://github.com/mlpen/Nystromformer/blob/main/code/attention_nystrom.py#L31 - q = q * _mask - k = k * _mask - - assert key_padding_mask.size() == (batched_dim, 1, seq_len), ( - f"key_padding_mask has invalid dimensions {key_padding_mask.size()}." - f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})." - ) - - if self.num_landmarks >= seq_len: - mask: Optional[torch.Tensor] = None - - if self.causal: - mask = self._triu_mask(batched_dim, seq_len, seq_len, **tt) - - if key_padding_mask is not None: - mask = key_padding_mask if mask is None else mask + key_padding_mask - - x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask) - - else: - q_landmarks = self.landmark_pooling(q) - k_landmarks = self.landmark_pooling(k) - - if self.causal and ( - self.causal_mask_1 is None - or (batched_dim, seq_len, self.num_landmarks) - != self.causal_mask_1.size() - ): - self.causal_mask_1 = self._triu_mask( - batched_dim, seq_len, self.num_landmarks, **tt - ) - self.causal_mask_2 = self._triu_mask( - batched_dim, self.num_landmarks, self.num_landmarks, **tt - ) - self.causal_mask_3 = self._triu_mask( - batched_dim, self.num_landmarks, seq_len, **tt - ) - - mask_3: Optional[torch.Tensor] = self.causal_mask_3 - if key_padding_mask is not None: - mask_3 = ( - key_padding_mask if mask_3 is None else mask_3 + key_padding_mask - ) - - kernel_1 = scaled_query_key_softmax(q=q, k=k_landmarks, att_mask=None) - kernel_2 = scaled_query_key_softmax( - q=q_landmarks, k=k_landmarks, att_mask=None - ) - kernel_3 = scaled_dot_product_attention( - q=q_landmarks, k=k, v=v, att_mask=mask_3 - ) - - kernel_2_inv = ( - iterative_pinv( - kernel_2, self.inv_iterations, self.pinverse_original_init - ) - if self.use_razavi_pinverse - else torch.linalg.pinv(kernel_2) - ) - - x = torch.matmul( - torch.matmul( - kernel_1, - kernel_2_inv, - ), - kernel_3, - ) - - if self.skip_connection: - # Assumption here is that v is 3D. - v_conv = self.skip_connection( - v.reshape(-1, self.num_heads, v.size(-2), v.size(-1)) - ) - x += v_conv.reshape(-1, v_conv.size(-2), v_conv.size(-1)) - x = self.attn_drop(x) - return x - - def _triu_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor: - device = kwargs["device"] - dtype = kwargs["dtype"] - - return torch.triu( - torch.ones(dim_2, dim_3, dtype=dtype, device=device) * float("-inf"), - diagonal=1, - ).expand( - dim_1, -1, -1 - ) # micro optim, save memory on the batch dimension diff --git a/xformers/components/attention/ortho.py b/xformers/components/attention/ortho.py deleted file mode 100644 index 3d6de43a3a..0000000000 --- a/xformers/components/attention/ortho.py +++ /dev/null @@ -1,324 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -import logging -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Union - -import torch -import torch.autograd.profiler as profiler -import torch.nn as nn -import torch.nn.functional as Fn - -from xformers.components.attention import ( - Attention, - AttentionConfig, - AttentionMask, - register_attention, -) -from xformers.components.attention.core import ( - scaled_dot_product_attention, - scaled_query_key_softmax, -) - -logger = logging.getLogger("xformers") - - -class LandmarkSelection(str, Enum): - Orthogonal = "orthogonal" - KMeans = "kmeans" - KMeans_Spherical = "kmeans_spherical" - Random = "random" - - -@dataclass -class OrthoformerAttentionConfig(AttentionConfig): - """ - num_landmarks Number of landmarks to use for softmax approximation. - subsample_fraction Percentage of q_samples matrix to sample per iteration - landmark_selection Landmark selection strategy - """ - - num_landmarks: Optional[int] - subsample_fraction: Optional[float] - landmark_selection: Optional[LandmarkSelection] - - -@register_attention("orthoformer", OrthoformerAttentionConfig) -class OrthoFormerAttention(Attention): - def __init__( - self, - dropout: float, - num_landmarks: int = 32, - subsample_fraction: float = 1.0, - landmark_selection: LandmarkSelection = LandmarkSelection.Orthogonal, - *args, - **kwargs, - ): - """ - Orthoformer_ attention mechanism. - :: - - "Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers" - Patrick, M., Campbell, D., Asano, Y., Misra, I., Metze, F., Feichtenhofer, - C., Vedaldi, A., Henriques, J. (2021) - - Reference codebase: https://github.com/facebookresearch/Motionformer - - .. _Orthoformer: https://arxiv.org/abs/2106.05392 - - """ - super().__init__() - - self.num_landmarks = num_landmarks - self.attn_drop = nn.Dropout(dropout) - self.subsample_fraction = subsample_fraction - self.landmark_selection = landmark_selection - - # Properties specific to this attention mechanism - self.supports_attention_mask = True - self.supports_key_padding_mask = False - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None, - *args, - **kwargs, - ): - N = k.shape[1] - - if self.num_landmarks == N: - # Default attention - x = scaled_dot_product_attention(q, k, v, att_mask) - else: - with torch.no_grad(), profiler.record_function("select landmarks"): - if self.landmark_selection == LandmarkSelection.Orthogonal: - landmarks = self._compute_orthogonal_landmarks(q) - elif self.landmark_selection == LandmarkSelection.Random: - half_L = self.num_landmarks // 2 - landmarks_q = q[:, torch.randint(q.size(1), (half_L,)), :] - landmarks_k = k[:, torch.randint(k.size(1), (half_L,)), :] - landmarks = torch.cat((landmarks_q, landmarks_k), dim=-2) - elif self.landmark_selection == LandmarkSelection.KMeans: - landmarks = self._cluster_landmarks(q) - elif self.landmark_selection == LandmarkSelection.KMeans_Spherical: - landmarks = self._cluster_landmarks(q, spherical=True) - - if att_mask is not None: - logger.warning( - "Orthoformer: attention mask passed alongside with using landmarks to reduce dimensions. \ - The two are typically not compatible" - ) - # FIXME: Should we still accept a mask in that case ? - att_mask = None - - # pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems - # like it could be uninitialized. - kernel_1 = scaled_query_key_softmax(q, landmarks, att_mask) - # pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems - # like it could be uninitialized. - kernel_2 = scaled_query_key_softmax(landmarks, k, att_mask) - x = torch.matmul(kernel_1, torch.matmul(kernel_2, v)) - x = self.attn_drop(x) - return x - - def _cluster_landmarks( - self, - q: torch.Tensor, - spherical: bool = False, - num_iters: int = 6, - ) -> torch.Tensor: - """ - Construct set of landmarks by recursively selecting new landmarks - that are maximally orthogonal to the existing set. - Returns near orthogonal landmarks with shape (B, M, D). - """ - - num_landmarks = min(self.num_landmarks, q.shape[1]) - - if self.subsample_fraction < 1.0: - num_samples = max( - int(self.subsample_fraction * q.size(-2)), num_landmarks - ) # Need at least M/2 samples of queries and keys - q_samples = q[:, torch.randint(q.size(-2), (num_samples,)), :] # (B, N, D) - else: - q_samples = q # (B, N, D) - - if spherical: - q_samples_normalized = Fn.normalize( - q_samples, p=2, dim=-1 - ) # may need to change default eps to eps=1e-8 for mixed precision compatibility - landmarks = self._kmeans_spherical( - q_samples_normalized, num_landmarks, num_iters - ) - else: - landmarks = self._kmeans(q_samples, num_landmarks, num_iters) - return landmarks # (B, M, D) - - def _kmeans(self, x: torch.Tensor, K: int, num_iters: int = 10): - """ - Arguments: - x: (B, N, D) - K: number of clusters - num_iters: the number of kmeans updates - """ - - B, N, D = x.size() - assert K <= N, f"{K} > {N}" - - c = x[ - :, torch.randperm(N, device=x.device)[:K], : - ].clone() # initialisation for the centroids - - with profiler.record_function("kmeans"): - x_i = x.view(B, N, 1, D) - c_j = c.view(B, 1, K, D) - counts = c.new_zeros(B, K) - ones = x.new_ones((B, N)) - - for _ in range(num_iters): - # E step: assign points to the nearest cluster - D_ij = ((x_i - c_j) ** 2).sum(-1) # (B, N, K) squared distances - cl = D_ij.argmin( - dim=-1, keepdim=True - ).long() # (B, N, 1) index of point to nearest cluster - - # M step: update the centroids - c.zero_() - c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster - counts.fill_(1e-6) # avoid div0 - counts.scatter_add_( - -1, cl.squeeze(-1), ones - ) # number of points per cluster - c.divide_(counts.unsqueeze(-1)) # compute the average - - return c - - def _kmeans_spherical(self, x: torch.Tensor, K: int, num_iters=10): - """ - Arguments: - x: (B, N, D) - """ - B, N, D = x.size() - assert K <= N, f"{K} > {N}" - - # initialisation for the centroids - c = x[:, torch.randperm(N, device=x.device)[:K], :].clone() - - with profiler.record_function("kmeans_spherical"): - counts = c.new_zeros(B, K) - ones = x.new_ones((B, N)) - - for _ in range(num_iters): - # E step: assign points to the nearest cluster - D_ij = torch.matmul( - x, c.transpose(-2, -1) - ) # (B, N, K) cosine similarity - cl = D_ij.argmax( - dim=-1, keepdim=True - ).long() # (B, N, 1) index of point to nearest cluster - - # M step: update the centroids - c.zero_() - c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster - counts.fill_(1e-6) # avoid div0 - counts.scatter_add_( - -1, cl.squeeze(-1), ones - ) # number of points per cluster - c.divide_(counts.unsqueeze(-1)) # compute the average - c = Fn.normalize(c, p=2, dim=-1) # renormalise - return c - - def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor: - """ - Construct set of landmarks by recursively selecting new landmarks - that are maximally orthogonal to the existing set. - Returns near orthogonal landmarks with shape (B, M, D). - """ - - if self.subsample_fraction < 1.0: - # Need at least M samples of queries - num_samples = max( - int(self.subsample_fraction * q.size(-2)), self.num_landmarks - ) - q_samples = q[ - :, torch.randint(q.size(-2), (num_samples,), device=q.device), : - ] - else: - # (B, N, D) - q_samples = q - - # may need to change default eps to eps=1e-8 for mixed precision compatibility - q_samples_normalized = Fn.normalize(q_samples, p=2, dim=-1) - B, N, D = q_samples_normalized.shape - - selected_mask = torch.zeros((B, N, 1), device=q_samples_normalized.device) - landmark_mask = torch.ones( - (B, 1, 1), dtype=selected_mask.dtype, device=q_samples_normalized.device - ) - - #  Get initial random landmark - random_idx = torch.randint( - q_samples_normalized.size(-2), (B, 1, 1), device=q_samples_normalized.device - ) - selected_mask.scatter_(-2, random_idx, landmark_mask) - - #  Selected landmarks - selected_landmarks = torch.empty( - (B, self.num_landmarks, D), - device=q_samples_normalized.device, - dtype=q_samples_normalized.dtype, - ) - selected_landmarks[:, 0, :] = q_samples_normalized[ - torch.arange(q_samples_normalized.size(0)), random_idx.view(-1), : - ].view(B, D) - - # Store computed cosine similarities - cos_sims = torch.empty( - (B, N, self.num_landmarks), - device=q_samples_normalized.device, - dtype=q_samples_normalized.dtype, - ) - - for M in range(1, self.num_landmarks): - with profiler.record_function("find new landmark"): - #  Calculate absolute cosine similarity between selected and unselected landmarks - # (B, N, D) * (B, D) -> (B, N) - cos_sims[:, :, M - 1] = torch.einsum( - "b n d, b d -> b n", - q_samples_normalized, - selected_landmarks[:, M - 1, :], - ).abs() - - # (B, N, M) cosine similarities of current set of landmarks wrt all queries and keys - cos_sim_set = cos_sims[:, :, :M] - - #  Get orthogonal landmark: landmark with smallest absolute cosine similarity: - # set cosine similarity for already selected landmarks to > 1 - cos_sim_set.view(-1, M)[selected_mask.flatten().bool(), :] = 10 - - # (B,) - want max for non - selected_landmark_idx = cos_sim_set.amax(-1).argmin(-1) - - #  Add most orthogonal landmark to selected landmarks: - selected_landmarks[:, M, :] = q_samples_normalized[ - torch.arange(q_samples_normalized.size(0)), selected_landmark_idx, : - ].view(B, D) - - #  Removed selected indices from non-selected mask: - selected_mask.scatter_( - -2, selected_landmark_idx.unsqueeze(-1).unsqueeze(-1), landmark_mask - ) - - # (B, M, D) - landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape( - B, -1, D - ) - return landmarks # (B, M, D) diff --git a/xformers/components/attention/pooling.py b/xformers/components/attention/pooling.py deleted file mode 100644 index 6c93193e75..0000000000 --- a/xformers/components/attention/pooling.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -import math -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn as nn - -from xformers.components.attention import Attention, AttentionConfig, register_attention - - -@dataclass -class PoolingAttentionConfig(AttentionConfig): - pool_size: int # dimension of the input sequence - stride: Optional[int] # dimension of the internal space - padding: Optional[int] - - -@register_attention("pooling", PoolingAttentionConfig) -class Pooling(Attention): - def __init__( - self, - pool_size: int = 3, - stride: int = 1, - padding: Optional[int] = None, - *_, - **__, - ): - """ - Pooling token mixing mechanism, as proposed in - `Metaformer is actually what you need for vision`_, Yu et al (2021). - - The original notation is kept as is. - - .. _`Metaformer is actually what you need for vision` : https://arxiv.org/pdf/2111.11418v1.pdf - """ - super().__init__() - - padding = padding if padding is not None else pool_size // 2 - self.pool = nn.AvgPool2d( - pool_size, - stride=stride, - padding=pool_size // 2, - count_include_pad=False, - ) - - # MHA related flags: - # kq need to have the same dimension - self.requires_same_k_q_dimensions = False - - # This attention does not support attention masks - self.supports_attention_mask = False - - # This "attention" (token mixing) skips the multihead attention altogether - self.requires_skip_multi_head = True - self.requires_input_projection = False - - # This operator does not really handle q,k,v - self.requires_same_k_q_dimensions = True - - # This attention requires the 2d structure out of the context, - # implictly assumed to be a squared length - self.requires_squared_context = True - - def forward(self, q: torch.Tensor, *_, **__): - # Expose the 2D token structure - B, HW, C = q.shape - H = int(math.sqrt(HW)) - assert H * H == HW - - q = q.transpose(-2, -1).reshape(B, C, H, H) - - # 2D pool - x_pool = self.pool(q) - q # compensate for the residual path - - # Get back to B HW C - return x_pool.flatten(2, 3).transpose(-2, -1) diff --git a/xformers/components/attention/random.py b/xformers/components/attention/random.py deleted file mode 100644 index e07e6c8679..0000000000 --- a/xformers/components/attention/random.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from dataclasses import dataclass -from typing import Optional, Union - -import torch -import torch.nn as nn - -from xformers.components.attention import ( - Attention, - AttentionConfig, - AttentionMask, - maybe_sparsify, - register_attention, - sparsify, -) -from xformers.components.attention.attention_patterns import ( - causal_1d_pattern, - random_pattern, -) -from xformers.components.attention.core import scaled_dot_product_attention - - -@dataclass -class RandomAttentionConfig(AttentionConfig): - r: Optional[ - float - ] # the ratio of keys that the query can attend to. 1.0 means dense attention - constant_masking: Optional[ - bool - ] # whether the randomness is per query or defined at construction time - force_sparsity: Optional[bool] # use sparsity in any case (potentially slower) - - -@register_attention("random", RandomAttentionConfig) -class RandomAttention(Attention): - def __init__( - self, - dropout: float, - causal: bool = False, - r: float = 0.01, - constant_masking: bool = True, - force_sparsity: bool = False, - *args, - **kwargs, - ): - """ - "Random" attention, as proposed for instance in BigBird_. - Random means in that case that each query can attend to a random set of keys. - This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory. - - Args: - r (float): the ratio in [0,1] of keys that the query can attend to - constant_masking (bool): if true, keep the same random set for all queries. - - .. _BigBird: https://arxiv.org/pdf/2007.14062.pdf - - """ - super().__init__() - - self.attn_drop = nn.Dropout(dropout, inplace=False) - self.causal = causal - self.r = r - self.rand_attention_mask: Optional[torch.Tensor] = None - self.constant_masking = constant_masking - self.force_sparsity = force_sparsity - - # Properties specific to this attention mechanism - self.supports_attention_mask = True - self.supports_key_padding_mask = False - - self.requires_same_k_q_dimensions = True - - def _get_rand_mask(self, shape: torch.Size) -> torch.Tensor: - sparsity = 1 - self.r - mask = random_pattern(shape[1], sparsity=sparsity) - - if self.causal: - mask &= causal_1d_pattern(shape[1]) - - mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask) - - return mask - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, - *args, - **kwargs, - ): - # Rand masking - if not self.constant_masking or self.rand_attention_mask is None: - self.rand_attention_mask = self._get_rand_mask(q.shape).to(q.device) - - # Mask-aware attention - if att_mask is not None: - if att_mask.dtype == torch.bool and isinstance( - self.rand_attention_mask, AttentionMask - ): - mask = self.rand_attention_mask + AttentionMask.from_bool(att_mask) - else: - if isinstance(att_mask, AttentionMask): - # Needed because & op not defined for SparseCS with AttentionMask - att_mask = att_mask.to_bool() - mask = self.rand_attention_mask & att_mask - else: - mask = self.rand_attention_mask - - # Handle q/k/v which would not fit the mask - seq_len = q.shape[-2] - q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v)) - - # Normal attention with the random mask - att = scaled_dot_product_attention( - q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop - ) - - # Take into account an hypothetical padding - return att[:, :seq_len, :] diff --git a/xformers/components/attention/visual.py b/xformers/components/attention/visual.py deleted file mode 100644 index 6ea81f41c2..0000000000 --- a/xformers/components/attention/visual.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -import math -from dataclasses import dataclass - -import torch -import torch.nn as nn - -from xformers.components.attention import Attention, AttentionConfig, register_attention - - -@dataclass -class VisualAttentionConfig(AttentionConfig): - dim_model: int # dimension of the input sequence - - -class LKA(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) - self.conv_spatial = nn.Conv2d( - dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3 - ) - self.conv1 = nn.Conv2d(dim, dim, 1) - - def forward(self, x: torch.Tensor): - u = x.clone() - attn = self.conv0(x) - attn = self.conv_spatial(attn) - attn = self.conv1(attn) - - return u * attn - - -@register_attention("visual", VisualAttentionConfig) -class Visual(Attention): - def __init__( - self, - dim_model: int, - *_, - **__, - ): - """ - Large kernel attention mechanism, as proposed in `Visual Attention Network`_, Guo et al (2022). - The original notation is tentatively kept as is. See https://github.com/Visual-Attention-Network - for the reference implementation - - .. Note: compared to the paper, this block contains the LKA (Large Kernel Attention) - and the prior and posterior transformations (Conv2d and activation) - - .. _`Visual Attention Network` : https://arxiv.org/pdf/2202.09741.pdf - """ - super().__init__() - - self.block = nn.Sequential( - nn.Conv2d(dim_model, dim_model, 1), - nn.GELU(), - LKA(dim_model), - nn.Conv2d(dim_model, dim_model, 1), - ) - - # MHA related flags: - self.requires_same_k_q_dimensions = ( - True # This mechanism only really supports self attention - ) - self.supports_attention_mask = False - self.requires_skip_multi_head = ( - True # This mechanism skips the multihead attention altogether - ) - self.requires_squared_context = ( - True # Recovering the 2D structure from context assumes squared content - ) - - self.requires_input_projection = ( - False # This mechanism does not require that the MHA projects inputs - ) - - def forward(self, q: torch.Tensor, *_, **__): - # Expose the 2D token structure - B, HW, C = q.shape - H = int(math.sqrt(HW)) - assert H * H == HW - - x = q.transpose(-2, -1).reshape(B, C, H, H) - - # Large kernel attention - residual = x.clone() - x = self.block(x) - x = x + residual - - # Get back to B HW C - return x.flatten(2, 3).transpose(-2, -1) diff --git a/xformers/components/feedforward/__init__.py b/xformers/components/feedforward/__init__.py deleted file mode 100644 index 4df0a5ce91..0000000000 --- a/xformers/components/feedforward/__init__.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from pathlib import Path -from typing import Any, Callable, Dict, Set, Union - -from xformers.utils import ( - generate_matching_config, - get_registry_decorator, - import_all_modules, -) - -from .base import Feedforward, FeedforwardConfig # noqa - -# CREDITS: Classy Vision registry mechanism - -FEEDFORWARD_REGISTRY: Dict[str, Any] = {} -FEEDFORWARD_CLASS_NAMES: Set[str] = set() - - -def build_feedforward(config: Union[Dict[str, Any], FeedforwardConfig]): - """Builds a feedforward from a config. - - This assumes a 'name' key in the config which is used to determine what - attention class to instantiate. For instance, a config `{"name": "my_feedforward", - "foo": "bar"}` will find a class that was registered as "my_feedforward" - (see :func:`register_feedforward`) and call .from_config on it.""" - - if not isinstance(config, FeedforwardConfig): - config_instance = generate_matching_config( - config, FEEDFORWARD_REGISTRY[config["name"]].config - ) - else: - config_instance = config - - return FEEDFORWARD_REGISTRY[config_instance.name].constructor.from_config( - config_instance - ) - - -"""Registers a Feedforward subclass. - - This decorator allows xFormers to instantiate a subclass of Feedforward - from a configuration file, even if the class itself is not part of the - xFormers framework. To use it, apply this decorator to a Feedforward - subclass, like this: - - .. code-block:: python - - @dataclass - class MyConfig: - ... - - @register_feedforward('my_ff', MyConfig) - class MyFeedforward(Feedforward): - ... - - To instantiate a feedforward from a configuration file, see :func:`build_feedforward`.""" -register_feedforward: Callable[ - [str, Any], Callable[[Any], Any] -] = get_registry_decorator( - FEEDFORWARD_REGISTRY, FEEDFORWARD_CLASS_NAMES, Feedforward, FeedforwardConfig -) - -from .mlp import MLP # noqa - -__all__ = [ - "MLP", - "Feedforward", - "build_feedforward", - "register_feedforward", -] - -# automatically import any Python files in the directory -import_all_modules(str(Path(__file__).parent), "xformers.components.feedforward") diff --git a/xformers/components/feedforward/base.py b/xformers/components/feedforward/base.py deleted file mode 100644 index 76a357cfb7..0000000000 --- a/xformers/components/feedforward/base.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from abc import ABCMeta, abstractmethod -from dataclasses import asdict, dataclass -from typing import Optional, Type, TypeVar - -import torch.nn as nn - -from xformers._deprecation_warning import deprecated_function -from xformers.components import Activation - -Self = TypeVar("Self", bound="Feedforward") - - -@dataclass -class FeedforwardConfig: - name: str - dim_model: int - dropout: float - activation: Activation - - -# Define the common interface, every feedforward block needs to derive from it -class Feedforward(nn.Module, metaclass=ABCMeta): - @abstractmethod - def __init__( - self, - dim_model: Optional[int] = None, - dropout: Optional[float] = None, - activation: Optional[Activation] = None, - *args, - **kwargs, - ): - super().__init__() - deprecated_function(self) - - # This feedforward requires a CUDA accelerator - self.requires_cuda = False - - # This feedforward requires a context length which is squared, often due to 2D pooling - self.requires_squared_context = False - - @classmethod - def from_config(cls: Type[Self], config: FeedforwardConfig) -> Self: - # Generate the class inputs from the config - fields = asdict(config) - - # Skip all Nones so that default values are used - fields = {k: v for k, v in fields.items() if v is not None} - - return cls(**fields) diff --git a/xformers/components/feedforward/conv_mlp.py b/xformers/components/feedforward/conv_mlp.py deleted file mode 100644 index 895211977d..0000000000 --- a/xformers/components/feedforward/conv_mlp.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -# CREDITS: Largely reusing the code from the reference VAN implementation -# see https://github.com/Visual-Attention-Network - -import math -from dataclasses import dataclass -from typing import Optional - -import torch.nn as nn - -from xformers.components import Activation, build_activation -from xformers.components.feedforward import Feedforward, FeedforwardConfig - -from . import register_feedforward - - -@dataclass -class ConvMlpConfig(FeedforwardConfig): - hidden_layer_multiplier: int - dim_model: int - dim_model_out: Optional[int] - act_layer: Activation - dropout: float - - -@register_feedforward("Conv2DFeedforward", ConvMlpConfig) -class Conv2DFeedforward(Feedforward): - """ - A Convolutional feed-forward network, as proposed in VAN_ (Vision Attention Network, Guo et al.) - - .. _VAN: https://arxiv.org/pdf/2202.09741.pdf - """ - - def __init__( - self, - dim_model: int, - hidden_layer_multiplier: int = 1, - dim_model_out: Optional[int] = None, - activation: Activation = Activation.GeLU, - dropout=0.0, - *args, - **kwargs, - ): - super().__init__() - out_features = dim_model_out or dim_model - hidden_features = hidden_layer_multiplier * dim_model - - self.conv_mlp = nn.Sequential( - nn.Conv2d(dim_model, hidden_features, 1), - nn.Conv2d( - hidden_features, - hidden_features, - 3, - 1, - 1, - bias=True, - groups=hidden_features, - ), - build_activation(activation), - nn.Conv2d(hidden_features, out_features, 1), - nn.Dropout(dropout), - ) - - # This feedforward requires a context length which is squared, often due to 2D pooling - self.requires_squared_context = True - - def init_weights(self, **kwargs): - # Follow the original init, but also make it possible to initialize from the outside - def init_module(m: nn.Module): - if isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - self.apply(init_module) - - def forward(self, x): - # The conv layers expect NCHW, we have NLC by default - B, L, C = x.shape - HW = int(math.sqrt(x.shape[-2])) - assert HW**2 == L, "Conv2DFeedforward requires squared context lengths" - - x = x.reshape((B, HW, HW, C)).swapdims(1, -1) - - # The actual FW, including the 2d convolutions - x = self.conv_mlp(x) - - # back to NLC - x = x.transpose(1, -1) - return x.flatten(1, 2) diff --git a/xformers/components/feedforward/mixture_of_experts.py b/xformers/components/feedforward/mixture_of_experts.py deleted file mode 100644 index b6ab1841f4..0000000000 --- a/xformers/components/feedforward/mixture_of_experts.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -import logging -from dataclasses import dataclass -from enum import Enum -from typing import Any, Callable, Optional, Union - -import torch - -from xformers.components import Activation -from xformers.components.feedforward import ( - Feedforward, - FeedforwardConfig, - register_feedforward, -) - -logger = logging.getLogger("xformers") - - -_is_fairscale_available = True - -try: - import torch.distributed as dist - from fairscale.nn import MOELayer, Top2Gate # type: ignore - - from xformers.components.feedforward import MLP - -except ImportError: - logger.warning( - "Either FairScale or torch distributed is not available, MixtureOfExperts will not be exposed." - " Please install them if you would like to use MoE" - ) - _is_fairscale_available = False - - -if _is_fairscale_available: - - # Credits: initially implemented in FairScale for sanity checking - class RoundRobinGate(torch.nn.Module): - def __init__(self, model_dim, num_experts): - super().__init__() - self.model_dim = model_dim - self.num_experts = num_experts - - def forward(self, input): - s = input.shape[0] - assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0" - capacity = 2 * s // self.num_experts - output = torch.zeros( - s, self.num_experts, capacity, dtype=input.dtype, device=input.device - ) - for i in range(s): - output[i, i % self.num_experts, i // self.num_experts] = 1.0 - return 0.0, output, output.bool() - - class GateConfig(str, Enum): - RoundRobin = "round_robin" - Top2 = "top_2" - # Other gating techniques could be exposed here - - @dataclass - class MoEConfig(FeedforwardConfig): - number_of_experts: int - gate: GateConfig - number_of_local_experts: Optional[int] = None - expert_constructor: Optional[Any] = None - hidden_layer_multiplier: Optional[int] = None - group: Optional[Any] = None - - @register_feedforward("MixtureOfExperts", MoEConfig) - class MixtureOfExperts(Feedforward): - """ - A MLP variant which uses the "Mixture of Experts" paradigm, as described in Gshard_. - xFormers uses the FairScale_ implementation under the hood. - - .. warning: Please note that most of the benefits of MoE are present in a distributed training environmentt - - .. _Gshard: https://arxiv.org/pdf/2006.16668.pdf - .. _FairScale: https://github.com/facebookresearch/fairscale/ - """ - - def __init__( - self, - dim_model: int, - dropout: float, - activation: Activation, - number_of_experts: int, - gate: Union[GateConfig, torch.nn.Module], - number_of_local_experts: Optional[int] = None, - expert_constructor: Optional[Callable[[], torch.nn.Module]] = None, - hidden_layer_multiplier: Optional[int] = None, - group: Optional[Any] = None, - *_, - **__, - ): - super().__init__() - - # Handle a possibly uninitialized process group - assert ( - dist.is_initialized() - ), "Mixture of Experts require torch distributed to be initialized" - - if number_of_local_experts is not None: - assert number_of_experts >= number_of_local_experts - else: - if dist.get_world_size() == 1: - logger.warning("Local experts no specified but world size of 1") - logger.warning("Assuming that all experts are local") - number_of_local_experts = number_of_experts - else: - number_of_local_experts = 1 - - # Programatically handle the gating technique - if not isinstance(gate, torch.nn.Module): - gate_constructor = { - GateConfig.RoundRobin: RoundRobinGate, - GateConfig.Top2: Top2Gate, - }[gate] - - self.gate = gate_constructor(dim_model, number_of_experts) - else: - self.gate = gate - - # Programatically handle the experts - if expert_constructor is None: - - multiplier = ( - hidden_layer_multiplier - if hidden_layer_multiplier is not None - else 4 - ) - - def expert_constructor() -> torch.nn.Module: - return MLP(dim_model, dropout, activation, multiplier) - - assert expert_constructor is not None - - local_experts = torch.nn.ModuleList( - [expert_constructor() for _ in range(number_of_local_experts)] - ) - - self.moe = MOELayer(gate=self.gate, experts=local_experts, group=group) - - self.requires_cuda = True - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - # FairScale MoE assumes that the dimensions are [S, B, E] - # xFormers assumes [B, S, E] - return self.moe(inputs.movedim(0, 1)).movedim(0, 1) diff --git a/xformers/components/feedforward/mlp.py b/xformers/components/feedforward/mlp.py deleted file mode 100644 index fefb328682..0000000000 --- a/xformers/components/feedforward/mlp.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from dataclasses import dataclass - -import torch -import torch.nn as nn - -from xformers.components import Activation, build_activation -from xformers.components.feedforward import Feedforward, FeedforwardConfig - -from . import register_feedforward - - -@dataclass -class MlpConfig(FeedforwardConfig): - hidden_layer_multiplier: int - bias: bool - - -@register_feedforward("MLP", MlpConfig) -class MLP(Feedforward): - def __init__( - self, - dim_model: int, - dropout: float, - activation: Activation, - hidden_layer_multiplier: int, - bias: bool = True, - *args, - **kwargs, - ): - super().__init__() - dim_mlp = hidden_layer_multiplier * dim_model - self.mlp = nn.Sequential( - nn.Linear(in_features=dim_model, out_features=dim_mlp, bias=bias), - build_activation(activation), - nn.Dropout(dropout), - nn.Linear(in_features=dim_mlp, out_features=dim_model, bias=bias), - nn.Dropout(dropout), - ) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - return self.mlp(inputs) diff --git a/xformers/components/multi_head_dispatch.py b/xformers/components/multi_head_dispatch.py deleted file mode 100644 index d0f75b2645..0000000000 --- a/xformers/components/multi_head_dispatch.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -import logging -from dataclasses import asdict, dataclass -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from torch.nn.init import constant_ - -from xformers._deprecation_warning import deprecated_function -from xformers.components.attention import Attention -from xformers.components.input_projection import InputProjection, InputProjectionConfig -from xformers.components.positional_embedding import RotaryEmbedding - -logger = logging.getLogger("xformers") - - -@dataclass -class MultiHeadDispatchConfig: - dim_model: int - num_heads: int - attention: Attention - bias: bool - residual_dropout: float - dim_key: Optional[int] - dim_value: Optional[int] - in_proj_container: Optional[InputProjection] - use_separate_proj_weight: Optional[bool] - use_rotary_embeddings: Optional[bool] - out_proj: Optional[nn.Module] - - def __getitem__(self, item): - return getattr(self, item) - - -# Move head forward and fold into batch dim. dimensions become (B * nh, S, hs) -def _fold_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int): - return t.view(B, S, H, Hs).transpose(1, 2).flatten(start_dim=0, end_dim=1) - - -# Move head forward and fold into batch dim. dimensions become (B, nh, S, hs) -def _split_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int): - return t.view(B, S, H, Hs).transpose(1, 2) - - -class MultiHeadDispatch(nn.Module): - """ - A multi-head masked self-attention dispatch mechanism, with a projection at the end, - following the architecture proposed in `Attention is all you need`_, Vaswani et al. - - The actual attention mechanism can vary, as well as the projections. - This can be used to wrap the proposed attention mechanisms and make them multi-head aware, - but it is optional. - - Args: - dim_model: The model/embedding dimension - num_heads: The number of heads being used - attention: The attention mechanism (needs to be registered to the xformers library) - bias: Whether to use bias for the projections : (Q, K, V, Output) - residual_dropout: Amount of dropout on the residual path - use_separate_proj_weight: Use different weights for the Q, K, V projections - dim_key: Optionally use a different dimension for the key - dim_value: Optionally use a different dimension for the value - in_proj_container: Optionally provide the input projection module - use_rotary_embeddings: Use rotary embeddings - out_proj: Optionally provide the output projection module - - - .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5 - """ - - def __init__( - self, - dim_model: int, - num_heads: int, - attention: Attention, - bias: Tuple[bool, bool, bool, bool] = (True, True, True, True), - residual_dropout: float = 0.0, - use_separate_proj_weight: bool = True, - dim_key: Optional[int] = None, - dim_value: Optional[int] = None, - in_proj_container: Optional[InputProjection] = None, - use_rotary_embeddings: Optional[bool] = False, - out_proj: Optional[nn.Module] = None, - *args, - **kwargs, - ): - super().__init__() - deprecated_function(self) - - if isinstance(bias, bool): - logger.warning( - "Single bias value provided for the MHA projections." - + f" Assuming the same parameter ({bias}) is to be used everywhere" - ) - bias = (bias, bias, bias, bias) - - assert ( - dim_model % num_heads == 0 - ) # static preset for now, each head works on 1/d the embeddings, could be relaxed - assert num_heads > 0 - - # Popular default is that all latent dimensions are the same - dim_key, dim_value = map(lambda x: x if x else dim_model, (dim_key, dim_value)) - - self.num_heads = num_heads - self.dim_key_head = dim_key // num_heads - self.dim_value_head = dim_value // num_heads - self.dim_model = dim_model - self.attention = attention - - # key, query, value projections for all heads - # critical options are - # - are we sharing weights ? - # - are we adding biases ? - if attention.requires_input_projection: - self.in_proj_container = ( - in_proj_container - if in_proj_container is not None - else InputProjection( - query_proj_params=InputProjectionConfig( - dim_model, dim_key, bias=bias[0] - ), - key_proj_params=InputProjectionConfig( - dim_model, dim_key, bias=bias[1] - ), - value_proj_params=InputProjectionConfig( - dim_model, dim_value, bias=bias[2] - ), - use_separate_proj_weight=use_separate_proj_weight, - ) - ) - - # Optional rotary embeddings - self.rotary_embeddings = ( - RotaryEmbedding(self.dim_key_head) if use_rotary_embeddings else None - ) - - # Regularization - self.resid_drop = nn.Dropout(residual_dropout, inplace=False) - - # Output projection - self.proj = ( - out_proj if out_proj else nn.Linear(dim_model, dim_model, bias=bias[3]) - ) - if isinstance(self.proj, nn.Linear) and self.proj.bias is not None: - constant_(self.proj.bias, 0.0) - - def forward( - self, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - value: Optional[torch.Tensor] = None, - att_mask: Optional[torch.Tensor] = None, - key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Expected input dimensions are [batch size, sequence length, embed dim] - Output dimensions are [batch size, sequence length, embed dim] - """ - - if key is None: - key = query - if value is None: - value = query - - if query.shape[0] != key.shape[0] or query.shape[0] != value.shape[0]: - max_batch = max((query.shape[0], key.shape[0], value.shape[0])) - query, key, value = map( - lambda x: x.expand(max_batch, -1, -1), [query, key, value] - ) - - B, S_Q, _ = query.size() # Batch x Sequence x Embedding (latent) - _, S_K, _ = key.size() # K, Q's sequence length could differ - - # Catch different query and key length but a causal attention - if S_Q != S_K: - assert ( - not self.attention.requires_same_k_q_dimensions - ), "This attention mechanism requires query and key to have the same sequence (context) lengths" - - if hasattr(self.attention, "causal"): - assert not self.attention.causal, ( - "Causal attention is not supported when key and query have different sequence lengths.\n" - + "In that case causality is ill-determined. Please pad your sequences accordingly" - ) - - kw_mask_args = {} - if att_mask is not None: - assert ( - self.attention.supports_attention_mask - ), "This attention does not support attention masks" - kw_mask_args["att_mask"] = att_mask - - if key_padding_mask is not None: - assert ( - self.attention.supports_key_padding_mask - ), "This attention does not support key padding masks" - kw_mask_args["key_padding_mask"] = key_padding_mask - - if self.attention.requires_skip_multi_head: - return self.attention(query, key, value, **kw_mask_args) - - # Calculate query, key, values for all heads in batch - if self.attention.requires_input_projection: - q, k, v = self.in_proj_container(query=query, key=key, value=value) - else: - k, q, v = key, query, value - - # Check the dimensions properly - def check(t, name): - assert ( - t.shape[2] % self.num_heads == 0 - ), f"the {name} embeddings need to be divisible by the number of heads" - - check(q, "projected query") - check(v, "projected value") - check(k, "projected key") - - # Optional: rotary embedding, add relative positioning information - if self.rotary_embeddings: - # rotary requires the head dimension - q = _split_heads(q, B, S_Q, self.num_heads, self.dim_key_head) - k = _split_heads(k, B, S_K, self.num_heads, self.dim_key_head) - v = _split_heads(v, B, S_K, self.num_heads, self.dim_value_head) - - q, k = self.rotary_embeddings(q=q, k=k) - - if not self.attention.requires_head_dimension: - q, k, v = q.flatten(0, 1), k.flatten(0, 1), v.flatten(0, 1) - - else: - # Reshape k/q/v to either expose the heads, or fold the head dimension into the batch - reshape_fn = ( - _split_heads if self.attention.requires_head_dimension else _fold_heads - ) - - q = reshape_fn(q, B, S_Q, self.num_heads, self.dim_key_head) - k = reshape_fn(k, B, S_K, self.num_heads, self.dim_key_head) - v = reshape_fn(v, B, S_K, self.num_heads, self.dim_value_head) - - # Self-attend - y = self.attention(q, k, v, **kw_mask_args) - - # Re-assemble all head outputs side by side - y = ( - y.view(B, self.num_heads, S_Q, self.dim_value_head) - .transpose(1, 2) - .flatten(start_dim=2, end_dim=3) - ) - - # Output projection, dropout and good to go - y = self.resid_drop(self.proj(y)) - - # Return the same sequence size as the input - return y - - @classmethod - def from_config(cls, config: MultiHeadDispatchConfig): - # Generate the class inputs from the config - fields = asdict(config) - - # Skip all Nones so that default values are used - fields = {k: v for k, v in fields.items() if v is not None} - - return cls(**fields) diff --git a/xformers/components/patch_embedding.py b/xformers/components/patch_embedding.py deleted file mode 100644 index dc3afb8d2e..0000000000 --- a/xformers/components/patch_embedding.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import math -from dataclasses import dataclass -from enum import Enum - -import torch - -from xformers._deprecation_warning import deprecated_function - - -class PoolType(str, Enum): - Conv2D = "CONV_2D" - # ... - # TODO: Support more cases ? - - -@dataclass -class PatchEmbeddingConfig: - """ - The configuration for the patch embedding layer, which takes the raw token passed in - and returns a pooled representation along a given embedding dimension. - - This typically trades the spatial (context length) representation with the embedding size - - This is canonicaly used by ViT, but other papers (like MetaFormer or other hierarchical transformers) - propose a more general use case for this - """ - - in_channels: int - out_channels: int - kernel_size: int - stride: int - padding: int = 0 - pool_type: PoolType = PoolType.Conv2D - - -class ConditionalReshape(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - deprecated_function(self) - - def forward(self, x): - if x.ndim == 3: - B, HW, C = x.shape - # NOTE: We're assuming a square sample here - H = int(math.sqrt(HW)) - assert H * H == HW, f"{H, HW}" - x = x.transpose(1, 2).reshape(B, C, H, H) - - return x - - -class PatchToSequence(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - deprecated_function(self) - - def forward(self, x): - return x.flatten(2, 3).transpose(1, 2).contiguous() # B HW C - - -def build_patch_embedding(config: PatchEmbeddingConfig): - if not isinstance(config, PatchEmbeddingConfig): - config = PatchEmbeddingConfig(**config) - - if config.pool_type == PoolType.Conv2D: - pool = torch.nn.Conv2d( - config.in_channels, - config.out_channels, - kernel_size=config.kernel_size, - stride=config.stride, - padding=config.padding, - ) - else: - raise NotImplementedError - - # The patch embedding supposes that the input really is 2D in essence - # If this block is in the middle of a stack, we need to reshape - return torch.nn.Sequential(ConditionalReshape(), pool, PatchToSequence()) diff --git a/xformers/components/positional_embedding/__init__.py b/xformers/components/positional_embedding/__init__.py deleted file mode 100644 index 0f7f02c2d1..0000000000 --- a/xformers/components/positional_embedding/__init__.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from pathlib import Path -from typing import Any, Callable, Dict, Set, Union - -from xformers.utils import ( - generate_matching_config, - get_registry_decorator, - import_all_modules, -) - -from .base import PositionEmbedding, PositionEmbeddingConfig # noqa - -# CREDITS: Classy Vision registry mechanism - -POSITION_EMBEDDING_REGISTRY: Dict[str, Any] = {} -POSITION_EMBEDDING_CLASS_NAMES: Set[str] = set() - - -def build_positional_embedding(config: Union[Dict[str, Any], PositionEmbeddingConfig]): - """Builds a position encoding from a config. - - This assumes a 'name' key in the config which is used to determine what - attention class to instantiate. For instance, a config `{"name": "my_position_encoding", - "foo": "bar"}` will find a class that was registered as "my_position_encoding" - (see :func:`register_positional_embedding`) and call .from_config on it.""" - - if not isinstance(config, PositionEmbeddingConfig): - config_instance = generate_matching_config( - config, POSITION_EMBEDDING_REGISTRY[config["name"]].config - ) - else: - config_instance = config - - return POSITION_EMBEDDING_REGISTRY[config_instance.name].constructor.from_config( - config_instance - ) - - -"""Registers a PositionEncoding subclass. - - This decorator allows xFormers to instantiate a subclass of PositionEncoding - from a configuration file, even if the class itself is not part of the - xFormers framework. To use it, apply this decorator to a `PositionEncoding` - subclass, like this: - - .. code-block:: python - - @dataclass - class MyConfig: - ... - - @register_positional_embedding('my_encoding', MyConfig) - class MyEncoding(PositionEncoding): - ... - - To instantiate a position encoding from a configuration file, see :func:`build_positional_embedding`.""" -register_positional_embedding: Callable[ - [str, Any], Callable[[Any], Any] -] = get_registry_decorator( - POSITION_EMBEDDING_REGISTRY, - POSITION_EMBEDDING_CLASS_NAMES, - PositionEmbedding, - PositionEmbeddingConfig, -) - - -from .rotary import RotaryEmbedding # noqa -from .sine import SinePositionalEmbedding # type: ignore # noqa -from .vocab import VocabEmbedding # noqa - -__all__ = [ - "RotaryEmbedding", - "SinePositionalEmbedding", - "VocabEmbedding", - "build_positional_embedding", - "register_positional_embedding", -] - -# automatically import any Python files in the directory -import_all_modules( - str(Path(__file__).parent), "xformers.components.positional_embedding" -) diff --git a/xformers/components/positional_embedding/base.py b/xformers/components/positional_embedding/base.py deleted file mode 100644 index c998487660..0000000000 --- a/xformers/components/positional_embedding/base.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from abc import ABCMeta, abstractmethod -from dataclasses import asdict, dataclass -from typing import Type, TypeVar - -import torch.nn as nn - -from xformers._deprecation_warning import deprecated_function - -Self = TypeVar("Self", bound="PositionEmbedding") - - -@dataclass -class PositionEmbeddingConfig: - name: str - dim_model: int - seq_len: int - - -class PositionEmbedding(nn.Module, metaclass=ABCMeta): - @abstractmethod - def __init__(self, *args, **kwargs) -> None: - super().__init__() - deprecated_function(self) - - @classmethod - def from_config(cls: Type[Self], config: PositionEmbeddingConfig) -> Self: - # Generate the class inputs from the config - fields = asdict(config) - - # Skip all Nones so that default values are used - fields = {k: v for k, v in fields.items() if v is not None} - return cls(**fields) diff --git a/xformers/components/positional_embedding/param.py b/xformers/components/positional_embedding/param.py deleted file mode 100644 index bc96cf6787..0000000000 --- a/xformers/components/positional_embedding/param.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from dataclasses import dataclass - -import torch - -from xformers.components.positional_embedding import ( - PositionEmbedding, - PositionEmbeddingConfig, - register_positional_embedding, -) - - -@dataclass -class LearnablePositionalEmbeddingConfig(PositionEmbeddingConfig): - name: str - seq_len: int - dim_model: int - add_class_token: bool - - -@register_positional_embedding("learnable", LearnablePositionalEmbeddingConfig) -class LearnablePositionalEmbedding(PositionEmbedding): - def __init__( - self, seq_len: int, dim_model: int, add_class_token: bool = False, *_, **__ - ): - super().__init__() - - # 0.02 is BERT initialization - self.pos_emb = torch.nn.Parameter( - torch.randn(1, seq_len + int(add_class_token), dim_model) * 0.02 - ) - - self.class_token = ( - torch.nn.Parameter(torch.zeros(dim_model)) if add_class_token else None - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.class_token is not None: - # Prepend class token - clf_token = ( - torch.ones(x.shape[0], 1, self.pos_emb.shape[-1], device=x.device) - * self.class_token - ) - x = torch.cat([clf_token, x], dim=1) - - if x.ndim == 2: - x = x.unsqueeze(-1) - - return x + self.pos_emb diff --git a/xformers/components/positional_embedding/rotary.py b/xformers/components/positional_embedding/rotary.py deleted file mode 100644 index 551089b3b3..0000000000 --- a/xformers/components/positional_embedding/rotary.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox -# NOTE: Almost the same right now, moving parts to Triton is the next step - -from typing import Tuple - -import torch - - -def rotate_half(x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -@torch.jit.script -def apply_rotary_pos_emb(x, cos, sin): - # NOTE: This could probably be moved to Triton - - # Handle a possible sequence length mismatch in between q and k - cos = cos[:, :, : x.shape[-2], :] - sin = sin[:, :, : x.shape[-2], :] - - return (x * cos) + (rotate_half(x) * sin) - - -class RotaryEmbedding(torch.nn.Module): - """ - The rotary position embeddings from RoFormer_ (Su et. al). - A crucial insight from the method is that the query and keys are - transformed by rotation matrices which depend on the relative positions. - - Other implementations are available in the Rotary Transformer repo_ and in - GPT-NeoX_, GPT-NeoX was an inspiration - - .. _RoFormer: https://arxiv.org/abs/2104.09864 - .. _repo: https://github.com/ZhuiyiTechnology/roformer - .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox - - - .. warning: Please note that this embedding is not registered on purpose, as it is transformative - (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis - """ - - def __init__(self, dim_model: int, *_, **__): - super().__init__() - # Generate and save the inverse frequency buffer (non trainable) - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model)) - self.register_buffer("inv_freq", inv_freq) - - self._seq_len_cached = None - self._cos_cached = None - self._sin_cached = None - - def _update_cos_sin_tables(self, x, seq_dimension=1): - seq_len = x.shape[seq_dimension] - - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if ( - seq_len != self._seq_len_cached - or self._cos_cached.device != x.device - or self._cos_cached.dtype != x.dtype - ): - self._seq_len_cached = seq_len - t = torch.arange( - x.shape[seq_dimension], device=x.device, dtype=torch.float32 - ) - freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - - self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) - self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) - - return self._cos_cached, self._sin_cached - - def forward( - self, q: torch.Tensor, k: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - self._cos_cached, self._sin_cached = self._update_cos_sin_tables( - k, seq_dimension=-2 - ) - - return ( - apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), - apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), - ) diff --git a/xformers/components/positional_embedding/sine.py b/xformers/components/positional_embedding/sine.py deleted file mode 100644 index 321920c5ac..0000000000 --- a/xformers/components/positional_embedding/sine.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -# Silence Mypy errors in this file. -# type: ignore - -import math - -import torch - -from xformers.components.positional_embedding import ( - PositionEmbedding, - PositionEmbeddingConfig, - register_positional_embedding, -) - - -@register_positional_embedding("sine", PositionEmbeddingConfig) -class SinePositionalEmbedding(PositionEmbedding): - def __init__(self, dim_model: int, *args, **kwargs): - super().__init__() - self.dim_model = dim_model - - def forward(self, x: torch.Tensor) -> torch.Tensor: - seq_len = x.shape[1] - pos = ( - torch.arange(0, seq_len, device=x.device, dtype=torch.float32) - .unsqueeze(1) - .repeat(1, self.dim_model) - ) - dim = ( - torch.arange(0, self.dim_model, device=x.device, dtype=torch.float32) - .unsqueeze(0) - .repeat(seq_len, 1) - ) - div = torch.exp(-math.log(10000) * (2 * (dim // 2) / self.dim_model)) - pos *= div - pos[:, 0::2] = torch.sin(pos[:, 0::2]) - pos[:, 1::2] = torch.cos(pos[:, 1::2]) - - output = x.unsqueeze(-1) if x.ndim == 2 else x - - return output + pos.unsqueeze(0) diff --git a/xformers/components/positional_embedding/vocab.py b/xformers/components/positional_embedding/vocab.py deleted file mode 100644 index dd18777eb1..0000000000 --- a/xformers/components/positional_embedding/vocab.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn as nn - -from xformers.components.positional_embedding import ( - PositionEmbedding, - PositionEmbeddingConfig, - register_positional_embedding, -) - - -@dataclass -class VocabEmbeddingConfig(PositionEmbeddingConfig): - vocab_size: int - dropout: float - - -@register_positional_embedding("vocab", VocabEmbeddingConfig) -class VocabEmbedding(PositionEmbedding): - def __init__( - self, - dim_model: int, - seq_len: int, - vocab_size: int, - dropout: float = 0.0, - *args, - **kwargs - ): - super().__init__() - - self.vocab_size = vocab_size - self.dim_model = dim_model - - self.dropout = torch.nn.Dropout(p=dropout) - self.position_embeddings = nn.Embedding(seq_len, self.dim_model) - self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) - - self.position_ids: Optional[torch.Tensor] = None - - self.init_weights() - - def init_weights(self, gain: float = 1.0): - torch.nn.init.normal_(self.position_embeddings.weight, std=0.02 * gain) - torch.nn.init.normal_(self.word_embeddings.weight, std=0.02 * gain) - - def forward(self, x: torch.Tensor): - position_ids = torch.arange(x.shape[1], dtype=torch.long, device=x.device)[ - None, : - ].repeat(x.shape[0], 1) - - X_token = self.word_embeddings(x) - X_pos = self.position_embeddings(position_ids) - - X = X_token + X_pos - X = self.dropout(X) - - return X diff --git a/xformers/components/reversible.py b/xformers/components/reversible.py deleted file mode 100644 index b961018b62..0000000000 --- a/xformers/components/reversible.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import List - -import torch -import torch.nn as nn -from torch.autograd.function import Function -from torch.utils.checkpoint import get_device_states, set_device_states - -from xformers._deprecation_warning import deprecated_function -from xformers.components import RequiresWrappedInputs - -# CREDITS: Code adapted from -# https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py -# https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py, -# https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html - - -# pyre-fixme[13]: `cpu_state` is not initialized in the constructor. -class Deterministic(nn.Module): - def __init__(self, net: nn.Module): - super().__init__() - deprecated_function(self) - self.net = net - self.cpu_state: torch.Tensor = torch.get_rng_state() - self.cuda_in_fwd: bool = False - self.gpu_devices: List[int] = [] - self.gpu_states: List[torch.Tensor] = [] - self.wrap_inputs = isinstance(net, RequiresWrappedInputs) - - def record_rng(self, *args): - self.cpu_state = torch.get_rng_state() - if torch.cuda._initialized: - self.cuda_in_fwd = True - self.gpu_devices, self.gpu_states = get_device_states(*args) - - def forward(self, *args, record_rng: bool = False, set_rng: bool = False, **kwargs): - if record_rng: - self.record_rng(*args) - - if not set_rng: - # Normal FW run - if self.wrap_inputs: - return self.net(inputs=args, **kwargs) - else: - return self.net(*args, **kwargs) - - else: # pragma: no cover # this is called in the backward pass, not picked up - # This is analogous to checkpointing, reset the original random state - rng_devices: List[int] = [] - if self.cuda_in_fwd: - rng_devices = self.gpu_devices - - with torch.random.fork_rng(devices=rng_devices, enabled=True): - torch.set_rng_state(self.cpu_state) - if self.cuda_in_fwd: - set_device_states(self.gpu_devices, self.gpu_states) - - if self.wrap_inputs: - return self.net(inputs=args, **kwargs) - else: - return self.net(*args, **kwargs) - - -class ReversibleBlock(nn.Module): - def __init__(self, f: nn.Module, g: nn.Module, split_dim: int = -1): - super().__init__() - self.f = Deterministic(f) - self.g = Deterministic(g) - self.split_dim = split_dim - - def forward(self, x: torch.Tensor, f_args={}, g_args={}): - x1, x2 = torch.chunk(x, 2, dim=-1) - y1, y2 = None, None - - with torch.no_grad(): - y1 = x1 + self.f(x2, record_rng=self.training, **f_args) - y2 = x2 + self.g(y1, record_rng=self.training, **g_args) - - return torch.cat([y1, y2], dim=self.split_dim) - - def backward_pass( - self, y: torch.Tensor, dy: torch.Tensor, f_args={}, g_args={} - ): # pragma: no cover # this is covered, but called directly from C++ - y1, y2 = torch.chunk(y, 2, dim=self.split_dim) - del y - - dy1, dy2 = torch.chunk(dy, 2, dim=self.split_dim) - del dy - - with torch.enable_grad(): - y1.requires_grad = True - gy1 = self.g(y1, set_rng=True, **g_args) - torch.autograd.backward(gy1, dy2) - - with torch.no_grad(): - x2 = y2 - gy1 - del y2, gy1 - - dx1 = dy1 + y1.grad - del dy1 - y1.grad = None - - with torch.enable_grad(): - x2.requires_grad = True - fx2 = self.f(x2, set_rng=True, **f_args) - torch.autograd.backward(fx2, dx1) - - with torch.no_grad(): - x1 = y1 - fx2 - del y1, fx2 - - dx2 = dy2 + x2.grad - del dy2 - x2.grad = None - - x = torch.cat([x1, x2.detach()], dim=self.split_dim) - dx = torch.cat([dx1, dx2], dim=self.split_dim) - - return x, dx - - -class _ReversibleFunction(Function): - @staticmethod - def forward(ctx, x, blocks, kwargs): - ctx.kwargs = kwargs - for block in blocks: - x = block(x, **kwargs) - ctx.y = x.detach() - ctx.blocks = blocks - return x - - @staticmethod - def backward( - ctx, dy - ): # pragma: no cover # this is covered, but called directly from C++ - y = ctx.y - kwargs = ctx.kwargs - for block in ctx.blocks[::-1]: - y, dy = block.backward_pass(y, dy, **kwargs) - return dy, None, None - - -class ReversibleSequence(nn.Module): - def __init__(self, blocks: nn.ModuleList): - super().__init__() - deprecated_function(self) - - # pyre-fixme[23]: Unable to unpack `torch.nn.Module` into 2 values. - self.blocks = nn.ModuleList([ReversibleBlock(f, g) for f, g in blocks]) - - def forward(self, x, arg_route=(True, False), **kwargs): - f_args, g_args = map(lambda route: kwargs if route else {}, arg_route) - block_kwargs = {"f_args": f_args, "g_args": g_args} - - return _ReversibleFunction.apply(x, self.blocks, block_kwargs) diff --git a/xformers/components/simplicial_embedding.py b/xformers/components/simplicial_embedding.py deleted file mode 100644 index a6ccbdac64..0000000000 --- a/xformers/components/simplicial_embedding.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import asdict, dataclass -from typing import Optional, Type, TypeVar - -import torch - -from xformers._deprecation_warning import deprecated_function - -Self = TypeVar("Self", bound="SimplicialEmbedding") - - -@dataclass -class SimplicialEmbeddingConfig: - L: int - temperature: float - - -class SimplicialEmbedding(torch.nn.Module): - """ - An implementation of the "Simplicial Embeddings"_, as proposed by Lavoie et. al - - Arguments: - - L: the number of embedding chunks - - temperature: optional scaling parameter for the softmax operation. - A small (<1.) temperature will lead to a sparse representation (up to one-hot), - while a large (>1.) temperature will make the vector more uniform - - _"Simplicial Embeddings": https://arxiv.org/pdf/2204.00616.pdf - """ - - def __init__(self, L: int, temperature: Optional[float] = None) -> None: - super().__init__() - deprecated_function(self) - self.L = L - self.temperature = temperature - - def forward(self, x: torch.Tensor) -> torch.Tensor: - assert ( - x.shape[-1] % self.L == 0 - ), f"The embedding dimension {x.shape[-1]} is not divisible by the chosen L parameter {self.L}" - - # Separate the input tensor into V chunks - B, C, E = x.shape - V = E // self.L - - Vs = x.reshape(B, C, self.L, V) - - # Softmax normalize them, with the proposed temperature - # This is done over the last dimension, so only within Vs - if self.temperature is not None: - Vs /= self.temperature - - Vs = torch.nn.functional.softmax(Vs, dim=-1) - - # Concatenate back and return - return Vs.reshape(B, C, E) - - @classmethod - def from_config(cls: Type[Self], config: SimplicialEmbeddingConfig) -> Self: - # Generate the class inputs from the config - fields = asdict(config) - - return cls(**fields)