Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Refactor and simplify scaled_dot_product_attention for easier extensibility #351

Open
fmassa opened this issue Jul 5, 2022 · 3 comments
Labels
brainstorm dropping an idea, may or may not be implemented in the end. RFC

Comments

@fmassa
Copy link
Contributor

fmassa commented Jul 5, 2022

🚀 Feature

scaled_dot_product_attention is a key component of most transformer architectures, and is used almost everywhere in xformers.
Over time, its API has been extended to support a few different cases (either performance improvements, or new functionality). With additional improvements planned down the road to enable even other cases, I think it would be a good moment to step back and re-write scaled_dot_product_attention so that it is extensible enough and generic enough for the use-cases we care about.

Here is a list of proposed changes:

1 - Let it take [batch, head, seq, dim] instead of 3d tensors

We originally made scaled_dot_product_attention take 3d tensors to simplify the implementation of some early sparse components. While this is ok in general, handling different types of masks types (same per batch, or per head only) means that we would need to materialize the full mask, instead of relying on broadcast for that (which would save memory). Plus, letting it take 4d tensors would make its API more aligned with what PyTorch expects as well.

2 - Remove AttentionMask support inside scaled_dot_product_attention

The use of AttentionMask inside scaled_dot_product_attention is mostly there as a performance optimization for the softmax in the causal case.
Instead, I propose to slightly modify its behavior so that we have this be an implementation detail of each specific Tensor subclass passed as a mask in scaled_dot_product_attention.

3 - Use __torch_function__ to handle the dispatch for different mask types

There is a lot of code in scaled_dot_product_attention which is basically performing the dispatching between different mask types (SparseCS / Tensor), but it doesn't handle block-sparse for example.

I think we should instead always use custom Tensor types and the dispatch mechanism from PyTorch to perform the dispatch under the hood. This would enable new Tensor subclasses to be added without having to change the underlying implementation in scaled_dot_product_attention. This is drafted in https://github.com/facebookresearch/xformers/tree/sparse_cleanup (but needs rebase)

4 - attn_mask becomes attn_bias

We currently have two behaviors which are (partly) enabled: masking and bias addition. I think we should only handle bias addition, and we would change the behavior of the sparse tensors to actually be masked tensors which are backed by sparse operators.
If users want the masking behavior, this can be done through a -inf in the bias.

Thoughts?

cc @blefaudeux @dianaml0 for thoughts

@fmassa fmassa added the brainstorm dropping an idea, may or may not be implemented in the end. RFC label Jul 5, 2022
@dianaml0
Copy link
Contributor

These seem like really nice improvements, I think 1 and 4 would help with integration with other codebases. For 2, would the functionality of AttentionMask just be encompassed by the Tensor classes instead?

@fmassa
Copy link
Contributor Author

fmassa commented Jul 19, 2022

For 2, would the functionality of AttentionMask just be encompassed by the Tensor classes instead?

Yes, the idea would be that most of what is present in AttentionMask would be present instead inside the tensor subclasses. What would be left out are convenience functions to go from bool to float masks and vice-versa, but this can be left as a separate function.

@blefaudeux
Copy link
Contributor

Thanks for the write up @fmassa, and sorry for the delay..

1 and 2 are kind of no brainers to me: 1 (4d tensors, aligned with pytorch) is side effect free as long as the repo is internally coherent (and the unit test will catch that), there are material benefits. 2 (AttentionMask) is completely fine by me, I think that it goes with 3 and 4 really but even just 2 would be ok. At the time AttentionMask was a way to factorize some redundant code and do the checks and transforms only once (not just causal actually, but accepting bool and float masks, I think that you'll see that in the tests).

3 (torch_function to dispatch the mask types) is also fine by me, I think that it makes it a little more complicated for outsiders to contribute, but it's not a part which gets people's attention typically (it's super useful, but I think that people mostly look for this new fancy thing instead), so that's all good. I agree that all the dispatch and size and type handling being nicely wrapped somewhere is nice. Top of head torch_function was not available until a couple of torch releases ?

4 (attn_mask -> attn_bias) is also fine by me, just an open question with respect to the mask size in that case, floats take 16 to 32 bits typically, which is a lot more than your typical bool, isn't that an issue ? Other than that it makes it very clear what it does, it's nice for models like Swin, and it handles both mask & bias purposes, so I just see wins

fmassa pushed a commit that referenced this issue Aug 10, 2022
authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
fmassa added a commit that referenced this issue Aug 25, 2022
* Enable masking in memory-efficient attention (#333)

* Add attention bias in memory-efficient attention

* Add gradient for attn_mask support

* Add CPU implementation

* clang-format

* Add benchmark scripts

* Add extra loop in benchmarks

* Move zeros array out of helper function

* clang-format

* Enable dropout in memory-efficient attention (#334)

* Merge compute_scaling_coeffs and update_scaling_coeffs into a single function

It wasn't needed to break it in two functions to begin with

* Add CUDA implementation for dropout

* clang-format

* Make p be drop probability

* Only CUDA supports dropout

* Add benchmarks

* Remove unused variables

* Fix test

* Cleanups and comments

* Fix masking corner case when full block is masked (#339)

* Add cutlass 2.9 - 858c735856a7f17bd33fe438ec76d3c9f0234e7f

* Option to load from shared memory for PredicatedTileIterator

* Add cutlass include dir

* Ignore files in third-party for flake8/coverage

* third-party -> third_party

* Address comments

* Revert some un-needed mods

* Add attention_forward_generic.cu

* Add tests

* Fix duplicate calculations on baseline for mem efficient transformers

* Always run all linters in CI

* clang-format attention_forward_generic.cu

* Benchmark: Add possibility to compare benchmarks

* [isort] Ignore third_party

* black autoformat

* Black again + ignore third_party properly

* black

* Fix memory leak between the 2 benchmarks in backward

* Exclude third_party/ without using pyproject.toml as it imposes isolated build which is a pain

* Remove progress bar when finished

* mypy

* flake8

* Save results to shared folder in home location

* run black

* clang-format with 'run-clang-format.py'

* Fix cutlass build for arch>=75

* Set tests precision for gradient more accurately

* Fix precision margin

* Revert changes to black

* [feat] Fix importing xformers when not built (#351)

authored-by: danthe3rd <danthe3rd@users.noreply.github.com>

* Update black to 22.3.0

* Tweak precision for mem_eff_attention test

* mem-efficient impl for f16 (#352)

Co-authored-by: danthe3rd <danthe3rd>

* Add support for f16 with tensorcores [sm70/sm75/sm80] (#354)

* Add support for f16 with tensorcores

* sm75 minimum for tensorcores

* Run tests with CUDA_LAUNCH_BLOCKING=1

* Support sm70 properly

* Disable tensorcore when not correctly aligned - and use 32bit accessors

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>

* Optimize backward of memory-efficient attention by ~20% (#355)

* Optimize backward by 15% by using equivalent formulation

* Unify everything into single kernel

* Remove unused implementation

* clang-format

* Remove unused tensor

* Display results as we progress during benchmark (#357)

Co-authored-by: danthe3rd <danthe3rd>

* RFC: Ops dispatch (#356)

* Ops dispatch

* CI: Fix doc build

* memory_efficient_attention raises when no implementation is available

* type: ignore

* Fix torch.device/str comparison

* Make mypy happy

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: danthe3rd <danthe3rd>

* [A100/f32] Use TensorCores for Q.K_t matmul with FastF32 (#358)

* Use TensorCores for MM0 on Float as well

* Use MultiStage MMA when available - change to FastF32 rather than FastF16

* Better alignment calculation

* Just use regular f32, no fastf32

* Hackfix to handle alignment

* HeuristicsMM0 -> GemmTypeQK

* No longer use f16 for matmul

* Add some doc

* Typo

* Fix build <sm80

* Alignment check based on current device compute capability

* Use TORCH_INTERNAL_ASSERT

Co-authored-by: danthe3rd <danthe3rd>

* FlashAttention implem and dispatch (#360)

* FlashAttention implem WIP

* Fix flashattention forward+backward

* Fix forward/backward for FlashAttention

* Enable tests (more permissive) for f16 backward

* Fix CI

* flashattn only supports Sm75 and above

* Fix CI2

* Disable K=128 when below sm80 for flashattn

Co-authored-by: danthe3rd <danthe3rd>

* Misc performance improvements for generic mem-efficient attention (#361)

* 3% speedup by calculating mi from registers

* Also compute m_prime/s_prime and exponentiate from registers

* Support for Simt tiles

* Fix TensorOp for V100

* Fix for A100

* Fix Simt alignment calculation

* clang-format

* WarpReduction before atomic call for Simt

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>

* Update flashattention to support bf16 (#363)

* Update flashattention to support bf16

* bfloat16 only on sm80 and above

Co-authored-by: danthe3rd <danthe3rd>

* Flashattn causal (#364)

* Implement causal memory-efficient attention with FlashAttention

* Update benchmarks

* Fix mypy

Co-authored-by: danthe3rd <danthe3rd>

* Option to disable flashattention (long to build) (#362)

* Option to disable flashattention (long to build)

* Update setup.py

Co-authored-by: danthe3rd <danthe3rd>

* Remove code duplicate in attention_scaling_coefs_updater.h (#367)

Co-authored-by: danthe3rd <danthe3rd>

* Update .gitmodules (#366)

* MemoryEff attention forward: Properly fuse matmul and enable TensorCores on the second matmul (#368)

* Generic backwards

* Guard backward to sm75 only

* bounds checking for gradV

* clang-format

* Fused gemm working for Sm80/Sm75 f16/f32

* WIP

* Volta TensorOp for f16

* Working on A100 again

* SIMT working

* Code cleanup 1

* Code cleanup2

* BUGFIX for shared memory limit

* Remove code

* clang-format

* Remove code again

* Remove draft of backward

* Enforce alignment for fp16

* Fix tests

* Fix constraint on seq length when not using tensorcores

* Fix alignment requirements for V100/tensorcores

* Clang-format

* Update xformers/components/attention/csrc/cuda/attention_forward_generic.cu

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>

* Address comments from fmassa

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: Francisco Massa <fvsmassa@gmail.com>

* Update install instructions with submodule (#365)

* Generic backward implem with cutlass (#371)

* Old bw code

* P100: gradV working

* gk/gq working (at least for small values of M, and on P100/f16)

* Further restrict supported values for bw

* Fix storage into smem for Simt

* More tooling for pruint/debug

* Remove tests we dont need for now

* Tests pass on P100 :D

* 4 warps per block

* Restraint on q length

* Use tensorcores on V100 for f16

* Support dynamic smem for bw

* Handle alignment and different dtype/arch

* Fix NaNS by initializing shared memory

* bw.py

* Fix launch bounds

* Faster 'computeDi'

* minus_lse can operate on arrays

* Output number of regs used etc...

* Code cleanup

* Hackfix for alignment check during forward

* zFill to avoid nans in Sm80 + fix launch bounds

* COde cleanup1

* clang-format

* Fix tests

* Add benchmark for K=64

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: danthe3rd <danthe3rd>

* Cutlass as submodule (#375)

* Make cutlass be back at 858c735856a7f17bd33fe438ec76d3c9f0234e7f

* Remove cutlass

* Update submodules

* Add submodule (properly)

* spaces / tab

* Make submodule init be recursive

* Fix bad rebase

* Bump tolerance for backward (#377)

* Add verbose flag to CI builds (#376)

* Add verbose flag to CI builds

* Spurious change to rebuild cache

* Add ninja

* Ninja wasn't visible before, install through conda

* Debugging

* Source env

* One more try

* Forgot to uncomment a line

* Another try

* Cleanup

* Fix for FlashAttention dispatch

It requires device capability >= 7.5

* Remove generated file

* Address some reviewer feedback

Remove unused function and typo fix

* Perf improvement on backward (#378)

* Fast again on V100

* Fix correctness - missing syncthreads

* Get rid of AttentionInfo

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
brainstorm dropping an idea, may or may not be implemented in the end. RFC
Projects
None yet
Development

No branches or pull requests

3 participants