Performance update #1 : loop unfolding + reverse pscan for backward #12
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
First part of the performance update I'm working on for mamba.py, PyTorch version, for training (faster parallel scan).
Can make your training 22% faster !
This update improves the speed of the parallel scan with two improvements over the original version :
Xa
, reshape, ...). This improvement alone brings a 13% speed improvement for small input lengths likeL=64
. For longer sequences, likeL=1024
, this only brings a 1.4% improvement. This is normal : as the input length goes to infinity, this improvement is too little to be seen, because it only speeds up 4 steps out of the2*log2(n)
steps that are done in the pscan. But for small sequences, like here in practice, this is cool to have !torch.flip
operations on tensors of shape(B, D, L, N),
it gives us a nice speed improvement.The memory footprint is not touched and is exactly the same as before. I plan to adress this in the second part of the performance update.
Some more precise benchmarks:
(time to do a forward+backward of a whole
Mamba
model,d_state=16),
A100 80GB, with torch profiler)(B, L, D) = (16, 64, 128), n_layers=8
(B, L, D) = (16, 1024, 128), n_layers=16
(B, L, D) = (16, 1024, 1024), n_layers=8
Forward+backward is still numerically equivalent to the original Mamba implementation.
I haven't updated the Performance section of the
README.md
. (will do after the second part of the update).Note that this update don't affect the asymptotic behavior of mama.py.