Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom softmax in memory efficient attention #530

Merged
merged 11 commits into from
Nov 17, 2022

Conversation

comaniac
Copy link
Contributor

@comaniac comaniac commented Nov 15, 2022

What does this PR do?

Implement #522.

  • Add a new argument has_custom_scale to memory efficient attention forward. When it is true, we assume the query state weights are scaled in advance so they won't be scaled again in kernels. This is required especially for T5 models.
  • Add a new flag SUPPORTS_CUSTOM_SCALE to indicate whether a memory efficient op supports custom scale. In this PR we only enable MemoryEfficientFlashAttentionOp to align the API change and experiments. If everything goes well, I'll support MemoryEfficientAttentionOp in a follow-up PR (I'm not sure if the CUTLASS one can be supported without first changing CUTLASS kernel. It'd be great if someone could help confirm).

Updated based on review comments:

  • New attribute has_custom_scale: bool = False in AttentionOpDispatch.
  • New attribute SUPPORTS_CUSTOM_SCALE: bool = False in AttentionOpBase.
  • New argument scale: Optional[float] = None in AttentionOpDispatch.from_argument and memory_efficient_attention. When None, default scale value (1.0 / q.shape[-1] ** 0.5) will be used.
  • Supported ops: All but MemoryEfficientAttentionOp.
  • Unit test: Covered forward and backward of all supported ops.

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • Did you update the changelog? (if needed)

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 15, 2022
Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @comaniac and thanks a lot for putting this pull request!
This change looks great - just a few things we need to modify: notably, it would be great to support any scale.

Regarding the cutlass kernel, it would require to modify the CUDA code indeed. I can guide you through this if you know C++ already and are interested :)

xformers/ops/memory_efficient_attention.py Show resolved Hide resolved
xformers/ops/memory_efficient_attention.py Outdated Show resolved Hide resolved
xformers/ops/memory_efficient_attention.py Outdated Show resolved Hide resolved
@comaniac
Copy link
Contributor Author

@danthe3rd thanks for the review. Yes I'm familiar with C++ (but not CUDA...) and could help support the CUTLASS one if the scale happens outside of its CUDA kernel.

@danthe3rd
Copy link
Contributor

@danthe3rd thanks for the review. Yes I'm familiar with C++ (but not CUDA...) and could help support the CUTLASS one if the scale happens outside of its CUDA kernel.

C++ knowledge is enough. CUDA is basically C++, and to do this modification you just need to pass another parameter all the way to this function and replace this value:

1.0f / cutlass::fast_sqrt(float(p.head_dim)));

But I can give more guidance once we are done with this PR :)

@comaniac
Copy link
Contributor Author

Cool thanks for pointing out. I'll make an update later today.

@comaniac
Copy link
Contributor Author

comaniac commented Nov 16, 2022

Hi @danthe3rd,

I've done some implementations along with CUTLASS op (so that I could test it locally as I don't have A100 on hand) for your comments. Here are two issues I'm facing:

  1. Seems like I cannot add a floating type scale to Params in kernel_forward.h; otherwise I'll encounter some memory access issues. I tried to add it with int32_t and the error was gone, but we cannot use an integer type for scale.
  2. Seems like the kernel you pointed out wasn't executed somehow. I even simply put XFORMERS_CHECK(0) to let it crash internationally, but it didn't. So I'm not sure which one it was executed...

Any hints/suggestions would be appreciated. Thanks.

@danthe3rd
Copy link
Contributor

I think I know what might happen. When you modify kernel_forward.h, it will not necessarily re-build the correct files (that's a bug).
You can make sure they are recompiled with

touch xformers/components/attention/csrc/cuda/mem_eff_attention/kernels/*.cu xformers/components/attention/csrc/cuda/mem_eff_attention/*.cu && python3 setup.py develop

@comaniac
Copy link
Contributor Author

Thanks! That's exactly the reason...now it works.
The last thing I'm dealing with is we cannot put c10:optional in Params because this type cannot be put on GPU. I guess I'll lift the scale computation to a higher level and store the scale directly in float type in Params.

@danthe3rd
Copy link
Contributor

I guess I'll lift the scale computation to a higher level and store the scale directly in float type in Params.

yes that makes sense :)

@comaniac
Copy link
Contributor Author

I've tested locally that CUTLASS ops work with custom scale. I'll let CI run for the Triton kernels to make sure I didn't break anything. Meanwhile, this PR should be ready for review.

Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good! Great work!
Just a few more nits before we can get this merged

tests/test_mem_eff_attention.py Outdated Show resolved Hide resolved
xformers/ops/memory_efficient_attention.py Outdated Show resolved Hide resolved
tests/test_mem_eff_attention.py Show resolved Hide resolved
@comaniac
Copy link
Contributor Author

@danthe3rd all comments were addressed. PTAL.
I'll monitor the CI to make sure it's green. Please let me know if any document or changelog have to be changed accordingly; otherwise it should be good to go.

Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great thanks!
I'll wait for the CI to become green, and test performance on A100 before merging just in case.

tests/test_mem_eff_attention.py Show resolved Hide resolved
@danthe3rd
Copy link
Contributor

Looks like some tests are failing - you will need to modify test_logsumexp and test_cu_seqlen_forward to add scale=None

@comaniac
Copy link
Contributor Author

Looks like some tests are failing - you will need to modify test_logsumexp and test_cu_seqlen_forward to add scale=None

Yeah I found that...fixed.

Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @comaniac !

I ran some performance tests and it does not seem to affect performance (FW is very slightly faster, BW very slightly slower).

A100 fw
[------------------ attention (attn_bias=<class 'NoneType'>) -----------------]                                  
                                     |  pr530_9b93469d  |    main    |   eager 
1 threads: --------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |        123.6     |     124.3  |    852.2
      f16 B=384, M=197, H=1, K=80    |        115.2     |     115.8  |    740.9
      f16 B=384, M=197, H=1, K=64    |         87.2     |      87.9  |    679.3
      f16 B=1024, M=197, H=1, K=88   |        315.2     |     315.7  |   2219.5
      f16 B=1024, M=197, H=1, K=80   |        293.5     |     294.9  |   1924.5
      f16 B=1024, M=197, H=1, K=64   |        209.3     |     211.7  |   1768.7
      f16 B=512, M=197, H=1, K=80    |        152.4     |     153.0  |    978.6
      f16 B=32, M=197, H=16, K=80    |        153.2     |     153.9  |   1064.8
      f16 B=32, M=197, H=16, K=64    |        112.0     |     113.1  |    981.8
      f16 B=32, M=197, H=16, K=128   |        165.1     |     167.0  |   1719.5
      f16 B=256, M=197, H=1, K=88    |         86.7     |      86.9  |    576.5
      f16 B=16, M=197, H=16, K=88    |         87.3     |      87.6  |    629.6
      f16 B=16, M=197, H=16, K=64    |         59.1     |      59.7  |    507.4
      f16 B=16, M=197, H=16, K=128   |         88.1     |      88.6  |    879.3
      f16 B=1, M=4096, H=160, K=128  |      15071.3     |   15053.1  |  20538.5
      f16 B=2, M=4096, H=160, K=128  |      30077.7     |   30037.7  |  41703.8
      f16 B=1, M=8192, H=160, K=128  |      60176.1     |   60096.2  |         
      f16 B=2, M=8192, H=160, K=128  |     120233.4     |  120091.0  |         
      f16 B=1024, M=82, H=8, K=64    |        447.2     |     450.9  |   1789.4
      f16 B=150, M=256, H=16, K=64   |        503.9     |     511.2  |   1990.6
      f16 B=64, M=256, H=12, K=64    |        170.3     |     172.0  |    671.8
      f16 B=1, M=4096, H=16, K=40    |        873.4     |     882.4  |   1879.5
      f16 B=1, M=16384, H=16, K=40   |      12397.6     |   12568.5  |  29433.2
      f16 B=256, M=4096, H=16, K=64  |     183892.7     |  187014.7  |         
      f16 B=16, M=128, H=16, K=16    |         28.6     |      27.9  |    127.0
      f16 B=16, M=128, H=16, K=32    |         28.3     |      27.8  |    128.2
      f16 B=16, M=128, H=16, K=64    |         28.8     |      28.2  |    127.3
      f16 B=16, M=128, H=16, K=128   |         38.9     |      39.1  |    147.5
      f16 B=16, M=512, H=16, K=16    |        174.7     |     176.9  |    519.5
      f16 B=16, M=512, H=16, K=32    |        180.8     |     182.6  |    568.3
      f16 B=16, M=512, H=16, K=64    |        208.2     |     210.7  |    675.8
      f16 B=16, M=512, H=16, K=128   |        360.4     |     362.2  |    860.2
      f16 B=16, M=1024, H=16, K=16   |        666.0     |     674.1  |   1820.0
      f16 B=16, M=1024, H=16, K=32   |        672.2     |     680.5  |   1913.4
      f16 B=16, M=1024, H=16, K=64   |        768.9     |     780.0  |   2159.2
      f16 B=16, M=1024, H=16, K=128  |       1348.2     |    1352.6  |   2592.8
      f16 B=64, M=128, H=16, K=16    |         52.4     |      53.0  |    204.2
      f16 B=64, M=128, H=16, K=32    |         58.1     |      58.3  |    250.9
      f16 B=64, M=128, H=16, K=64    |         73.4     |      73.6  |    349.7
      f16 B=64, M=128, H=16, K=128   |        130.6     |     131.0  |    538.1
      f16 B=64, M=512, H=16, K=16    |        679.7     |     687.9  |   1883.3
      f16 B=64, M=512, H=16, K=32    |        688.1     |     696.1  |   2071.7
      f16 B=64, M=512, H=16, K=64    |        791.1     |     802.8  |   2472.3
      f16 B=64, M=512, H=16, K=128   |       1415.9     |    1412.6  |   3258.6
      f16 B=64, M=1024, H=16, K=16   |       2604.6     |    2637.9  |   7128.6
      f16 B=64, M=1024, H=16, K=32   |       2625.8     |    2658.1  |   7520.6
      f16 B=64, M=1024, H=16, K=64   |       3000.2     |    3046.3  |   8521.2
      f16 B=64, M=1024, H=16, K=128  |       5358.9     |    5401.8  |  10204.7

Times are in microseconds (us).
A100 bw
[------------- attention backward (attn_bias=<class 'NoneType'>) -------------]                                  
                                     |  pr530_9b93469d  |    main    |  vanilla
1 threads: --------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |        609.7     |     602.9  |   2256.9
      f16 B=384, M=197, H=1, K=80    |        577.2     |     571.9  |   1915.9
      f16 B=384, M=197, H=1, K=64    |        386.9     |     385.4  |   1807.0
      f16 B=1024, M=197, H=1, K=88   |       1548.1     |    1535.4  |   5942.5
      f16 B=1024, M=197, H=1, K=80   |       1465.8     |    1456.2  |   5021.6
      f16 B=1024, M=197, H=1, K=64   |        862.0     |     861.3  |   4737.1
      f16 B=512, M=197, H=1, K=80    |        734.1     |     726.8  |   2539.7
      f16 B=32, M=197, H=16, K=80    |        723.5     |     718.5  |   2578.6
      f16 B=32, M=197, H=16, K=64    |        457.7     |     457.2  |   2432.8
      f16 B=32, M=197, H=16, K=128   |        861.5     |     846.4  |   4497.1
      f16 B=256, M=197, H=1, K=88    |        448.8     |     440.4  |   1526.1
      f16 B=16, M=197, H=16, K=88    |        440.1     |     435.1  |   1538.6
      f16 B=16, M=197, H=16, K=64    |        232.0     |     233.1  |   1245.4
      f16 B=16, M=197, H=16, K=128   |        490.0     |     485.6  |   2266.9
      f16 B=1, M=4096, H=160, K=128  |      63475.5     |   63526.0  |  46356.4
      f16 B=2, M=4096, H=160, K=128  |     100184.6     |  100160.4  |         
      f16 B=1, M=8192, H=160, K=128  |     251758.1     |  251911.2  |         
      f16 B=2, M=8192, H=160, K=128  |     394557.1     |  394659.7  |         
      f16 B=1024, M=82, H=8, K=64    |       1866.4     |    1855.6  |   3826.4
      f16 B=150, M=256, H=16, K=64   |       2105.3     |    2102.4  |   4559.9
      f16 B=64, M=256, H=12, K=64    |        727.6     |     729.2  |   1499.7
      f16 B=1, M=4096, H=16, K=40    |      23539.1     |   23522.2  |   4240.3
      f16 B=1, M=16384, H=16, K=40   |     435278.5     |  437517.8  |         
      f16 B=256, M=4096, H=16, K=64  |     603767.9     |  602664.0  |         
      f16 B=16, M=128, H=16, K=16    |        146.0     |     141.7  |    492.7
      f16 B=16, M=128, H=16, K=32    |        147.6     |     238.3  |    397.8
      f16 B=16, M=128, H=16, K=64    |        142.3     |     138.9  |    303.9
      f16 B=16, M=128, H=16, K=128   |        178.0     |     178.0  |    301.2
      f16 B=16, M=512, H=16, K=16    |        552.9     |     553.3  |   1204.7
      f16 B=16, M=512, H=16, K=32    |        651.5     |     650.8  |   1308.3
      f16 B=16, M=512, H=16, K=64    |        849.4     |     849.4  |   1545.2
      f16 B=16, M=512, H=16, K=128   |       1763.6     |    1762.4  |   1983.9
      f16 B=16, M=1024, H=16, K=16   |       2229.1     |    2224.9  |   4262.5
      f16 B=16, M=1024, H=16, K=32   |       2438.5     |    2438.5  |   4493.4
      f16 B=16, M=1024, H=16, K=64   |       3031.7     |    3025.4  |   5001.1
      f16 B=16, M=1024, H=16, K=128  |       6403.9     |    6398.4  |   5961.8
      f16 B=64, M=128, H=16, K=16    |        161.7     |     161.6  |    439.2
      f16 B=64, M=128, H=16, K=32    |        206.4     |     206.2  |    545.0
      f16 B=64, M=128, H=16, K=64    |        327.3     |     326.4  |    766.2
      f16 B=64, M=128, H=16, K=128   |        614.2     |     614.8  |   1231.2
      f16 B=64, M=512, H=16, K=16    |       1975.7     |    1974.2  |   4487.7
      f16 B=64, M=512, H=16, K=32    |       2355.5     |    2333.9  |   4979.9
      f16 B=64, M=512, H=16, K=64    |       3075.9     |    3076.1  |   5888.7
      f16 B=64, M=512, H=16, K=128   |       6148.0     |    6148.3  |   7706.6
      f16 B=64, M=1024, H=16, K=16   |       7849.9     |    7839.7  |  16909.4
      f16 B=64, M=1024, H=16, K=32   |       8856.1     |    8797.2  |  17904.8
      f16 B=64, M=1024, H=16, K=64   |      11054.6     |   11058.0  |  19959.4
      f16 B=64, M=1024, H=16, K=128  |      21944.5     |   21930.4  |  23716.0

Times are in microseconds (us).

@danthe3rd danthe3rd merged commit c101579 into facebookresearch:main Nov 17, 2022
@comaniac comaniac deleted the custom_scale branch November 17, 2022 19:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants