Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Attention] Disable weight scaling #522

Closed
comaniac opened this issue Nov 12, 2022 · 4 comments
Closed

[Attention] Disable weight scaling #522

comaniac opened this issue Nov 12, 2022 · 4 comments

Comments

@comaniac
Copy link
Contributor

🚀 Feature

The current attention implementations scale weights, as shown in the reference implementation (https://github.com/facebookresearch/xformers/blob/main/tests/test_mem_eff_attention.py#L154):

q = q * (1 / q.shape[-1] ** 0.5)

However, in some models such as T5, the weight scaling is already applied during weight initialization: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L803.

As a result, it is not efficient to use memory efficient attention in T5, because we have to either manually change the weight initialization, or revert this scaling before calling the attention op.

Pitch

Let users configure whether the weight scaling is disabled in attention ops.

@danthe3rd
Copy link
Contributor

Hi @comaniac

That's a good point, and we could indeed add an argument with the scale to apply to memory_efficient_attention. I'll add this to my todo list, but I have a few other things to do before.

@comaniac
Copy link
Contributor Author

Thanks for the reply. Will this be tracked by any task tracking issue or roadmap so that I could follow up later? Or if you have a concrete suggestion in mind about how this should be done (e.g., how this should be configured by users), I could probably try to add the support.

@danthe3rd
Copy link
Contributor

We don't have a public roadmap. Tracking could be done with this issue.
If you want to contribute, that should be possible as well and we welcome pull requests :)

You would need to:
(1) Add an argument (eg scale) to the function that would default to None for instance (to use the default scaling we have currently): https://github.com/facebookresearch/xformers/blob/main/xformers/ops/memory_efficient_attention.py#L734
(2) Modify the dispatcher because some kernels will not support a custom scale, so adding a few "has_custom_scale" for instance
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/memory_efficient_attention.py#L673
Can be filtered here then:

def supports(cls, d: "AttentionOpDispatch") -> bool:

(3) All the classes that inherit from "AttentionOpBase" should have an additional argument "scale" for the forwards

This is for the basic support. Then you can add support in each operator individually. It's maybe easier to start with Flash if you are interested in Flash, as you just have to change the value here:

softmax_scale = query.shape[-1] ** (-0.5)

Finally, you will need to add some tests in https://github.com/facebookresearch/xformers/blob/main/tests/test_mem_eff_attention.py

@comaniac
Copy link
Contributor Author

Thanks for the details. I'll try to implement it when I got a chance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants