Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the issue with dispatching to jax.vmap with qjit(vmap) #569

Merged
merged 2 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
* Catalyst now supports QJIT compatible `catalyst.vmap` of hybrid programs.
`catalyst.vmap` offers the vectorization mapping backed by `catalyst.for_loop`.
[(#497)](https://github.com/PennyLaneAI/catalyst/pull/497)
[(#569)](https://github.com/PennyLaneAI/catalyst/pull/569)

For example,

Expand Down
9 changes: 5 additions & 4 deletions frontend/catalyst/pennylane_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2229,10 +2229,7 @@ def postcircuit(y, x, z):
in the output. ``out_axes`` is subject to the same modes as well.
"""

# Dispatch to jax.vmap when it is called outside qjit.
if not EvaluationContext.is_tracing():
return jax.vmap(fn, in_axes, out_axes)

# Check the validity of in_axes and out_axes
if not all(isinstance(l, int) for l in tree_leaves(in_axes)):
raise ValueError(
"Invalid 'in_axes'; it must be an int or a tuple of PyTrees with integer leaves, "
Expand All @@ -2248,6 +2245,10 @@ def postcircuit(y, x, z):
def batched_fn(*args, **kwargs):
"""Vectorization wrapper around the hybrid program using catalyst.for_loop"""

# Dispatch to jax.vmap when it is called outside qjit.
if not EvaluationContext.is_tracing():
return jax.vmap(fn, in_axes, out_axes)(*args, **kwargs)

args_flat, args_tree = tree_flatten(args)
in_axes_flat, _ = tree_flatten(in_axes, is_leaf=lambda x: x is None)

Expand Down
24 changes: 24 additions & 0 deletions frontend/test/pytest/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,30 @@ def circuit(x):
assert jnp.allclose(result[1], expected)
assert jnp.allclose(result[2], expected)

def test_vmap_circuit_inside_without_jax_dispatch(self, backend):
"""Test catalyst.vmap of a hybrid workflow inside QJIT."""

@qml.qnode(qml.device(backend, wires=1))
def circuit(x):
qml.RX(jnp.pi * x[0], wires=0)
qml.RY(x[1] ** 2, wires=0)
qml.RX(x[1] * x[2], wires=0)
return qml.expval(qml.PauliZ(0))

x = jnp.array(
[
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9],
]
)

result0 = qjit(vmap(circuit))(x)
result1 = qjit(vmap(circuit, in_axes=(0,)))(x)
expected = jnp.array([0.93005586, 0.00498127, -0.88789978])
assert jnp.allclose(result0, expected)
assert jnp.allclose(result1, expected)

def test_vmap_circuit_in_axes_int(self, backend):
"""Test catalyst.vmap of a hybrid workflow inside QJIT with `in_axes:int`."""

Expand Down
Loading