Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RTX/A6000 GPUs] NaNs in backward pass when training with the huggingface diffusers-style trainer and unet. #631

Closed
zoru22 opened this issue Jan 10, 2023 · 29 comments

Comments

@zoru22
Copy link

zoru22 commented Jan 10, 2023

Sorry if this bug report is sub-optimal, I'm going to append additional information. When training a diffusion model using a training script similar to/exact with the huggingface diffusers u-net.

See: https://github.com/harubaru/waifu-diffusion/tree/main/trainer
for the trainer.

🐛 Bug

With a random-initialized unet and bfloat16 enabled, as well as pytorch's anomaly-detection enabled, on the very first run of the model, I will get:

RuntimeError: Function '_fMHABackward' returned nan values in its 0th output.

With torch anomaly detection enabled. If anomaly detection is disabled and the trainer left running long enough, the loss will eventually devolve into NaNs.

When using bfloat16 and enabling xformers for training. Hardware: A6000 sm_86 gpus.

To Reproduce

Steps to reproduce the behavior:

Enable:
torch.autograd.set_detect_anomaly(True)
Set the data type to bfloat16.

Train on a data set of images...

[...] in main
    scaler.scale(loss).backward()
  File "/home/zoru/.local/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/home/zoru/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/zoru/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 267, in apply
    return user_fn(self, *args)
  File "/home/zoru/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 157, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/zoru/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function '_fMHABackward' returned nan values in its 0th output.

Expected behavior

Not getting NaNs when running backward when running bfloat16

Environment

  • PyTorch Version (e.g., 1.0): 1.13.1
  • OS (e.g., Linux): Manjaro linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: Python 3.10.8
  • CUDA/cuDNN version: NVIDIA-SMI 525.60.11 Driver Version: 525.60.11 CUDA Version: 12.0 (nvcc reports: cuda_11.8.r11.8/compiler.31833905_0
  • GPU models and configuration: Nvidia A6000 48GB

xformers.info output when nans (latest commit)

python -m xformers.info
xFormers 0.0.15+6cd1b36.d20230109
memory_efficient_attention.cutlassF:                 available
memory_efficient_attention.cutlassB:                 available
memory_efficient_attention.flshattF:                  available
memory_efficient_attention.flshattB:                  available
memory_efficient_attention.smallkF:                  available
memory_efficient_attention.smallkB:                 available
memory_efficient_attention.tritonflashattF:      available
memory_efficient_attention.tritonflashattB:     available
swiglu.fused.p.cpp:                                                 available
is_triton_available:                                                  True
is_functorch_available:                                           False
pytorch.version:                                                      1.13.1+cu117
pytorch.cuda:                                                           available
gpu.compute_capability:                                       8.6
gpu.name:                                                                NVIDIA RTX A6000

Transformers is version: 4.22.2
Diffusers is: 0.11.1

Additional Information

  • Batch size does not affect the presence of NaNs in backward pass.
  • Learning Rate does not affect whether NaNs come out.
  • Disabling xformers and running with bfloat16 torch AMP and anomaly detector no longer trips/nans no longer occur.
  • Rolling xformers back to commit c733c99 (Nov 29th) and earlier, and nans no longer occur/torch AMP with the anomaly detector, no longer trips/nans no longer occur.
  • Rolling xformers forward to affe4da (Dec 9th and later) and the nans occur.
  • Using the pip or conda packages both give nan's

Note: I do not know if c733c99 is the inflection point for when the nans begin to reliably occur, however that is the most recent version which I know works, and have tested.

Edit x1:
Stepping forward to commit: 1924b196 (Dec 6) and nans do not occur. Will drill into it more later.

python -m xformers.info                                                                                                                                                                      
xFormers 0.0.15.dev+1924b19.d20230110
memory_efficient_attention.flshatt:      available - requires GPU with compute capability 7.5+
memory_efficient_attention.cutlass:      available
memory_efficient_attention.small_k:      available
swiglu.fused.p.cpp:                      available
@zoru22 zoru22 changed the title NaNs in backward pass when training with the huggingface diffusers unet. NaNs in backward pass when training with the huggingface diffusers-style trainer and unet. Jan 10, 2023
@danthe3rd
Copy link
Contributor

Hi,
That's quite surprising, but might also be related to your GPU - we didn't really test sm86 in depth, and there are known issues for the bw pass (eg #628).
1924b196 is right before we added triton based flash-attention (although it should be enabled on A100 only ...). Maybe you should try removing triton and see if you still have the issue? (pip uninstall triton)

In any case, I also opened #632 to make it easier to debug what's happening in those cases

@zoru22
Copy link
Author

zoru22 commented Jan 10, 2023

@danthe3rd You caught me as I was rolling forward, to commit: 07ba3f3a (dec 8th), but I get nans there, so it looks related to issues from around that point. With or without triton, the error there, actually changes:

Function 'MemoryEfficientAttentionCutlassOpBackward' returned nan values in its 1th output (with and without triton)

Also note: I'm building from source using pip install -v -U . though the last couple times that would miss packages, so I would run python setup.py build develop, and then xformers.info would show flashattn etc as available.

Edit:
Rolling back to f2f3424 (one after the last-known-good commit I'd found) and I get the same error, so yeah, I do think it's a problem with a change introduced with triton.

@danthe3rd
Copy link
Contributor

I don't understand what's going on...
The relevant commits are:
(1) 1924b19 [good]: No NaNs
(2) f2f3424 [bad]: Introduce Triton kernel for FW
(3) c2d5b37 [bad I assume?]: restrict Triton FW kernel to only A100
So from this point the Triton kernel shouldn't be used at all

Can you go back to master, and print the name of the kernel used for the FW / BW?
You'll need to modify xformers and add print(op.NAME) before those 2 lines for the FW/BW respectively:

out = op.apply(inp, needs_gradient=True)

grads = op.apply(ctx, inp, grad)

@zoru22
Copy link
Author

zoru22 commented Jan 10, 2023

I made the print statements give a bit more info:

[...]
BWD OP: flshattB
BWD OP: flshattB
FWD OP W GRAD: flshattF
FWD OP W GRAD: flshattF
BWD OP: cutlassB
BWD OP: cutlassB
FWD OP W GRAD: flshattF
FWD OP W GRAD: flshattF
BWD OP: cutlassB

@danthe3rd
Copy link
Contributor

so the one causing the NaNs should be the last one, aka cutlassB.
I'm a bit surprised that it's not using FlashAttention for the backward as well..

So it's the grad_q variable below which eventually contains NaNs:

(grad_q, grad_k, grad_v,) = cls.OPERATOR(
grad.to(dtype),
inp.query,
inp.key,
inp.value,
ctx.get_padded_lse(32, force_pad_inf=force_pad_inf),
ctx.out.to(dtype),
causal=causal,
scale=inp.scale,
)

Can you dump all the values there when there is a nan and send that as a pickle file for instance? (grad, ctx, inp, grad_q, grad_k, grad_v)

@grapeot
Copy link

grapeot commented Jan 15, 2023

Thanks for reporting and trying to solve the issue. Just wanted to report I also hit the same issue with A4000, in case this helps prioritization.

I tried rolling back to 1924b1, but still get an NaN:

Traceback (most recent call last):
  File "/home/grapeot/co/diffusersOfficial/examples/dreambooth/train_dreambooth.py", line 847, in <module>
    main(args)
  File "/home/grapeot/co/diffusersOfficial/examples/dreambooth/train_dreambooth.py", line 797, in main
    accelerator.backward(loss)
  File "/home/grapeot/co/diffusersOfficial/py310/lib/python3.10/site-packages/accelerate/accelerator.py", line 1314, in backward
    self.scaler.scale(loss).backward(**kwargs)
  File "/home/grapeot/co/diffusersOfficial/py310/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/home/grapeot/co/diffusersOfficial/py310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/grapeot/co/diffusersOfficial/py310/lib/python3.10/site-packages/torch/autograd/function.py", line 267, in apply
    return user_fn(self, *args)
  File "/home/grapeot/co/diffusersOfficial/py310/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 157, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/grapeot/co/diffusersOfficial/py310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'MemoryEfficientAttentionCutlassOpBackward' returned nan values in its 2th output.

@arpowers
Copy link

arpowers commented Jan 18, 2023

also having this issue, package is the new one installed with:
pip install --pre -U xformers

running same dreambooth example from diffusers, the loss goes to nan around 120 iterations in my case; can also confirm that removing xformers has fixed the problem. The issue might be with their specific implementation of xformers and how it works with the latest version (idk)

wasted nearly a day trying to figure out this, didn't think the problem was xformers since everything seemed to be working well; just getting black outputs at infer time.

Looking forward to a fix!

@danthe3rd

@knoopx
Copy link

knoopx commented Jan 20, 2023

xformers-0.0.15.dev343 (1b1fd8a) was the last working published revision on my 3090. unfortunately pre-compiled wheels are no longer available.

@danthe3rd
Copy link
Contributor

I'll be away next week, but after that I want to start working on making it compatible on Sm86

@jlee2109
Copy link

jlee2109 commented Jan 25, 2023

I'll admit this is speculative, but for what it's worth... I was having a hell of a time getting Dreambooth extension working in A111's stable diffusion UI. I was having some unrelated issues, so I rolled forward to latest versions of everything (iirc, this included an xformers update). I started to quickly get NaN bombed when training the model.

This bug report tipped me off that it might be xformers. I changed only one parameter in the training - switching from xformers to flash_attention as my memory attention option - and the NaN bombs are now gone. I'm still too dumb about how all of this works, but it points my suspicion at xformers (at the very least, there is some incompatibility issue, but that may not be entirely due to xformers).

xformers version 0.0.14.dev0 is the one I had installed (which I'm just now realizing... why do I have a dev build?). I'm running an RTX 4090 if that seems relevant, and with bf16 precision. On Windows 11.

More details from the log....

Python revision: 3.10.6 (tags/v3.10.6:9c7b4bd, Aug  1 2022, 21:53:49) [MSC v.1932 64 bit (AMD64)]
Dreambooth revision: 9f4d931a319056c537d24669cb950d146d1537b0
SD-WebUI revision: 48a15821de768fea76e66f26df83df3fddf18f4b

Checking Dreambooth requirements...
[+] bitsandbytes version 0.35.0 installed.
[+] diffusers version 0.10.2 installed.
[+] transformers version 4.25.1 installed.
[+] xformers version 0.0.14.dev0 installed.
[+] torch version 1.13.1+cu117 installed.
[+] torchvision version 0.14.1+cu117 installed.

@Zuxier
Copy link

Zuxier commented Jan 26, 2023

RTX 4090 as well but i have seen the issue on the full spectrum on consumers card from the reports we have been getting.
To research the issue, I built 0.14dev0, 0.14, 0.15, 0.16.
0.14dev0 is the only one that is able to properly train.
1924b19 doesn't train on my end.
Issue is there for both windows and linux.
A guy on windows with a T4 couldn't train with the newest xformers.
meanwhile one google colab confirmed to me that with the T4 he is able to train on the latest xformers, but that's the only success story i have.
I assumed the issue was related to the shared memory limitations, but from the little i was able to find, T4 should have the same amount of shared memory as normal Turing cards. Can't really expain that colab result.

@grapeot
Copy link

grapeot commented Jan 28, 2023

@Zuxier thanks for providing a way to get around this issue! May I ask how to build v0.14dev0? I tried checking the tags, releases, PyPI package versions, and searching in the commit messages, and couldn't find 0.0.14dev0 or 0.14dev0. Did you build it using a certain commit id? Thanks a lot!

@danthe3rd danthe3rd pinned this issue Jan 30, 2023
@danthe3rd danthe3rd changed the title NaNs in backward pass when training with the huggingface diffusers-style trainer and unet. [RTX/A6000 GPUs] NaNs in backward pass when training with the huggingface diffusers-style trainer and unet. Jan 30, 2023
@danthe3rd
Copy link
Contributor

I found an A10g GPU to work on this. Do we have specific steps to reproduce or inputs I could try it on? (also interested if you have the input shapes)
I can get CUDA errors (which I'll fix), but I'm unable to produce nans ...

@zoru22
Copy link
Author

zoru22 commented Jan 30, 2023

Sorry for not following back up, I only have access to a limited number of GPUs and needed to use them for other jobs.

I'm a bit surprised that it's not using FlashAttention for the backward as well..

The reason for this is because the Stable Diffusion attention head dims are too large for FlashAttention on A6000s (>64. The actual head size is 80, iirc).

Do we have specific steps to reproduce or inputs I could try it on?

I'll try to set something up, but it might be a while (>5 days)

@danthe3rd
Copy link
Contributor

The actual head size is 80

Oh okay. That's weird because with the current code I would expect a CUDA error rather than silent nans ... I'll try to push a fix and let you know to test it again

@Zuxier
Copy link

Zuxier commented Jan 30, 2023

I found an A10g GPU to work on this. Do we have specific steps to reproduce or inputs I could try it on? (also interested if you have the input shapes) I can get CUDA errors (which I'll fix), but I'm unable to produce nans ...

here is the original diffusers example. And this is part of their documentation.
Using the diffusers implementation, the final model is the same as the original. Using another it's NaN .
Both works when using 0.14dev0 (e23b369 not sure of the ones right after)

@nephina
Copy link

nephina commented Feb 1, 2023

I found an A10g GPU to work on this. Do we have specific steps to reproduce or inputs I could try it on? (also interested if you have the input shapes) I can get CUDA errors (which I'll fix), but I'm unable to produce nans ...

Here's an example that's failing for me. I've removed the specific thing I'm working on (the densenet) and replaced it with a dummy MSELoss and it still fails in the same way. One weird thing is that if I skip the densenet201 instantiation, it will return:
RuntimeError: Function 'MseLossBackward0' returned nan values in its 1th output.
instead of:
RuntimeError: Function '_fMHABackward' returned nan values in its 1th output.
even though I am not using the densenet in this example pipeline.

I'm working with:
RTX 3090
Driver version: 510.108.03
CUDA 11.6
torch: 1.13.1
torchvision: 0.14.1
xformers: 0.0.16
triton: 2.0.0.dev20221120

Example:

import torch
import diffusers
torch.autograd.set_detect_anomaly(True)

text_embedding = torch.randn(1,77,768).clamp(0,1).half() #the actual thing we want to train
text_embedding = text_embedding.to(torch.device('cuda'))
text_embedding.requires_grad = True #because we want to optimize it
latent_seed = torch.randn(1,4,64,64).half().to(torch.device('cuda'))

unet_conditioner = diffusers.UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4",subfolder='unet', torch_dtype=torch.float16)
unet_conditioner.set_use_memory_efficient_attention_xformers(True)
unet_conditioner.to(torch.device('cuda'))
vae_upscaler = diffusers.AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4",subfolder='vae', torch_dtype=torch.float16)
vae_upscaler.set_use_memory_efficient_attention_xformers(True)
vae_upscaler.to(torch.device('cuda'))

from torchvision.models import densenet201
ranker = densenet201() #torchvision_variational_ranker()
ranker.half().to(torch.device('cuda'))

Loss = torch.nn.MSELoss()
optimizer = torch.optim.AdamW([text_embedding],lr=0.001)

iterations = 10
for i in range(iterations):
    conditioned_latent = unet_conditioner(latent_seed,1,text_embedding)[0] #feed latents to unet
    image = vae_upscaler.decoder(conditioned_latent)[0]
    loss = Loss(torch.randn_like(image),image)
    loss.backward() #propagate loss back to text embedding
    optimizer.step() #update text embedding
    text_embedding.clamp(0,1)

@danthe3rd
Copy link
Contributor

The reason for this is because the Stable Diffusion attention head dims are too large for FlashAttention on A6000s (>64. The actual head size is 80, iirc).

I've got something working with any value of K for Sm86/Sm89, and it's reasonably fast for K<=96 now. I hope to be able to push it this week or next week

@danthe3rd
Copy link
Contributor

danthe3rd commented Feb 2, 2023

I've got something working with any value of K for Sm86/Sm89, and it's reasonably fast for K<=96 now. I hope to be able to push it this week or next week

It's landed as part of 82d5881
The wheels/conda package should be available soon (version 0.0.17.dev441 or above) if you want binaries
Curious to get feedback from you on this :)
Do you know what batch size you are using for the training?

@ArrowM
Copy link

ArrowM commented Feb 2, 2023

Awesome, I'll try it out this morning. We can control the batch size on training. I think most people are using 1-2, but some of the people with 24GB VRAM cards use 10 or higher for batch size.

@danthe3rd
Copy link
Contributor

If you have 16 heads, you should use at least 108/16=6 for the batch size (for the memory_efficient backward pass). Below that, the kernel won't entirely use the GPU (it wasn't optimized for this use-case)

@Zuxier
Copy link

Zuxier commented Feb 2, 2023

I've got something working with any value of K for Sm86/Sm89, and it's reasonably fast for K<=96 now. I hope to be able to push it this week or next week

It's landed as part of 82d5881 The wheels/conda package should be available soon (version 0.0.17.dev441 or above) if you want binaries Curious to get feedback from you on this :) Do you know what batch size you are using for the training?

First test, it seems to be working 👍
Will check it a little more/ask other people to check and report back.

@ArrowM
Copy link

ArrowM commented Feb 2, 2023

My testing looks good too. Thanks Dan! For anyone else reading this, the wheels have been posted already

@ArrowM
Copy link

ArrowM commented Feb 2, 2023

Is there any chance you could add Torch2 builds to the published wheels? We are building our own atm, but it would be nice to have an official version of it.

@danthe3rd
Copy link
Contributor

Thanks for the feedback :)

Is there any chance you could add Torch2 builds to the published wheels?

Unfortunately PyTorch's ABI is not stable. This means that an XFormers binary built for a PyTorch version will not be compatible with the next one. A binary build would only be compatible with a single PT nightly, so we would need to build that every day and this would require some additional work to manage that.
We could add this in the future, but for now it's not something we plan to prioritize.

Of course, once PyTorch publishes 2.0 officially, we will change the binary builds to be for 2.0 instead of 1.13.1

@EandrewJones
Copy link

@danthe3rd Major props here. I spent hours trying to debug this on an a10g yesterday. You came through at the perfect time.

@shirayu
Copy link

shirayu commented Feb 3, 2023

Thanks!
xformers 0.0.17.dev442 seems to work fine on RTX 3060 Ti (sm_86).

@zoru22
Copy link
Author

zoru22 commented Feb 4, 2023

@danthe3rd I'm still running tests with different head shapes and sizes, but so far it's looking good for me!

If I don't report back in another day or two, then you can consider my report resolved. Thanks!

@danthe3rd
Copy link
Contributor

Closing as fixed :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests