Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RTX 30xx/Sm86] memory_efficient_attention backward not supported for K>64 (f32 & f16?) #517

Closed
Thomas-MMJ opened this issue Nov 10, 2022 · 22 comments · Fixed by #526
Closed
Labels
enhancement New feature or request

Comments

@Thomas-MMJ
Copy link
Contributor

🐛 Bug

numerous tests in test_mem_eff_attetion.py failing due to assertion errors,

here is the first one

pytest ./tests/test_mem_eff_attetion.py::test_backward[cutlass-cuda-torch.float32-1,32,32,1,128,128-False-None-BMHK]

E       AssertionError: qkv: out=0.0 and ref=25.226354598999023 (diff=25.177289962768555 > 0)/ atol=0.04654095001184158, rtol=0.0001
E       assert False
E        +  where False = <built-in method allclose of type object at 0x7f7c34343d80>(tensor([[[[[0., 0., 0.,  ..., 0., 0., 0.]],\n\n          [[0., 0., 0.,  ..., 0., 0., 0.]],\n\n          [[0., 0., 0.,  ......0., 0.]],\n\n          [[0., 0., 0.,  ..., 0., 0., 0.]],\n\n          [[0., 0., 0.,  ..., 0., 0., 0.]]]]], device='cuda:0'), tensor([[[[[ 1.0817e-01,  1.3813e-02,  3.2980e-02,  ...,  2.0466e-02,\n            -6.8351e-03,  2.1739e-02]],\n\n       ...[[ 1.0926e-01,  1.0926e-01,  1.0926e-01,  ...,  1.0926e-01,\n             1.0926e-01,  1.0926e-01]]]]], device='cuda:0'), rtol=0.0001, atol=0.04654095001184158)
E        +    where <built-in method allclose of type object at 0x7f7c34343d80> = torch.allclose

tests/test_mem_eff_attention.py:118: AssertionError
================================================================= short test summary info ==================================================================
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,32,32,1,128,128-False-None-BMHK] - AssertionError: qkv: out=0.0 and ref=25.226354598999023 (diff=25.177289962768555 > 0)/ atol=0.04654095001184158, rtol=0.0001

Environment

Collecting environment information...
PyTorch version: 1.14.0.dev20221107
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.24.3
Libc version: glibc-2.31

Python version: 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:56:21)  [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-5.15.68.1-microsoft-standard-WSL2-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060 Laptop GPU
Nvidia driver version: 522.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.5.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] clip-anytorch==2.5.0
[pip3] colossalai==0.1.11rc2+torch1.14cu11.8
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.4
[pip3] pytorch-lightning==1.8.0.post1
[pip3] torch==1.14.0.dev20221107
[pip3] torchaudio==0.14.0.dev20221107
[pip3] torchdiffeq==0.2.2
[pip3] torchmetrics==0.10.2
[pip3] torchvision==0.15.0.dev20221107
[conda] blas                      1.0                         mkl
[conda] clip-anytorch             2.5.0                    pypi_0    pypi
[conda] colossalai                0.1.11rc2+torch1.14cu11.8          pypi_0    pypi
[conda] cudatoolkit               11.7.0              hd8887f6_10    nvidia
[conda] libblas                   3.9.0            16_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            16_linux64_mkl    conda-forge
[conda] liblapack                 3.9.0            16_linux64_mkl    conda-forge
[conda] liblapacke                3.9.0            16_linux64_mkl    conda-forge
[conda] mkl                       2022.1.0           hc2b9512_224
[conda] numpy                     1.23.4           py39h3d75532_1    conda-forge
[conda] pytorch                   1.14.0.dev20221107 py3.9_cuda11.7_cudnn8.5.0_0    pytorch-nightly
[conda] pytorch-cuda              11.7                 h67b0de4_0    pytorch-nightly
[conda] pytorch-lightning         1.8.0.post1              pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                0.14.0.dev20221107      py39_cu117    pytorch-nightly
[conda] torchdiffeq               0.2.2              pyhd8ed1ab_0    conda-forge
[conda] torchmetrics              0.10.2                   pypi_0    pypi
[conda] torchvision               0.15.0.dev20221107      py39_cu117    pytorch-nightly
@danthe3rd
Copy link
Contributor

Hi thanks for reporting. Can you also add the output of python -m xformers.info ?

@Thomas-MMJ
Copy link
Contributor Author

Thomas-MMJ commented Nov 10, 2022

Sure,

xFormers 0.0.14.dev
memory_efficient_attention.flshatt:      available - requires GPU with compute capability 7.5+
memory_efficient_attention.cutlass:      available
memory_efficient_attention.small_k:      available
is_triton_available:                     True
is_functorch_available:                  True
pytorch.version:                         1.14.0.dev20221107
pytorch.cuda:                            available
gpu.compute_capability:                  8.6
gpu.name:                                NVIDIA GeForce RTX 3060 Laptop GPU

note that pytorch 1.13+ functorch is part of pytorch, so isn't a seperate install, see

Previously, functorch was released out-of-tree in a separate package. After installing PyTorch, a user will be able to import functorch and use functorch without needing to install another package.

https://pytorch.org/blog/PyTorch-1.13-release/

Output is with the Is_functorch_available now set in

__init__.py

  • same result.

Can do a patch to set it automatically for pytorch >= 1.13

Rerunning with it set and still not passing though.

Will retest again after the latest commits with the cutlass update

@Thomas-MMJ
Copy link
Contributor Author

No improvement with latest git

@danthe3rd
Copy link
Contributor

danthe3rd commented Nov 11, 2022

Thanks for the update. We haven't tested xFormers with PyTorch 1.13 (we're still working on 1.12.1) so that might be the issue?
I'll have a look but I'll be pretty busy next week - so might have to wait a bit :/

NOTE: functorch shouldn't matter for this test

@danthe3rd
Copy link
Contributor

I was able to run some more tests. I confirm the tests pass on pytorch 1.13 (and today's 1.14 nightly) with cuda 11.7. Might be due to the GPU (RTX 3060) with compute capability 8.6?
We are mostly testing/developing CC 8.0/7.5/6.1/6.0 - so it might be a bit harder to debug what's happening in your case as I don't have the corresponding hardware to test...

@danthe3rd danthe3rd changed the title assertion errors in test_mem_eff_attention.py [RTX 3060/Sm75] assertion errors in test_mem_eff_attention.py Nov 11, 2022
@IdiotSandwichTheThird
Copy link

IdiotSandwichTheThird commented Nov 12, 2022

I can confirm this test (and a lot of the others as mentioned) also fails on my 3090. Environment is the default given in the Readme.md, installed the latest build from conda.

============================================================================================= short test summary info =============================================================================================
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,32,32,1,128,128-False-None-BMHK] - AssertionError: qkv: out=0.0 and ref=25.226354598999023 (diff=25.177289962768555 > 0)/ atol=0.04654095001184158, rtol=0.0001
Collecting environment information...
PyTorch version: 1.13.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: version 3.24.3
Libc version: glibc-2.35

Python version: 3.10.6 (main, Oct 24 2022, 16:07:47) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-52-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==0.982
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.4
[pip3] pytorch-lightning==1.8.1
[pip3] torch==1.13.0
[pip3] torchaudio==0.13.0
[pip3] torchmetrics==0.10.2
[pip3] torchvision==0.14.0
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.6.0              hecad31d_10    conda-forge
[conda] mkl                       2022.1.0           hc2b9512_224  
[conda] pytorch                   1.12.1          py3.10_cuda11.6_cudnn8.3.2_0    pytorch
[conda] pytorch-lightning         1.8.1                    pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchmetrics              0.10.2                   pypi_0    pypi

xFormers 0.0.15.dev+8367685.d20221112
memory_efficient_attention.flshatt:      not built
memory_efficient_attention.cutlass:      available
memory_efficient_attention.small_k:      available
swiglu.fused.p.cpp:                      available
is_triton_available:                     True
is_functorch_available:                  False
pytorch.version:                         1.13.0+cu117
pytorch.cuda:                            available
gpu.compute_capability:                  8.6
gpu.name:                                NVIDIA GeForce RTX 3090

Issue also appears when built from source instead.

xFormers 0.0.15.dev+8367685.d20221112
memory_efficient_attention.flshatt:      available - requires GPU with compute capability 7.5+
memory_efficient_attention.cutlass:      available
memory_efficient_attention.small_k:      available
swiglu.fused.p.cpp:                      available
is_triton_available:                     True
is_functorch_available:                  False
pytorch.version:                         1.13.0+cu117
pytorch.cuda:                            available
gpu.compute_capability:                  8.6
gpu.name:                                NVIDIA GeForce RTX 3090

@danthe3rd
Copy link
Contributor

I suspect this might be due to the shared-memory amount - it's 160kb in Sm80 vs 100kb for Sm86.
Can you run the test again and give me a paste of the output? I would like to see the list of failed tests - so I can update the dispatcher accordingly to not run kernels not supported.
Also can you run the tests with "CUDA_LAUNCH_BLOCKING=1" ?
Thanks a lot!

@IdiotSandwichTheThird
Copy link

IdiotSandwichTheThird commented Nov 13, 2022

I've attached the full log and a full log With CUDA_LAUNCH_BLOCKING=1 to this comment.

Running the full test normally a 2nd time, the amount of errors changed yet again to 1171 failed. Not sure if this is expected behavior, So I've added 2 additional run logs without CUDA_LAUNCH_BLOCKING here
log launch blocking.txt
logFull.txt

@Thomas-MMJ
Copy link
Contributor Author

the

RuntimeError: Expected is_sm80 to be true, but got false.

and

torch.cuda.OutOfMemoryError are also included in the log above,

slight variations in GPU usage from other processes or differences in allocation/fragmentation can result on the OutOfMemoryError sometimes not occurring, which probably is why the amount of errors change.

@Thomas-MMJ
Copy link
Contributor Author

Thomas-MMJ commented Nov 13, 2022

If I compile using

TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop

The tests don't pass either,

an interesting pattern I noticed is that it appears to be the float32 variants that are predominantly failing, the float16 and bfloat16 are passing usually. Running the tests with -v make the pattern fairly clear,

for instance, here is the first group of 8 float16 and 8 bloat16 passing compared to the the first 8 equivalent float32 that fail. (I've removed the skips)

tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float16-1,32,32,1,128,128-False-None-BMK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float16-1,32,32,1,128,128-False-None-BMHK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float16-1,32,32,1,128,128-False-LowerTriangularMask-BMK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float16-1,32,32,1,128,128-False-LowerTriangularMask-BMHK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float16-1,32,32,1,128,128-True-None-BMK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float16-1,32,32,1,128,128-True-None-BMHK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float16-1,32,32,1,128,128-True-LowerTriangularMask-BMK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float16-1,32,32,1,128,128-True-LowerTriangularMask-BMHK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.bfloat16-1,32,32,1,128,128-False-None-BMK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.bfloat16-1,32,32,1,128,128-False-None-BMHK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.bfloat16-1,32,32,1,128,128-False-LowerTriangularMask-BMK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.bfloat16-1,32,32,1,128,128-False-LowerTriangularMask-BMHK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.bfloat16-1,32,32,1,128,128-True-None-BMK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.bfloat16-1,32,32,1,128,128-True-None-BMHK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.bfloat16-1,32,32,1,128,128-True-LowerTriangularMask-BMK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.bfloat16-1,32,32,1,128,128-True-LowerTriangularMask-BMHK] PASSED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,32,32,1,128,128-False-None-BMK] FAILED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,32,32,1,128,128-False-None-BMHK] FAILED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,32,32,1,128,128-False-LowerTriangularMask-BMK] FAILED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,32,32,1,128,128-False-LowerTriangularMask-BMHK] FAILED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,32,32,1,128,128-True-None-BMK] FAILED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,32,32,1,128,128-True-None-BMHK] FAILED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,32,32,1,128,128-True-LowerTriangularMask-BMK] FAILED [ 63%]
tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,32,32,1,128,128-True-LowerTriangularMask-BMHK] FAILED [ 63%]

@Thomas-MMJ
Copy link
Contributor Author

Here are the results for my 3060 GPU (laptop variant 6GB) 1538 fails, the additional are due to more frequently exceeding VRAM capacity,

mem_eff_results2.txt

@danthe3rd
Copy link
Contributor

danthe3rd commented Nov 14, 2022

Thanks a lot for the logs, that really helps! From the results, it looks like:

[*] I don't count the OOM errors, as they are unrelated to the kernels

For CUTLASS, it looks like the 3060 GPUs don't have enough shared-memory to run the backward:

I'll try to address that

@danthe3rd
Copy link
Contributor

If you have some time, is it possible for you to check if the issue is solved with #526 ?
(you should still have the OOM issues, but it's less important)

@Thomas-MMJ
Copy link
Contributor Author

Thomas-MMJ commented Nov 14, 2022

Here is the test with sm86 update, no VRAM OOMs (skipped now it looks like) there are now 40 fails, ran with CUDA_LAUNCH_BLOCKING=1

== 40 failed, 14455 passed, 18553 skipped, 408 warnings in 286.51s (0:04:46) ===

previous was

= 1538 failed, 13761 passed, 17749 skipped, 578 warnings in 486.09s (0:08:06) ==

So 694 more passed; and 804 skipped, fail reduction of 1498

FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,256,128,1,32,264-False-None-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,256,128,1,32,264-False-None-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,256,128,1,32,264-False-LowerTriangularMask-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,256,128,1,32,264-False-LowerTriangularMask-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,256,128,1,32,264-True-None-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,256,128,1,32,264-True-None-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,256,128,1,32,264-True-LowerTriangularMask-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-1,256,128,1,32,264-True-LowerTriangularMask-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-300,256,128,1,32,264-False-None-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-300,256,128,1,32,264-False-None-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-300,256,128,1,32,264-False-LowerTriangularMask-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-300,256,128,1,32,264-False-LowerTriangularMask-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-300,256,128,1,32,264-True-None-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-300,256,128,1,32,264-True-None-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-300,256,128,1,32,264-True-LowerTriangularMask-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-300,256,128,1,32,264-True-LowerTriangularMask-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-24,389,456,8,16,72-False-None-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-24,389,456,8,16,72-False-None-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-24,389,456,8,16,72-False-LowerTriangularMask-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-24,389,456,8,16,72-False-LowerTriangularMask-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-24,389,456,8,16,72-True-None-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-24,389,456,8,16,72-True-None-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-24,389,456,8,16,72-True-LowerTriangularMask-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-24,389,456,8,16,72-True-LowerTriangularMask-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-35,170,242,10,32,96-False-None-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-35,170,242,10,32,96-False-None-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-35,170,242,10,32,96-False-LowerTriangularMask-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-35,170,242,10,32,96-False-LowerTriangularMask-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-35,170,242,10,32,96-True-None-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-35,170,242,10,32,96-True-None-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-35,170,242,10,32,96-True-LowerTriangularMask-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-35,170,242,10,32,96-True-LowerTriangularMask-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-18,449,261,9,32,80-False-None-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-18,449,261,9,32,80-False-None-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-18,449,261,9,32,80-False-LowerTriangularMask-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-18,449,261,9,32,80-False-LowerTriangularMask-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-18,449,261,9,32,80-True-None-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-18,449,261,9,32,80-True-None-BMHK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-18,449,261,9,32,80-True-LowerTriangularMask-BMK] - RuntimeError: CUDA error: invalid argument
FAILED tests/test_mem_eff_attention.py::test_backward[cutlass-cuda-torch.float32-18,449,261,9,32,80-True-LowerTriangularMask-BMHK] - RuntimeError: CUDA error: invalid argument

test_sm86_updated.txt

@danthe3rd
Copy link
Contributor

Thanks a lot! This is really helpful! I updated #526 and it should address the missing tests that were failing.
note that this PR just disables the kernels for the cases that cause issues, and the call to memory_efficient_attention will end up with a NotImplementedError on Sm86 until we implement it properly (typically for value.shape[-1] > 64, when using it with the backward - typically for training)

@Thomas-MMJ
Copy link
Contributor Author

All passed or skipped now

============================ 14455 passed, 18593 skipped, 408 warnings in 193.37s (0:03:13) ============================

@danthe3rd
Copy link
Contributor

Awesome! I'll merge the PR then. I'll leave this open as we still would like to support K>64 on those GPUs for the backward (but as a lower priority tho)

@IdiotSandwichTheThird
Copy link

same here,
=============== 14455 passed, 18593 skipped in 92.06s (0:01:32) ================

@danthe3rd
Copy link
Contributor

Awesome! (reopening as we still need to support K>64)

@danthe3rd danthe3rd reopened this Nov 14, 2022
@danthe3rd danthe3rd changed the title [RTX 30X0/Sm86] assertion errors in test_mem_eff_attention.py [RTX 30xx/Sm86] memory_efficient_attention backward not supported for K>64/f32 Nov 15, 2022
@danthe3rd danthe3rd added the enhancement New feature or request label Nov 15, 2022
@danthe3rd danthe3rd changed the title [RTX 30xx/Sm86] memory_efficient_attention backward not supported for K>64/f32 [RTX 30xx/Sm86] memory_efficient_attention backward not supported for K>64 (f32 & f16) Jan 11, 2023
@polym
Copy link

polym commented Mar 22, 2023

3080 Ti met same problem. Has there been any update?

NotImplementedError: No operator found for `memory_efficient_attention_backward` with inputs:
     query       : shape=(16, 1024, 1, 80) (torch.float32)
     key         : shape=(16, 77, 1, 80) (torch.float32)
     value       : shape=(16, 77, 1, 80) (torch.float32)
     attn_bias   : <class 'NoneType'>
     p           : 0.0
`flshattB` is not supported because:
    dtype=torch.float32 (supported: {torch.bfloat16, torch.float16})
    requires a GPU with compute capability == 8.0 for 'query.shape[-1] > 64'
`cutlassB` is not supported because:
    Sm86 does not have enough shared-memory to run this kernel - see https://github.com/facebookresearch/xformers/issues/517
`smallkB` is not supported because:
    max(query.shape[-1] != value.shape[-1]) > 32
    has custom scale
    unsupported embed per head: 80

@danthe3rd
Copy link
Contributor

danthe3rd commented Mar 22, 2023

Hi,
This should be solved in the latest development releases - see Development Binaries here
The next release (0.0.17) will include the fix

@polym
Copy link

polym commented Mar 23, 2023

👍 Looks good now, Thanks! I use xformers to train dreambooth on 12GB GPU and it works.

[0] NVIDIA GeForce RTX 3080 Ti | 74°C, 99 % | 11589 / 12288 MB | python/17978(11342M)

ref: https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-on-a-12gb-gpu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants