Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor][DRAFT] Tentative input projection (inc. self attention) rewrite #299

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Better asserts on QKV dimensions [#264]
- Better perfs for FusedMLP and FusedLinearLayer [#283]
- Deepnorm init missing self-attention [#284]
- Better self-attention projection [#299]

### Added
- Simplicial Embeddings [#259]
Expand Down
3 changes: 2 additions & 1 deletion examples/microViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
"dropout": attn_pdrop,
"causal": False,
},
"self_attention": True,
},
"feedforward_config": {
"name": "FusedMLP",
Expand Down Expand Up @@ -246,7 +247,7 @@ def test_step(self, batch, _):
num_classes=num_classes,
attention="scaled_dot_product",
classifier=Classifier.TOKEN,
layer_norm_style="pre",
layer_norm_style="deepnorm",
use_rotary_embeddings=True,
)
trainer = pl.Trainer(
Expand Down
79 changes: 62 additions & 17 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# LICENSE file in the root directory of this source tree.

import math
from contextlib import nullcontext
from typing import Tuple

import pytest
import torch

from xformers.components import InProjContainer, InProjParams, MultiHeadDispatch
from xformers.components import InProjParams, InputProjection, MultiHeadDispatch

# Automatically test all the registered attentions
from xformers.components.attention import (
Expand Down Expand Up @@ -52,6 +53,7 @@ def _get_multihead(
"num_heads": heads,
"dim_head": MODEL / heads,
"num_rules": 2, # Compositional Attention
"r": 0.5, # make sure that there's something left to drop / random attention
}

if skip_output_projection:
Expand Down Expand Up @@ -181,10 +183,14 @@ def test_kqv_ordering(
@pytest.mark.parametrize("proj_bias", [False, True])
@pytest.mark.parametrize("same_sizes", [False, True])
@pytest.mark.parametrize("same_settings", [False, True])
@pytest.mark.parametrize("self_attention", [False, True])
def test_inproj(
small_init: bool, proj_bias: bool, same_sizes: bool, same_settings: bool
small_init: bool,
proj_bias: bool,
same_sizes: bool,
same_settings: bool,
self_attention: bool,
):

test_config = {
"name": "scaled_dot_product",
"dropout": 0.1,
Expand All @@ -201,11 +207,36 @@ def test_inproj(
in_params = InProjParams(MODEL, MODEL, proj_bias, small_init)

if same_settings:
in_proj = InProjContainer(in_params, None, None)
in_proj = InputProjection(
in_params,
None,
None,
self_attention=self_attention,
)
else:
out_features = MODEL if same_sizes else MODEL // 2
in_params_flip = InProjParams(MODEL, out_features, proj_bias, small_init)
in_proj = InProjContainer(in_params, in_params_flip, in_params_flip)
in_params_flip = InProjParams(
MODEL,
out_features,
not proj_bias,
small_init,
)

# Different settings and self attention is not supported, and should raise
context = nullcontext() if not self_attention else pytest.raises(AssertionError)

with context:
in_proj = InputProjection(
in_params,
in_params_flip,
in_params_flip,
self_attention=self_attention,
)

if self_attention:
return # done testing this case

in_proj = InputProjection(in_params, in_params_flip, in_params_flip)

# build a multi head dispatch to test this attention mechanism
multi_head = MultiHeadDispatch(
Expand All @@ -215,6 +246,7 @@ def test_inproj(
num_heads=1,
attention=attention,
in_proj_container=in_proj,
self_attention=self_attention,
)

# Check kqv are not flipped
Expand All @@ -229,18 +261,29 @@ def test_inproj(
dim=1,
).expand((BATCH, SEQ, MODEL))

k = torch.cat(
(
torch.zeros((1, MODEL // 2)),
torch.rand((1, MODEL // 2)),
),
dim=1,
).expand((BATCH, SEQ, MODEL))
v = torch.rand(BATCH, SEQ, MODEL)

# just check that a FW does not assert out
if self_attention:
k = q
v = q

else:
k = torch.cat(
(
torch.zeros((1, MODEL // 2)),
torch.rand((1, MODEL // 2)),
),
dim=1,
).expand((BATCH, SEQ, MODEL))
v = torch.rand(BATCH, SEQ, MODEL)

_ = multi_head(query=q, key=k, value=v)

# Check that self_attention and mismatching inputs asserts out
if self_attention:
with pytest.raises(AssertionError):
v = torch.rand(BATCH, SEQ, MODEL)
_ = multi_head(query=q, key=k, value=v)


@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
Expand Down Expand Up @@ -317,13 +360,14 @@ def test_causal(
"""

torch.random.manual_seed(42)
torch.backends.cuda.matmul.allow_tf32 = False

device = torch.device("cuda")

multi_head = _get_multihead(
attention_name,
0.0,
0.0,
attn_dropout=0.0,
res_dropout=0.0,
causal=True,
heads=heads,
device=device,
Expand All @@ -335,6 +379,7 @@ def test_causal(
.unsqueeze(0)
.expand(1, -1, -1)
)

q = (
torch.triu(torch.ones((SEQ, SEQ), device=device), diagonal=0)
.unsqueeze(0)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_pytorch_transformer_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
EMB = 8
VOCAB = 8
HEADS = 4
DROP = 0.1
DROP = 0.0
LAYERS = 2
ACTIVATION = "relu"

Expand All @@ -39,6 +39,7 @@
"seq_len": SEQ,
},
"dim_model": EMB,
"self_attention": True,
},
"feedforward_config": {
"name": "MLP",
Expand Down Expand Up @@ -161,6 +162,8 @@ def test_pytorch_encoder_parity(device=torch.device("cuda")):
def test_pytorch_tranformer_parity(device=torch.device("cuda")):
# Build both a xFormers and Pytorch model
reset_seeds()
torch.backends.cuda.matmul.allow_tf32 = False

model_xformers = xFormer.from_config(xFormerConfig(_test_config)).to(device)
print(model_xformers)

Expand Down
3 changes: 2 additions & 1 deletion xformers/benchmarks/benchmark_multi_head_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def bench_multihead_dispatch(backward: bool, self_attention: bool):
num_heads=heads,
attention=ScaledDotProduct(),
bias=True,
self_attention=self_attention,
).to(device=device, dtype=dtype)
torch_multi_head = nn.MultiheadAttention(
embed_dim=K, num_heads=heads, batch_first=True
Expand Down Expand Up @@ -81,7 +82,7 @@ def xformers_mha():
TestCase(xformers_mha, f"xf - fw{bw}{sa}"),
]:
time = triton.testing.do_bench(testcase.function)[0]
key = f"B={B}, M={M}, K={K}, N_HEADS={heads}"
key = f"({B},{M},{K}) - {heads}H"
if key not in results:
results[key] = {}

Expand Down
1 change: 1 addition & 0 deletions xformers/benchmarks/benchmark_pytorch_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def bench_pytorch_encoder(
"seq_len": seq,
},
"dim_model": emb,
"self_attention": True,
},
"feedforward_config": {
"name": "FusedMLP",
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .activations import Activation, build_activation # noqa
from .attention import Attention, build_attention # noqa
from .in_proj_container import InProjContainer, InProjParams # noqa
from .input_projection import InProjParams, InputProjection # noqa
from .multi_head_dispatch import MultiHeadDispatch # noqa
from .multi_head_dispatch import MultiHeadDispatchConfig
from .patch_embedding import PatchEmbeddingConfig # noqa
Expand Down
8 changes: 4 additions & 4 deletions xformers/components/attention/compositional.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
register_attention,
)
from xformers.components.attention.core import _softmax
from xformers.components.in_proj_container import InProjContainer, InProjParams
from xformers.components.input_projection import InProjParams, InputProjection


def _either_or(a: Optional[int], b: int) -> int:
Expand All @@ -51,7 +51,7 @@ class CompositionalAttentionConfig(AttentionConfig):
q_compose: bool = False
bias: bool = True
causal: Optional[bool] = False
in_proj_container: Optional[InProjContainer] = None
in_proj_container: Optional[InputProjection] = None
use_separate_proj_weight: Optional[bool] = False


Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
qk_rule=False,
nonlinear=False,
q_compose=False,
in_proj_container: Optional[InProjContainer] = None,
in_proj_container: Optional[InputProjection] = None,
use_separate_proj_weight: Optional[bool] = False,
bias=True,
causal=False,
Expand All @@ -126,7 +126,7 @@ def __init__(
self.in_proj_container = (
in_proj_container
if in_proj_container is not None
else InProjContainer(
else InputProjection(
query_proj_params=InProjParams(dim_model, dim_key, bias=bias),
key_proj_params=InProjParams(dim_model, dim_key, bias=bias)
if use_separate_proj_weight
Expand Down
Loading