-
Notifications
You must be signed in to change notification settings - Fork 605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tensor masks failing with memory_efficient_attention #683
Comments
Hi @Infrared1029 Taking a step back, adding a |
Hi @danthe3rd q, k, v = torch.randn(32, 10, 1, 16), torch.randn(32, 10, 1, 16), torch.randn(32, 10, 1, 16)
q, k, v = q.cuda(), k.cuda(), v.cuda()
mask = torch.zeros((32, 10, 10)).cuda()
xops.memory_efficient_attention(q, k, v, attn_bias=mask)
Also thanks for the attention bias classes suggestion, I'm aware of that, just testing for more custom cases. |
The most recent development version should give you more information on what's going on:
I believe the same script should work if your sequence length is dividable by 4 (f32) or 8 (f16/bf16). In theory, this could work with seqlen 10 if your attn_bias is padded, but this is not implemented in xFormers at the moment... Is your sequence length going to be 10? cc @jfc4050 |
hmm, same issue still q, k, v = torch.randn(32, 16, 1, 16), torch.randn(32, 16, 1, 16), torch.randn(32, 16, 1, 16)
q, k, v = q.cuda(), k.cuda(), v.cuda()
mask = torch.zeros((32, 16, 16)).cuda()
xops.memory_efficient_attention(q, k, v, attn_bias=mask) yields the same error: for reference, this is the output of
|
Actually, after trying out the dev version |
Oh right, the bias support was merged after we cut branch for 0.0.16. |
Thanks a lot @danthe3rd, I was just writing a custom |
Pre-0.0.17: Supported under very strict conditions by a deprecated kernel. 0.0.17+: Attention bias is supported only if seqlen is a multiple of 4 or 8. |
Awesome, thanks a lot @danthe3rd for the help, also thanks for the quick replies, saved me lots of time:) |
Do we have the support for seqlen that is not a multiple of 4 or 8 ? I am working on a speech recognition problem, where seqlen is generally very high and might not always be divisible of 4 or 8 |
This has been added in v0.0.18, however with a twist. If your sequence length is |
❓ Questions and Help
I'm playing around with xformers and this is probably a noobish question, but why is this code snippet failing?
The text was updated successfully, but these errors were encountered: