Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the issue with dispatching to jax.vmap with qjit(vmap) #569

Merged
merged 2 commits into from
Feb 29, 2024

Conversation

maliasadi
Copy link
Member

This PR fixes the issue with compiling catalyst.qjit(catalyst.vmap(fn))(*args, **kwargs)

For example, this example should pass now

dev = qml.device("lightning.qubit", wires=1)

@qml.qnode(dev)
def circuit(x, y):
    qml.RX(jnp.pi * x[0] + y, wires=0)
    qml.RY(x[1] ** 2, wires=0)
    qml.RX(x[1] * x[2], wires=0)
    return qml.expval(qml.PauliZ(0))

def cost(x, y, z):
    return circuit(x, y) * z

vmapped_cost = qjit(vmap(cost, in_axes=(0, 0, None)))

x = jnp.array([[0.1, 0.2, 0.3],
               [0.4, 0.5, 0.6],
               [0.7, 0.8, 0.9]])
y = jnp.array([jnp.pi, jnp.pi / 2, jnp.pi / 4])

vmapped_cost(x, y, 1)

@maliasadi maliasadi added the frontend Pull requests that update the frontend label Feb 29, 2024
@maliasadi maliasadi added this to the v0.5.0 milestone Feb 29, 2024
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat 👍 Maybe add this PR to the existing changelog entry?

@maliasadi
Copy link
Member Author

@erick-xanadu Feel free to merge it.. 🙌

Copy link

codecov bot commented Feb 29, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.54%. Comparing base (1c754a3) to head (419747c).

Additional details and impacted files
@@            Coverage Diff             @@
##           v0.5.0-rc     #569   +/-   ##
==========================================
  Coverage      99.54%   99.54%           
==========================================
  Files             51       51           
  Lines           8618     8618           
  Branches         606      606           
==========================================
  Hits            8579     8579           
  Misses            21       21           
  Partials          18       18           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@josh146
Copy link
Member

josh146 commented Feb 29, 2024

Can I merge this? :)

@erick-xanadu
Copy link
Contributor

Can I merge this? :)

I want this one first: #568

After this, yes. I'll merge both of them as soon as possible and ping you @josh146 :)

@erick-xanadu erick-xanadu merged commit b267d06 into v0.5.0-rc Feb 29, 2024
32 checks passed
@erick-xanadu erick-xanadu deleted the fix-vmap-dispatch branch February 29, 2024 23:39
erick-xanadu pushed a commit that referenced this pull request Mar 1, 2024
This PR fixes the issue with compiling
`catalyst.qjit(catalyst.vmap(fn))(*args, **kwargs)`


For example, this example should pass now
``` py
dev = qml.device("lightning.qubit", wires=1)

@qml.qnode(dev)
def circuit(x, y):
    qml.RX(jnp.pi * x[0] + y, wires=0)
    qml.RY(x[1] ** 2, wires=0)
    qml.RX(x[1] * x[2], wires=0)
    return qml.expval(qml.PauliZ(0))

def cost(x, y, z):
    return circuit(x, y) * z

vmapped_cost = qjit(vmap(cost, in_axes=(0, 0, None)))

x = jnp.array([[0.1, 0.2, 0.3],
               [0.4, 0.5, 0.6],
               [0.7, 0.8, 0.9]])
y = jnp.array([jnp.pi, jnp.pi / 2, jnp.pi / 4])

vmapped_cost(x, y, 1)
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend Pull requests that update the frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants