Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is bfloat16 supported? #231

Closed
colehawkins opened this issue Mar 9, 2022 · 9 comments · Fixed by #272
Closed

Is bfloat16 supported? #231

colehawkins opened this issue Mar 9, 2022 · 9 comments · Fixed by #272

Comments

@colehawkins
Copy link
Contributor

I try to run this example: https://github.com/facebookresearch/xformers/blob/main/HOWTO.md#blocksparseattention but change the dtype to bfloat16 and cast everything to bfloat16. The error I get is

  File "/home/ubuntu/long_context/tmp.py", line 45, in <module>
    att_val = multi_head(query=query, key=query, value=query, att_mask=causal_mask)
  File "/home/ubuntu/anaconda3/envs/xf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/long_context/xformers/xformers/components/multi_head_dispatch.py", line 201, in forward
    y = self.attention(
  File "/home/ubuntu/anaconda3/envs/xf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/long_context/xformers/xformers/components/attention/blocksparse.py", line 195, in forward
    sparse_att_mat = self.sparse_softmax(
  File "/home/ubuntu/anaconda3/envs/xf/lib/python3.9/site-packages/triton/ops/blocksparse/softmax.py", line 229, in __call__
    lut, maxlut = self.make_lut(x.device)
  File "/home/ubuntu/anaconda3/envs/xf/lib/python3.9/site-packages/triton/ops/blocksparse/softmax.py", line 208, in make_lut
    self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
  File "/home/ubuntu/anaconda3/envs/xf/lib/python3.9/site-packages/triton/ops/blocksparse/softmax.py", line 118, in make_lut
    offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
RuntimeError: "cumsum_out_cpu" not implemented for 'BFloat16'

Triton 2.0 appears to support bfloat16, but I'm not sure if Triton 1.X used by xformers should support bfloat16 or if xformers should/will support bfloat16.

@blefaudeux
Copy link
Contributor

thanks for filing in the issue ! right now triton 1.x is the only stable version, so that's what we support, but the plan is to move to triton 2 as soon as it's out, and this is tied with bfloat16 support indeed. cc @ptillet @dianaml0

@blefaudeux
Copy link
Contributor

one thing we could do is support it in /experimental, this part is CI-tested still but not part of the pip package, and it already depends on triton2.0. Thoughts @dianaml0 ?

@dianaml0
Copy link
Contributor

That could work, since we'll eventually want to support it anyway. Do we know when triton 2.0 will be out?

@colehawkins
Copy link
Contributor Author

@dianaml0 @blefaudeux Happy to pitch in by supporting block-sparse attention through the triton 2.0 dev branch and writing some tests.

@ptillet
Copy link

ptillet commented Mar 15, 2022

@dianaml0 I could probably release some v2.0alpha if it helps. I believe most of the features are there. But there should be more docs and bugfixes before the final release :)

@blefaudeux
Copy link
Contributor

that would be great @colehawkins , pull me in anytime if I can help ! Anything in experimental/tests runs with triton2.0 on the CI already, I hope that helps. I'll also be migrating the memory-efficient implementation there soon, hoping to get it to work with triton 2 finally

@colehawkins
Copy link
Contributor Author

@blefaudeux My plan is to move the block-sparse attention over (https://github.com/facebookresearch/xformers/blob/main/xformers/components/attention/blocksparse.py).

Then I will move over the tests below:
https://github.com/facebookresearch/xformers/blob/main/tests/test_attention_patterns.py
https://github.com/facebookresearch/xformers/blob/main/tests/test_triton_blocksparse.py

I will ignore some of the attention patterns tests which do not touch blocksparse.

@ptillet it seems that this xformers test_triton_blocksparse.py replicates the triton tests in 2.0 dev branch which do not pass when I run them due to some API changes (i.e. device argument). I can notify you on the PR in case the updates I make are useful.

@blefaudeux
Copy link
Contributor

@blefaudeux My plan is to move the block-sparse attention over (https://github.com/facebookresearch/xformers/blob/main/xformers/components/attention/blocksparse.py).

Then I will move over the tests below: https://github.com/facebookresearch/xformers/blob/main/tests/test_attention_patterns.py https://github.com/facebookresearch/xformers/blob/main/tests/test_triton_blocksparse.py

I will ignore some of the attention patterns tests which do not touch blocksparse.

@ptillet it seems that this xformers test_triton_blocksparse.py replicates the triton tests in 2.0 dev branch which do not pass when I run them due to some API changes (i.e. device argument). I can notify you on the PR in case the updates I make are useful.

bingo, sounds great, thank you already !

@blefaudeux blefaudeux mentioned this issue Mar 24, 2022
15 tasks
@blefaudeux blefaudeux linked a pull request Mar 24, 2022 that will close this issue
15 tasks
@blefaudeux blefaudeux linked a pull request Apr 19, 2022 that will close this issue
15 tasks
@blefaudeux
Copy link
Contributor

Should be good to go !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants