Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a qjit-compatible catalyst.vmap function #497

Merged
merged 34 commits into from
Feb 28, 2024
Merged

Conversation

maliasadi
Copy link
Member

@maliasadi maliasadi commented Feb 6, 2024

  • catalyst.vmap(fun) is added to Catalyst and can apply inside qjitted functions backed by catalyst.for_loop
  • catalyst.vmap dispatches to jax.vmap when is called outsize QJIT
  • support in_axes
  • support out_axes
  • support axis_size
  • Support PyTrees as parameters and return values

[sc-55115]

@maliasadi maliasadi added the frontend Pull requests that update the frontend label Feb 6, 2024
@codecov-commenter
Copy link

codecov-commenter commented Feb 6, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.54%. Comparing base (d191ede) to head (6a8a215).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #497   +/-   ##
=======================================
  Coverage   99.54%   99.54%           
=======================================
  Files          51       51           
  Lines        8493     8571   +78     
  Branches      572      598   +26     
=======================================
+ Hits         8454     8532   +78     
  Misses         21       21           
  Partials       18       18           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work 🥇 The main function is really well-structured, although due to the complexity of the vmap options some blocks are still a bit unclear to me.

doc/changelog.md Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
@dime10 dime10 added this to the v0.5.0 milestone Feb 21, 2024
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @maliasadi, this is looking great! I identified a few more issues with the error checking and matching JAX's behaviour, but otherwise the functionality is really nice 🎉

frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Show resolved Hide resolved
frontend/test/pytest/test_vmap.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_vmap.py Show resolved Hide resolved
Copy link
Member

@josh146 josh146 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work @maliasadi 😎

frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_vmap.py Show resolved Hide resolved
doc/changelog.md Outdated Show resolved Hide resolved
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work 💯

frontend/catalyst/pennylane_extensions.py Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Show resolved Hide resolved
@maliasadi maliasadi merged commit 19b5d77 into main Feb 28, 2024
36 checks passed
@maliasadi maliasadi deleted the maa/catalyst-vmap branch February 28, 2024 21:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend Pull requests that update the frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants