Skip to content

Commit

Permalink
using nn.Linear in the projections to match Timm, probably simpler
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed May 9, 2022
1 parent 7705e5e commit d572815
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 245 deletions.
1 change: 1 addition & 0 deletions examples/microViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
"dropout": attn_pdrop,
"causal": False,
},
"self_attention": True,
},
"feedforward_config": {
"name": "FusedMLP",
Expand Down
79 changes: 62 additions & 17 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# LICENSE file in the root directory of this source tree.

import math
from contextlib import nullcontext
from typing import Tuple

import pytest
import torch

from xformers.components import InProjContainer, InProjParams, MultiHeadDispatch
from xformers.components import InProjParams, InputProjection, MultiHeadDispatch

# Automatically test all the registered attentions
from xformers.components.attention import (
Expand Down Expand Up @@ -52,6 +53,7 @@ def _get_multihead(
"num_heads": heads,
"dim_head": MODEL / heads,
"num_rules": 2, # Compositional Attention
"r": 0.5, # make sure that there's something left to drop / random attention
}

if skip_output_projection:
Expand Down Expand Up @@ -181,10 +183,14 @@ def test_kqv_ordering(
@pytest.mark.parametrize("proj_bias", [False, True])
@pytest.mark.parametrize("same_sizes", [False, True])
@pytest.mark.parametrize("same_settings", [False, True])
@pytest.mark.parametrize("self_attention", [False, True])
def test_inproj(
small_init: bool, proj_bias: bool, same_sizes: bool, same_settings: bool
small_init: bool,
proj_bias: bool,
same_sizes: bool,
same_settings: bool,
self_attention: bool,
):

test_config = {
"name": "scaled_dot_product",
"dropout": 0.1,
Expand All @@ -201,11 +207,36 @@ def test_inproj(
in_params = InProjParams(MODEL, MODEL, proj_bias, small_init)

if same_settings:
in_proj = InProjContainer(in_params, None, None)
in_proj = InputProjection(
in_params,
None,
None,
self_attention=self_attention,
)
else:
out_features = MODEL if same_sizes else MODEL // 2
in_params_flip = InProjParams(MODEL, out_features, proj_bias, small_init)
in_proj = InProjContainer(in_params, in_params_flip, in_params_flip)
in_params_flip = InProjParams(
MODEL,
out_features,
not proj_bias,
small_init,
)

# Different settings and self attention is not supported, and should raise
context = nullcontext() if not self_attention else pytest.raises(AssertionError)

with context:
in_proj = InputProjection(
in_params,
in_params_flip,
in_params_flip,
self_attention=self_attention,
)

if self_attention:
return # done testing this case

in_proj = InputProjection(in_params, in_params_flip, in_params_flip)

# build a multi head dispatch to test this attention mechanism
multi_head = MultiHeadDispatch(
Expand All @@ -215,6 +246,7 @@ def test_inproj(
num_heads=1,
attention=attention,
in_proj_container=in_proj,
self_attention=self_attention,
)

# Check kqv are not flipped
Expand All @@ -229,18 +261,29 @@ def test_inproj(
dim=1,
).expand((BATCH, SEQ, MODEL))

k = torch.cat(
(
torch.zeros((1, MODEL // 2)),
torch.rand((1, MODEL // 2)),
),
dim=1,
).expand((BATCH, SEQ, MODEL))
v = torch.rand(BATCH, SEQ, MODEL)

# just check that a FW does not assert out
if self_attention:
k = q
v = q

else:
k = torch.cat(
(
torch.zeros((1, MODEL // 2)),
torch.rand((1, MODEL // 2)),
),
dim=1,
).expand((BATCH, SEQ, MODEL))
v = torch.rand(BATCH, SEQ, MODEL)

_ = multi_head(query=q, key=k, value=v)

# Check that self_attention and mismatching inputs asserts out
if self_attention:
with pytest.raises(AssertionError):
v = torch.rand(BATCH, SEQ, MODEL)
_ = multi_head(query=q, key=k, value=v)


@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
Expand Down Expand Up @@ -317,13 +360,14 @@ def test_causal(
"""

torch.random.manual_seed(42)
torch.backends.cuda.matmul.allow_tf32 = False

device = torch.device("cuda")

multi_head = _get_multihead(
attention_name,
0.0,
0.0,
attn_dropout=0.0,
res_dropout=0.0,
causal=True,
heads=heads,
device=device,
Expand All @@ -335,6 +379,7 @@ def test_causal(
.unsqueeze(0)
.expand(1, -1, -1)
)

q = (
torch.triu(torch.ones((SEQ, SEQ), device=device), diagonal=0)
.unsqueeze(0)
Expand Down
1 change: 1 addition & 0 deletions xformers/benchmarks/benchmark_pytorch_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def bench_pytorch_encoder(
"seq_len": seq,
},
"dim_model": emb,
"self_attention": True,
},
"feedforward_config": {
"name": "FusedMLP",
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .activations import Activation, build_activation # noqa
from .attention import Attention, build_attention # noqa
from .in_proj_container import InProjContainer, InProjParams # noqa
from .input_projection import InProjParams, InputProjection # noqa
from .multi_head_dispatch import MultiHeadDispatch # noqa
from .multi_head_dispatch import MultiHeadDispatchConfig
from .patch_embedding import PatchEmbeddingConfig # noqa
Expand Down
8 changes: 4 additions & 4 deletions xformers/components/attention/compositional.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
register_attention,
)
from xformers.components.attention.core import _softmax
from xformers.components.in_proj_container import InProjContainer, InProjParams
from xformers.components.input_projection import InProjParams, InputProjection


def _either_or(a: Optional[int], b: int) -> int:
Expand All @@ -51,7 +51,7 @@ class CompositionalAttentionConfig(AttentionConfig):
q_compose: bool = False
bias: bool = True
causal: Optional[bool] = False
in_proj_container: Optional[InProjContainer] = None
in_proj_container: Optional[InputProjection] = None
use_separate_proj_weight: Optional[bool] = False


Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
qk_rule=False,
nonlinear=False,
q_compose=False,
in_proj_container: Optional[InProjContainer] = None,
in_proj_container: Optional[InputProjection] = None,
use_separate_proj_weight: Optional[bool] = False,
bias=True,
causal=False,
Expand All @@ -126,7 +126,7 @@ def __init__(
self.in_proj_container = (
in_proj_container
if in_proj_container is not None
else InProjContainer(
else InputProjection(
query_proj_params=InProjParams(dim_model, dim_key, bias=bias),
key_proj_params=InProjParams(dim_model, dim_key, bias=bias)
if use_separate_proj_weight
Expand Down
Loading

0 comments on commit d572815

Please sign in to comment.