-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a qjit-compatible catalyst.vmap
function
#497
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #497 +/- ##
=======================================
Coverage 99.54% 99.54%
=======================================
Files 51 51
Lines 8493 8571 +78
Branches 572 598 +26
=======================================
+ Hits 8454 8532 +78
Misses 21 21
Partials 18 18 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work 🥇 The main function is really well-structured, although due to the complexity of the vmap options some blocks are still a bit unclear to me.
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @maliasadi, this is looking great! I identified a few more issues with the error checking and matching JAX's behaviour, but otherwise the functionality is really nice 🎉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work @maliasadi 😎
Co-authored-by: Josh Izaac <josh146@gmail.com> Co-authored-by: David Ittah <dime10@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work 💯
catalyst.vmap(fun)
is added to Catalyst and can apply inside qjitted functions backed bycatalyst.for_loop
catalyst.vmap
dispatches tojax.vmap
when is called outsize QJITin_axes
out_axes
axis_size
[sc-55115]