Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaNs when training with attn_bias (f32) #684

Open
zen-d opened this issue Mar 8, 2023 · 14 comments
Open

NaNs when training with attn_bias (f32) #684

zen-d opened this issue Mar 8, 2023 · 14 comments
Labels
bug Something isn't working

Comments

@zen-d
Copy link

zen-d commented Mar 8, 2023

❓ Questions and Help

Hi, I pass in the attn_bias to xformers.ops.memory_efficient_attention, but meet the following error

NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
     query       : shape=(831, 43, 32, 8) (torch.float32)             
     key         : shape=(831, 43, 32, 8) (torch.float32)
     value       : shape=(831, 43, 32, 8) (torch.float32)  
     attn_bias   : <class 'torch.Tensor'>
     p           : 0.0                                                                                                                                   
`flshattF` is not supported because:
    dtype=torch.float32 (supported: {torch.bfloat16, torch.float16})                                                                                     
    attn_bias type is <class 'torch.Tensor'>
`tritonflashattF` is not supported because:                                                                                                              
    dtype=torch.float32 (supported: {torch.bfloat16, torch.float16})
    attn_bias type is <class 'torch.Tensor'>                                                                                                             
`cutlassF` is not supported because:
    attn_bias.shape[-1] % 8 != 0                                                                                                                         
`smallkF` is not supported because:                                         
    bias with non-zero stride not supported    

In my case, attn_bias is indispensable and it is hard to always satisfy that attn_bias.shape[-1] % 8 == 0, so how could I benefit from this repo? Thanks.

@zen-d zen-d changed the title support for attn_bias of arbitrry format support for attn_bias of arbitrary format Mar 8, 2023
@danthe3rd
Copy link
Contributor

Hi,
Thank for opening this issue. That's something we can work on (see #683).
What type of bias do you need? Is it a learnable bias

@zen-d
Copy link
Author

zen-d commented Mar 8, 2023

@danthe3rd Thanks a lot for your prompt reply! #683 is highly related. In that thread I notice you may work on it #683 (comment).
First, may I know when the support for a attn_bias of torch.Tensorwith attn_bias.shape[-1] % 8 != 0 is scheduled? Would it be a very recent plan?
Second, if you could also support a learnable attn_bias, it would become more attractive.

@danthe3rd
Copy link
Contributor

The bias is currently learnable :) We just need to add this padding support. Hopefully we can get that out next week

@zen-d
Copy link
Author

zen-d commented Mar 8, 2023

Wow, fantastic! Look forward to seeing the padding support soon to relax the shape constraint.

@danthe3rd danthe3rd added bug Something isn't working ongoing labels Mar 8, 2023
@danthe3rd
Copy link
Contributor

It's merged in b6be33a

@zen-d
Copy link
Author

zen-d commented Mar 13, 2023

@danthe3rd Thanks! Looks good, but I don't have free GPUs temporarily. I will try on the new feature ASAP.

@zen-d
Copy link
Author

zen-d commented Mar 14, 2023

@danthe3rd By following these hints to do padding and slicing, I'm able to run the model now. The memory burden is significantly alleviated. Thanks for your awesome job! I will continue to monitor the training process and the final accuracy.

HINT: To use an attn_bias with a sequence length that is not a multiple of 8,
you need to ensure memory is aligned by slicing a bigger tensor.
Example: use attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5] instead of torch.zeros([1, 1, 5, 5])

@zen-d zen-d closed this as completed Mar 14, 2023
@zen-d zen-d reopened this Mar 14, 2023
@zen-d
Copy link
Author

zen-d commented Mar 14, 2023

Unfortunately, the training diverges in the middle (loss becomes NaN), which did not happen in the original attention-based model. Would you like to share some insights about that? Thanks.

@danthe3rd
Copy link
Contributor

danthe3rd commented Mar 14, 2023

I don't have specific idea for this, but you can detect more precisely where the nan is coming from with the anomaly detection:

torch.autograd.set_detect_anomaly(mode=True, check_nan=True)

@zen-d
Copy link
Author

zen-d commented Mar 15, 2023

Thanks for providing the suggestion. The only difference is the attention implementation in this controlled experiment, but I am not sure of the specific reason temporarily. I will dive deep into the issue. :)

@danthe3rd
Copy link
Contributor

Also - this is running in f32 it looks like? Otherwise you might want to try to train with f32 to see if it's related to the numerical precision

@zen-d
Copy link
Author

zen-d commented Mar 15, 2023

Yes, for safety, I am training with FP32 numerical precision now. (Similar to my experience, AMP training seems to have more chance of NaN for Transformer-based models.)

@danthe3rd danthe3rd removed the ongoing label Mar 30, 2023
@danthe3rd danthe3rd changed the title support for attn_bias of arbitrary format NaNs when training with attn_bias (f32) Mar 30, 2023
@Shannen3206
Copy link

Yes, for safety, I am training with FP32 numerical precision now. (Similar to my experience, AMP training seems to have more chance of NaN for Transformer-based models.)

I meet the same question, and i found that use fp16 can solve this problem.

@Shannen3206
Copy link

@danthe3rd By following these hints to do padding and slicing, I'm able to run the model now. The memory burden is significantly alleviated. Thanks for your awesome job! I will continue to monitor the training process and the final accuracy.

HINT: To use an attn_bias with a sequence length that is not a multiple of 8,
you need to ensure memory is aligned by slicing a bigger tensor.
Example: use attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5] instead of torch.zeros([1, 1, 5, 5])

Hi,
I found that use this method may cause the inference speed lower.#853
Do you have any good way?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants