Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[conformer] support flash att by torch sdpa #2360

Merged
merged 7 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 116 additions & 4 deletions test/wenet/transformer/test_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import torch
import pytest
from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention)
from wenet.transformer.embedding import RelPositionalEncoding
from wenet.transformer.encoder_layer import (ConformerEncoderLayer,
TransformerEncoderLayer)
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES

Expand Down Expand Up @@ -30,7 +33,7 @@
"dropout_rate": 0.0
},
])
def test_sdpa(args):
def test_multi_head_attention_sdpa(args):
torch.manual_seed(777)
mha_module = MultiHeadedAttention(use_sdpa=False, **args)
torch.manual_seed(777)
Expand Down Expand Up @@ -106,6 +109,115 @@ def test_sdpa(args):
atol=9e-7,
rtol=9e-4,
)
# assert torch.allclose(cache, cache_with_sdpa)
assert torch.allclose(cache, cache_with_sdpa)

q = output


@pytest.mark.parametrize("args", [
{
"n_feat": 256,
"n_head": 4,
"dropout_rate": 0.0
},
{
"n_feat": 512,
"n_head": 8,
"dropout_rate": 0.0
},
{
"n_feat": 1280,
"n_head": 20,
"dropout_rate": 0.0
},
{
"n_feat": 512,
"n_head": 4,
"dropout_rate": 0.0
},
])
def test_rel_position_multi_head_attention_sdpa(args):
rel_pos_moduls = RelPositionalEncoding(args['n_feat'], dropout_rate=0.0)
torch.manual_seed(777)
rel_mha_module = RelPositionMultiHeadedAttention(use_sdpa=False, **args)
torch.manual_seed(777)
rel_mha_module_with_sdpa = RelPositionMultiHeadedAttention(use_sdpa=True,
**args)
rel_mha_module.eval()
rel_mha_module_with_sdpa.eval()

q = torch.rand(10, 100, args['n_feat'])
_, pos_emb = rel_pos_moduls(q)
k = torch.rand(10, 100, args['n_feat'])
v = torch.rand(10, 100, args['n_feat'])
input_lens = torch.tensor([100, 90, 80, 79, 60, 51, 40, 30, 10, 5])
mask = make_non_pad_mask(input_lens).unsqueeze(1)
att_mask = add_optional_chunk_mask(q,
mask,
use_dynamic_chunk=True,
decoding_chunk_size=0,
static_chunk_size=0,
use_dynamic_left_chunk=True,
num_decoding_left_chunks=-1)
output, cache = rel_mha_module(q, k, v, mask=att_mask, pos_emb=pos_emb)

att_mask_bias = (1.0 - att_mask.float()) * torch.finfo(torch.float).min
output_with_sdpa, cache_with_sdpa = rel_mha_module_with_sdpa(
q, k, v, mask=att_mask_bias, pos_emb=pos_emb)
assert torch.allclose(
output * mask.transpose(1, 2),
output_with_sdpa * mask.transpose(1, 2),
atol=9e-7,
)
assert torch.allclose(cache, cache_with_sdpa)

n_blocks = 12
torch.manual_seed(777)
rel_mha_layers = [
ConformerEncoderLayer(
args['n_feat'],
RelPositionMultiHeadedAttention(use_sdpa=False, **args),
PositionwiseFeedForward(
args['n_feat'],
2048,
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
None,
None,
0.0,
normalize_before=True,
) for _ in range(n_blocks)
]

torch.manual_seed(777)
rel_mha_layers_with_sdpa = [
ConformerEncoderLayer(
args['n_feat'],
RelPositionMultiHeadedAttention(use_sdpa=True, **args),
PositionwiseFeedForward(
args['n_feat'],
2048,
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
None,
None,
0.0,
normalize_before=True,
) for _ in range(n_blocks)
]

for i in range(n_blocks):
output, _, cache, _ = rel_mha_layers[i](q, att_mask, pos_emb, mask)
output_with_sdpa, _, cache_with_sdpa, _ = rel_mha_layers_with_sdpa[i](
q, att_mask_bias, pos_emb, mask)

assert torch.allclose(
output * mask.transpose(1, 2),
output_with_sdpa * mask.transpose(1, 2),
atol=9e-7,
rtol=9e-4,
)
assert torch.allclose(cache, cache_with_sdpa)
q = output
44 changes: 33 additions & 11 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import torch
from torch import nn

from wenet.utils.common import get_dtype_min


class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Expand Down Expand Up @@ -227,9 +229,10 @@ def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True):
key_bias: bool = True,
use_sdpa: bool = True):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, key_bias)
super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
Expand Down Expand Up @@ -330,20 +333,39 @@ def forward(
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)

# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))

# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
if not self.use_sdpa:
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))

scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)

return self.forward_attention(v, scores, mask), new_cache
return self.forward_attention(v, scores, mask), new_cache
else:
# NOTE(Mddct): we need mask bias, not boolean mask
assert mask.dtype != torch.bool
mask = mask.unsqueeze(1)
# matrix_bd as a mask bias
mask = torch.where(mask == get_dtype_min(mask.dtype), mask,
matrix_bd / math.sqrt(self.d_k))
output = torch.nn.functional.scaled_dot_product_attention(
q_with_bias_u,
k,
v,
attn_mask=mask,
dropout_p=self.dropout_rate,
scale=1 / math.sqrt(self.d_k),
)
output = (output.transpose(1, 2).contiguous().view(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache
5 changes: 4 additions & 1 deletion wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def __init__(
cnn_module_norm: str = "batch_norm",
key_bias: bool = True,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
):
"""Construct ConformerEncoder

Expand All @@ -441,7 +442,8 @@ def __init__(
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing, False)
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa)
activation = WENET_ACTIVATION_CLASSES[activation_type]()

# self-attention module definition
Expand All @@ -450,6 +452,7 @@ def __init__(
output_size,
attention_dropout_rate,
key_bias,
use_sdpa,
)
# feed-forward module definition
positionwise_layer_args = (
Expand Down
Loading