Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[0.0.18, memory_efficient_attention with attn_bias]Getting NANs with arbitrary attn_bias mask with xformers==0.0.18 #722

Closed
toothacher17 opened this issue Apr 8, 2023 · 12 comments
Assignees
Labels
bug Something isn't working ongoing
Milestone

Comments

@toothacher17
Copy link

toothacher17 commented Apr 8, 2023

🐛 Bug

I am trying to use xformers to replace my native pytorch MHA implementation, sth like:

scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
    attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
return attn @ value

After switching to xformers, I am usiing xops.memory_efficient_attention(q, k ,v, attn_bias).

This works fine when I am using a lower triangular mask, either by passing in a LowerTriangularMask() or passing a torch.tensor with the same shape that built by my own.

However, when I am switching to use a arbitrary mask (supposing in pretraining stage, you opened reset_position_ids and reset_attention_mask flags, so you'll get a new start inside one sequence), I am getting NANs during evaluating (no grad forward) or training (with grad). Based on the log, the program is using the CUTLASS op.

Based on my observations, xformers saves 10-15% GPU memory and improves overall TFLOPs by 10-15%, so I really want to use it to replace with my native pytorch implementations. Could you help on this issue?

Environment

some depenedencies
trition==2.0
xformers==0.0.18
pytorch==2.0

@toothacher17
Copy link
Author

toothacher17 commented Apr 8, 2023

So my question is: is arbitrary attention mask supported by xformers 0.0.18 yet?
I tried to follow several issues thread from the past and based on the doc, my understanding is that the answer to this question is YES. However, I could not make it work

@danthe3rd
Copy link
Contributor

Hi @toothacher17
This should be supported by xFormers, but the behavior you report is definitively a bug. Unless a line in your mask is entirely masked-out, it shouldn't give you NaNs.
Do you have an independant minimum repro example so I can try it?

@danthe3rd
Copy link
Contributor

I actually managed to repro it with this script:

import math
import torch
import xformers.ops.fmha as fmha

B, M, H, K = 1, 1024, 2, 64
dtype = torch.float16
device = "cuda"

q, k, v = [torch.randn([B, M, H, K], dtype=dtype, device=device) for _ in range(3)]
mask = torch.zeros([B, H, M, M], dtype=dtype, device=device)
mask[:, :, :256, :256] = -math.inf
out = fmha.memory_efficient_attention(q, k, v, attn_bias=mask)
print(out.sum())

It looks like it happens when the first 128 tokens of a sentence are entirely masked out.

Regarding your issue specifically, as you want to handle sequences of varying length, I recommend you use this mask, it will also save compute

@danthe3rd danthe3rd self-assigned this Apr 11, 2023
@danthe3rd danthe3rd added bug Something isn't working ongoing labels Apr 11, 2023
@danthe3rd danthe3rd pinned this issue Apr 11, 2023
@toothacher17
Copy link
Author

hi, @danthe3rd

Thanks a lot for the reply.

Yes, I took another look at the BlockDiagonalCausalMask at flash_attention and Megatron-LM repo a few days ago. Basically it meets my need of pretraining with reset-position-id and reset-attention-mask, as resetting positions changes a lower-triangular matrix to a block diagonal matrix. To use this mask, like Megatron-LM, I'll need to change the shape, and merge all sentences in a batch into a single sentence, calculate the cumulative sequence length and pass them in to the ops.

@toothacher17
Copy link
Author

Btw, I tested with cutlass based operators as it is the only class that supports customized attention bias. And the NANs is generated by the cutlass operator...

@danthe3rd
Copy link
Contributor

To use this mask, like Megatron-LM, I'll need to change the shape, and merge all sentences in a batch into a single sentence, calculate the cumulative sequence length and pass them in to the ops.

It shouldn't be that involved, as normally everything operates at the token level, except for the attention. It might be a bit more involved if you are using some specific positional embedding tho.

Regarding the bug, I believe I understand where it comes from and should have a fix coming soon

@toothacher17
Copy link
Author

Thanks for getting back to me. I am using sth like ROPE, but that's before it enters into attention, so as long as the cumulative sequence length for qkv is calculated correctly, the block diagonal causal matrix should be fine.

How soon do you think you can release the fix? If it is really coming soon, I'll wait for the fix before changing my code for the block diagonal causal matrix

@danthe3rd
Copy link
Contributor

How soon do you think you can release the fix? If it is really coming soon, I'll wait for the fix before changing my code for the block diagonal causal matrix

Hopefully this week or next week for the xFormers development version (eg we might not necessarily push a new version tag yet)

@danthe3rd danthe3rd added this to the v0.0.19 milestone Apr 14, 2023
@danthe3rd
Copy link
Contributor

It should be fixed as of 540fcbf, and will be included in the next release (0.0.19). In the meantime, you can also use a development build >=0.0.19.dev516

@danthe3rd danthe3rd changed the title Getting NANs with arbitrary attn_bias mask with xformers==0.0.18 [0.0.18, memory_efficient_attention with attn_bias]Getting NANs with arbitrary attn_bias mask with xformers==0.0.18 Apr 14, 2023
@toothacher17
Copy link
Author

Thanks, @danthe3rd . We managed to change to use flash_atten_unppad_func and get correct expected loss. Maybe after the 0.0.19 is released, we might change back to use xformers and see if cutlass is faster than flash_attn implementation.

Thanks for the quick fix! Cheers!

@danthe3rd danthe3rd unpinned this issue May 11, 2023
@nofreewill42
Copy link

nofreewill42 commented Mar 23, 2024

Is an arbitrary mask with i.e. an off diagonal rectangular shaped attention mask being memory efficient if my understanding is correct, but it computes also the grey area as well so compute efficiency is lacking?
I'm awaiting nervously for this feature :P because of cross attention to audio from a transcript text efficiently like in this image:
image

EDIT:
Have I just found the solution to my specific problem withBlockDiagonalGappyKeysMask ? :O

https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.BlockDiagonalGappyKeysMask

image

@danthe3rd
Copy link
Contributor

@nofreewill42 I believe the mask you are looking for is this one - it also supports training, whereas the one you found only supports inference

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ongoing
Projects
None yet
Development

No branches or pull requests

3 participants