-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the issue with dispatching to jax.vmap with qjit(vmap) #569
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat 👍 Maybe add this PR to the existing changelog entry?
@erick-xanadu Feel free to merge it.. 🙌 |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## v0.5.0-rc #569 +/- ##
==========================================
Coverage 99.54% 99.54%
==========================================
Files 51 51
Lines 8618 8618
Branches 606 606
==========================================
Hits 8579 8579
Misses 21 21
Partials 18 18 ☔ View full report in Codecov by Sentry. |
Can I merge this? :) |
This PR fixes the issue with compiling `catalyst.qjit(catalyst.vmap(fn))(*args, **kwargs)` For example, this example should pass now ``` py dev = qml.device("lightning.qubit", wires=1) @qml.qnode(dev) def circuit(x, y): qml.RX(jnp.pi * x[0] + y, wires=0) qml.RY(x[1] ** 2, wires=0) qml.RX(x[1] * x[2], wires=0) return qml.expval(qml.PauliZ(0)) def cost(x, y, z): return circuit(x, y) * z vmapped_cost = qjit(vmap(cost, in_axes=(0, 0, None))) x = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) y = jnp.array([jnp.pi, jnp.pi / 2, jnp.pi / 4]) vmapped_cost(x, y, 1) ```
This PR fixes the issue with compiling
catalyst.qjit(catalyst.vmap(fn))(*args, **kwargs)
For example, this example should pass now