Skip to content

Commit

Permalink
Fix the issue with dispatching to jax.vmap with qjit(vmap) (#569)
Browse files Browse the repository at this point in the history
This PR fixes the issue with compiling
`catalyst.qjit(catalyst.vmap(fn))(*args, **kwargs)`


For example, this example should pass now
``` py
dev = qml.device("lightning.qubit", wires=1)

@qml.qnode(dev)
def circuit(x, y):
    qml.RX(jnp.pi * x[0] + y, wires=0)
    qml.RY(x[1] ** 2, wires=0)
    qml.RX(x[1] * x[2], wires=0)
    return qml.expval(qml.PauliZ(0))

def cost(x, y, z):
    return circuit(x, y) * z

vmapped_cost = qjit(vmap(cost, in_axes=(0, 0, None)))

x = jnp.array([[0.1, 0.2, 0.3],
               [0.4, 0.5, 0.6],
               [0.7, 0.8, 0.9]])
y = jnp.array([jnp.pi, jnp.pi / 2, jnp.pi / 4])

vmapped_cost(x, y, 1)
```
  • Loading branch information
maliasadi authored and erick-xanadu committed Mar 1, 2024
1 parent 5d4f7a9 commit c2b7c11
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
1 change: 1 addition & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ This release contains contributions from (in alphabetical order):
* Catalyst now supports QJIT compatible `catalyst.vmap` of hybrid programs.
`catalyst.vmap` offers the vectorization mapping backed by `catalyst.for_loop`.
[(#497)](https://github.com/PennyLaneAI/catalyst/pull/497)
[(#569)](https://github.com/PennyLaneAI/catalyst/pull/569)

For example,

Expand Down
9 changes: 5 additions & 4 deletions frontend/catalyst/pennylane_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2229,10 +2229,7 @@ def postcircuit(y, x, z):
in the output. ``out_axes`` is subject to the same modes as well.
"""

# Dispatch to jax.vmap when it is called outside qjit.
if not EvaluationContext.is_tracing():
return jax.vmap(fn, in_axes, out_axes)

# Check the validity of in_axes and out_axes
if not all(isinstance(l, int) for l in tree_leaves(in_axes)):
raise ValueError(
"Invalid 'in_axes'; it must be an int or a tuple of PyTrees with integer leaves, "
Expand All @@ -2248,6 +2245,10 @@ def postcircuit(y, x, z):
def batched_fn(*args, **kwargs):
"""Vectorization wrapper around the hybrid program using catalyst.for_loop"""

# Dispatch to jax.vmap when it is called outside qjit.
if not EvaluationContext.is_tracing():
return jax.vmap(fn, in_axes, out_axes)(*args, **kwargs)

args_flat, args_tree = tree_flatten(args)
in_axes_flat, _ = tree_flatten(in_axes, is_leaf=lambda x: x is None)

Expand Down
24 changes: 24 additions & 0 deletions frontend/test/pytest/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,30 @@ def circuit(x):
assert jnp.allclose(result[1], expected)
assert jnp.allclose(result[2], expected)

def test_vmap_circuit_inside_without_jax_dispatch(self, backend):
"""Test catalyst.vmap of a hybrid workflow inside QJIT."""

@qml.qnode(qml.device(backend, wires=1))
def circuit(x):
qml.RX(jnp.pi * x[0], wires=0)
qml.RY(x[1] ** 2, wires=0)
qml.RX(x[1] * x[2], wires=0)
return qml.expval(qml.PauliZ(0))

x = jnp.array(
[
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9],
]
)

result0 = qjit(vmap(circuit))(x)
result1 = qjit(vmap(circuit, in_axes=(0,)))(x)
expected = jnp.array([0.93005586, 0.00498127, -0.88789978])
assert jnp.allclose(result0, expected)
assert jnp.allclose(result1, expected)

def test_vmap_circuit_in_axes_int(self, backend):
"""Test catalyst.vmap of a hybrid workflow inside QJIT with `in_axes:int`."""

Expand Down

0 comments on commit c2b7c11

Please sign in to comment.