Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix][speed] Better projections - correctness + speed #119

Merged
merged 1 commit into from
Nov 22, 2021

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Nov 20, 2021

What does this PR do?

  • fixes the self attention optimization (project all in one go) not being triggered
  • contiguous tensors are only needed for sparse, removing the contiguous cost when not needed
  • fixing a simplifying a broken path in the factory / residual and normalization, related to the above: only normalize once if this is the same tensor, and propagate the same id down

With these changes/fixes, vanilla xformers is 5-10% faster than timm on a vanilla ViT (up from being 10% slower) as per the bench from the other PR, and microGPT trains something decent in 20 minutes on a laptop (3080 / fp16)

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 20, 2021
@blefaudeux blefaudeux changed the base branch from main to vit_comp_bench November 20, 2021 09:18
@@ -258,7 +258,7 @@ def matmul_with_mask(self, a, b):
column_indices = self.column_indices
out = _sddmm.apply(
a,
b.transpose(-2, -1),
b.transpose(-2, -1).contiguous(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contiguous is only needed here and down, we were forcing this all the time before. Note that with a good projection kernel this could come for free, could be nice

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you should leave a comment/breadcrumb trail saying that in the code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do, or maybe just file an issue, basically the projection in the beginning of MHA could be hardened a little for speed

qkv.split(self.out_features, dim=-1),
)
return q, k, v
qkv = qkv.split(self.out_features, -1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was actually slow (memcopy) and not needed for non-sparse attention


def __init__(self, layer: nn.Module):
super().__init__()
self.layer = layer

def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]], *args, **kwargs):
Copy link
Contributor Author

@blefaudeux blefaudeux Nov 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was super error prone (@stephenroller spotted that long ago, I should have caught that then), since inputs and args can mix

def forward(self, *args, **kwargs):
# Could be that the same tensor has been passed multiple times
# in that case we'll just normalize once
list_ids = [id(inp) for inp in args]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems small, but here if the same tensor was actually passed multiple times (self attention), we would normalize 3 times and loose the same id(), which in turn means that in the attention layer we would not optimize for self-attention

@@ -364,8 +364,8 @@ def forward(
else:
target_q, target_k, target_v = target, target, target

x = self.wrap_att([target_q, target_k, target_v], att_mask=decoder_att_mask)
x = self.wrap_cross([x, memory, memory], att_mask=encoder_att_mask)
x = self.wrap_att(target_q, target_k, target_v, att_mask=decoder_att_mask)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related to the "input<>args" cleanup

@blefaudeux
Copy link
Contributor Author

@fmassa @dianaml0 I think that the code in block_factory and residual is probably way to complicated and error prone, could be worth another pass if deemed important (this is orthogonal to the parts zoo approach). In any case with these changes a default xformers is competitive in terms of speed vs. Timm ViT

@blefaudeux blefaudeux changed the title [DRAFT] Better projections - correctness + speed [fix][speed] Better projections - correctness + speed Nov 20, 2021
@codecov-commenter
Copy link

Codecov Report

Merging #119 (36d64f2) into vit_comp_bench (9499a7a) will decrease coverage by 0.01%.
The diff coverage is 100.00%.

Impacted file tree graph

@@                Coverage Diff                 @@
##           vit_comp_bench     #119      +/-   ##
==================================================
- Coverage           87.22%   87.21%   -0.02%     
==================================================
  Files                  49       49              
  Lines                2490     2487       -3     
==================================================
- Hits                 2172     2169       -3     
  Misses                318      318              
Flag Coverage Δ
Python 87.21% <100.00%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/components/attention/_sputnik_sparse.py 95.55% <ø> (ø)
xformers/components/attention/core.py 89.47% <100.00%> (ø)
xformers/components/in_proj_container.py 73.61% <100.00%> (ø)
xformers/components/residual.py 95.00% <100.00%> (-0.35%) ⬇️
xformers/factory/block_factory.py 92.81% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9499a7a...36d64f2. Read the comment docs.

@blefaudeux
Copy link
Contributor Author

Ping reviewers (and PR below), this significantly speeds up a code path used in the examples and by some folks around

@dianaml0
Copy link
Contributor

@fmassa @dianaml0 I think that the code in block_factory and residual is probably way to complicated and error prone, could be worth another pass if deemed important (this is orthogonal to the parts zoo approach). In any case with these changes a default xformers is competitive in terms of speed vs. Timm ViT

Seems like much of the complication was introduced for reversible layers to work? But agree probably a good idea to see if there's a cleaner way to do it.

Copy link
Contributor

@dianaml0 dianaml0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really great!! LGTM!

@blefaudeux
Copy link
Contributor Author

@fmassa @dianaml0 I think that the code in block_factory and residual is probably way to complicated and error prone, could be worth another pass if deemed important (this is orthogonal to the parts zoo approach). In any case with these changes a default xformers is competitive in terms of speed vs. Timm ViT

Seems like much of the complication was introduced for reversible layers to work? But agree probably a good idea to see if there's a cleaner way to do it.

Yes for reversible, I agree, it makes the signal paths a lot more complex unfortunately:( no strong opinion on that, I don't think that this was benchmarked enough to have a definitive answer. Probably possible to have a better code and keep reversible though

@blefaudeux blefaudeux merged commit 5d72a66 into vit_comp_bench Nov 22, 2021
@blefaudeux blefaudeux deleted the better_projections branch December 1, 2021 18:52
xwhan pushed a commit to xwhan/xformers that referenced this pull request Feb 8, 2022
…, not attentions (facebookresearch#119)

* follow up from facebookresearch#117, macro blocks mask inputs, not attentions
* matching unit test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants