Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Added four blocksparsity layouts #320

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removed dupliacated biases in the FusedMLP layers [#317]

### Added
- Four blocksparsity layouts from DeepSpeed [#320]

## [0.0.11] - 2022-05-30
### Fixed
Expand Down
263 changes: 263 additions & 0 deletions tests/test_attention_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
import torch

import xformers.components.attention.attention_patterns as AP
from xformers.components.attention.sparsity_config import (
BigBirdSparsityConfig,
BSLongformerSparsityConfig,
DenseSparsityConfig,
FixedSparsityConfig,
VariableSparsityConfig,
)


# baseline implementations
Expand Down Expand Up @@ -210,3 +217,259 @@ def test_alibi_pattern():
mask = AP.alibi_pattern(1e-3, (16, 128, 128))
# Minor, check that all the top left corners are True
assert torch.sum(mask[:, 0, 0]) == 16


def test_quick_layouts():

seq_size = 128
block_size = 16
num_heads = 2

# Fixed
assert torch.allclose(
AP.quick_fixed_layout(num_heads, block_size, seq_size),
torch.Tensor(
[
[
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
],
[
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
],
]
).long(),
)

# BSLongformer
assert torch.allclose(
AP.quick_bslongformer_layout(num_heads, block_size, seq_size),
torch.Tensor(
[
[
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 1, 1, 1, 0, 0],
[1, 0, 0, 0, 1, 1, 1, 0],
[1, 0, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 1, 1],
],
[
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 1, 1, 1, 0, 0],
[1, 0, 0, 0, 1, 1, 1, 0],
[1, 0, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 1, 1],
],
]
).long(),
)

# Variable
assert torch.allclose(
AP.quick_variable_layout(num_heads, block_size, seq_size),
torch.Tensor(
[
[
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
],
[
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
],
]
).long(),
)

# BigBird (just the shape)
assert AP.quick_bigbird_layout(num_heads, block_size, seq_size).shape == torch.Size(
[num_heads, seq_size // block_size, seq_size // block_size]
)


def test_layout_to_pattern():
torch.allclose(
AP.layout_to_pattern(
layout=torch.Tensor([[[0, 1], [1, 0]], [[1, 0], [0, 1]]]), block_size=2
),
torch.Tensor(
[
[
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 1.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
],
[
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 1.0],
],
]
),
)


def test_dense_sparsity_config():
sc = DenseSparsityConfig(num_heads=1, block_size=16)
with pytest.raises(expected_exception=ValueError):
sc.setup_layout(seq_len=17)
assert torch.allclose(
sc.make_layout(seq_len=32), torch.Tensor([[[1, 1], [1, 1]]]).long()
)


def test_big_bird_sparsity_config():
sc = BigBirdSparsityConfig(
num_heads=1,
block_size=16,
num_random_blocks=2,
num_sliding_window_blocks=1,
num_global_blocks=1,
)
with pytest.raises(expected_exception=ValueError):
sc.make_layout(seq_len=16)
sc = BigBirdSparsityConfig(
num_heads=1,
block_size=16,
num_random_blocks=1,
num_sliding_window_blocks=2,
num_global_blocks=1,
)
with pytest.raises(expected_exception=ValueError):
sc.make_layout(seq_len=16)
sc = BigBirdSparsityConfig(
num_heads=1,
block_size=16,
num_random_blocks=1,
num_sliding_window_blocks=1,
num_global_blocks=2,
)
with pytest.raises(expected_exception=ValueError):
sc.make_layout(seq_len=16)
with pytest.raises(expected_exception=NotImplementedError):
BigBirdSparsityConfig(num_heads=1, attention="directional")


def test_bslongformer_sparsity_config():
sc = BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[1])
assert torch.allclose(
sc.make_layout(128),
torch.Tensor(
[
[
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 1, 1, 1, 0, 0],
[1, 0, 0, 0, 1, 1, 1, 0],
[1, 0, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 1, 1],
]
]
).long(),
)
with pytest.raises(expected_exception=ValueError):
BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[])
with pytest.raises(expected_exception=ValueError):
BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[-1])


def test_fixed_sparsity_config():
# chech that the case end < num_blocks is correct
sc = FixedSparsityConfig(num_heads=1, horizontal_global_attention=True)
assert torch.allclose(
sc.make_layout(112),
torch.Tensor(
[
[
[1, 1, 1, 1, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 1],
[1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
]
]
).long(),
)
with pytest.raises(expected_exception=ValueError):
FixedSparsityConfig(num_heads=1, num_local_blocks=3, num_global_blocks=2)
with pytest.raises(expected_exception=NotImplementedError):
FixedSparsityConfig(num_heads=1, attention="directional")
with pytest.raises(expected_exception=ValueError):
FixedSparsityConfig(
num_heads=1, attention="unidirectional", horizontal_global_attention=True
)
with pytest.raises(expected_exception=ValueError):
FixedSparsityConfig(
num_heads=1,
num_different_global_patterns=2,
different_layout_per_head=False,
)
with pytest.raises(expected_exception=ValueError):
FixedSparsityConfig(
num_heads=1,
num_different_global_patterns=10,
num_local_blocks=4,
num_global_blocks=1,
)


def test_variable_sparsity_config():
sc = VariableSparsityConfig(num_heads=1, global_block_end_indices=[1])
assert torch.allclose(
sc.make_layout(128),
torch.Tensor(
[
[
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
]
]
).long(),
)
with pytest.raises(expected_exception=ValueError):
VariableSparsityConfig(num_heads=1, global_block_end_indices=[])
with pytest.raises(expected_exception=ValueError):
VariableSparsityConfig(num_heads=1, global_block_end_indices=[-1])
4 changes: 2 additions & 2 deletions xformers/benchmarks/LRA/code/config_nystrom.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"hidden_layer_multiplier": 2
}
}

],
"extra_settings": {
"attention": {
Expand Down Expand Up @@ -199,7 +199,7 @@
"hidden_layer_multiplier": 2
}
}

],
"extra_settings": {
"attention": {
Expand Down
35 changes: 35 additions & 0 deletions xformers/components/attention/attention_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
import numpy as np
import torch

from xformers.components.attention.sparsity_config import (
BigBirdSparsityConfig,
BSLongformerSparsityConfig,
FixedSparsityConfig,
VariableSparsityConfig,
)


# generic nd cases
def _generate_nd_grid(*sizes):
Expand Down Expand Up @@ -258,3 +265,31 @@ def get_slopes_power_of_2(n: int) -> List[float]:

# Now threshold arbitrarily, report the mask
return alibi < threshold


def quick_fixed_layout(num_heads: int, block_size: int, seq_len: int):
config = FixedSparsityConfig(num_heads=num_heads, block_size=block_size)
return config.make_layout(seq_len)


def quick_variable_layout(num_heads: int, block_size: int, seq_len: int):
config = VariableSparsityConfig(num_heads=num_heads, block_size=block_size)
return config.make_layout(seq_len)


def quick_bigbird_layout(num_heads: int, block_size: int, seq_len: int):
config = BigBirdSparsityConfig(num_heads=num_heads, block_size=block_size)
return config.make_layout(seq_len)


def quick_bslongformer_layout(num_heads: int, block_size: int, seq_len: int):
config = BSLongformerSparsityConfig(num_heads=num_heads, block_size=block_size)
return config.make_layout(seq_len)


def layout_to_pattern(layout: torch.Tensor, block_size: int):
r"""
create a pattern of shape [heads, seq, seq] out of a blocksparse
layout of shape [heads, seq/block_size, seq/block_size]
"""
return torch.kron(layout, torch.ones(block_size, block_size))
Loading