Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance update #1 : loop unfolding + reverse pscan for backward #12

Merged
merged 14 commits into from
Feb 9, 2024

Conversation

alxndrTL
Copy link
Owner

@alxndrTL alxndrTL commented Feb 9, 2024

First part of the performance update I'm working on for mamba.py, PyTorch version, for training (faster parallel scan).

Can make your training 22% faster !

This update improves the speed of the parallel scan with two improvements over the original version :

  • loop unfolding : the last 2 steps of the up-sweep as well as the first 2 steps of the down-sweep are no more part of their respective for loop, but are done outside of the loops, manually. This avoids a lot of "setup" just to do a few operations (define Xa, reshape, ...). This improvement alone brings a 13% speed improvement for small input lengths like L=64. For longer sequences, like L=1024, this only brings a 1.4% improvement. This is normal : as the input length goes to infinity, this improvement is too little to be seen, because it only speeds up 4 steps out of the 2*log2(n) steps that are done in the pscan. But for small sequences, like here in practice, this is cool to have !
  • reverse pscan for the backward calculation : up until now, the same pscan was used for both the forward and backward. It happens that for the backward part, we actually need to compute a reverse pscan. What was done before in the backward was to use the same pscan function, but flip the input, do the pscan, and then flip the output. What I've done here is simply code a reverse pscan, that works very similarly like the pscan, but goes in reverse wrt the time dimension. Because it avoids a bunch of torch.flip operations on tensors of shape (B, D, L, N), it gives us a nice speed improvement.

The memory footprint is not touched and is exactly the same as before. I plan to adress this in the second part of the performance update.

Some more precise benchmarks:

(time to do a forward+backward of a whole Mamba model, d_state=16), A100 80GB, with torch profiler)

before update new speed improvement
(B, L, D) = (16, 64, 128), n_layers=8 528.2ms 408.8ms 22.6%
(B, L, D) = (16, 1024, 128), n_layers=16 3.465s 2.887s 16.7%
(B, L, D) = (16, 1024, 1024), n_layers=8 2.727s 2.527s 8.8%

Forward+backward is still numerically equivalent to the original Mamba implementation.

I haven't updated the Performance section of the README.md. (will do after the second part of the update).
Note that this update don't affect the asymptotic behavior of mama.py.

@alxndrTL alxndrTL merged commit 604493a into main Feb 9, 2024
@alxndrTL alxndrTL deleted the perf-update-1 branch February 9, 2024 19:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant