Skip to content

Commit

Permalink
Support for the new device API (#565)
Browse files Browse the repository at this point in the history
**Context:**

PennyLane recently added a new device API.

**Description of the Change:**

In this PR we create anothe QJIT device that supports the new device
API. Users define preprocessing by adding transforms to the preprocess
method.

We can add our own transforms in the QJIT device preprocess function, if
necessary.
  • Loading branch information
rmoyard authored Mar 7, 2024
1 parent bc401d9 commit 5e6d6bd
Show file tree
Hide file tree
Showing 8 changed files with 434 additions and 52 deletions.
7 changes: 6 additions & 1 deletion doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

<h3>Improvements</h3>

* Catalyst now supports devices built from the
[new PennyLane device API](https://docs.pennylane.ai/en/stable/code/api/pennylane.devices.Device.html).
[(#565)](https://github.com/PennyLaneAI/catalyst/pull/565)

* Catalyst now supports return statements inside conditionals in `@qjit(autograph=True)` compiled
functions.
[(#583)](https://github.com/PennyLaneAI/catalyst/pull/583)
Expand Down Expand Up @@ -36,7 +40,8 @@

This release contains contributions from (in alphabetical order):

David Ittah.
David Ittah,
Romain Moyard.

# Release 0.5.0

Expand Down
33 changes: 31 additions & 2 deletions doc/dev/custom_devices.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ Integration with Python devices

There are two things that are needed in order to integrate with PennyLane devices:

* Adding a ``get_c_interface`` method to your ``qml.QubitDevice`` class.
* Adding a ``get_c_interface`` method to your ``qml.Device`` or ``qml.devices.Device`` class.
* Adding a ``config`` class variable pointing to your configuration file. This file should be a `toml file <https://toml.io/en/>`_ with fields that describe what gates and features are supported by your device.

If you already have a custom PennyLane device defined in Python and have added a shared object that corresponds to your implementation of the ``QuantumDevice`` class, then all you need to do is to add a ``get_c_interface`` method to your PennyLane device.
Expand All @@ -163,9 +163,11 @@ The ``get_c_interface`` method should be a static method that takes no parameter
The first result of ``get_c_interface`` needs to match the ``<DeviceIdentifier>``
as described in the first section.

With the old device API, you can simply build a QJIT compatible device:

.. code-block:: python
class CustomDevice(qml.QubitDevice):
class CustomDevice(qml.Device):
"""Dummy Device"""
name = "Dummy Device"
Expand Down Expand Up @@ -194,6 +196,33 @@ The ``get_c_interface`` method should be a static method that takes no parameter
def f():
return measure(0)
or with the new device API:

.. code-block:: python
class CustomDevice(qml.devices.Device):
"""Dummy Device"""
config = pathlib.Path("absolute/path/to/configuration/file.toml")
@staticmethod
def get_c_interface():
""" Returns a tuple consisting of the device name, and
the location to the shared object with the C/C++ device implementation.
"""
return "CustomDevice", "absolute/path/to/libdummy_device.so"
def __init__(self, shots=None, wires=None):
super().__init__(wires=wires, shots=shots)
def execute(self, circuits, config):
"""Your normal definitions"""
@qjit
@qml.qnode(CustomDevice(wires=1))
def f():
return measure(0)
Below is an example configuration file with inline descriptions of how to fill out the fields. All
headers and fields are generally required, unless stated otherwise.
Expand Down
46 changes: 33 additions & 13 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pennylane.operation import AnyWires, Operation, Wires
from pennylane.ops import Controlled, ControlledOp, ControlledQubitUnitary
from pennylane.tape import QuantumTape
from pennylane.transforms.core import TransformProgram

import catalyst
from catalyst.jax_extras import (
Expand Down Expand Up @@ -453,7 +454,12 @@ def _bind_native_controlled_op(qrp, op, controlled_wires, controlled_values):
return qrp

qrp = QRegPromise(qreg)
for op in device.expand_fn(quantum_tape):
if isinstance(device, qml.Device):
ops = device.expand_fn(quantum_tape)
else:
ops = quantum_tape

for op in ops:
qrp2 = None
if isinstance(op, HybridOp):
qrp2 = op.trace_quantum(ctx, device, trace, qrp)
Expand Down Expand Up @@ -577,6 +583,7 @@ def trace_quantum_measurements(
qrp: QRegPromise,
outputs: List[Union[MeasurementProcess, DynamicJaxprTracer, Any]],
out_tree: PyTreeDef,
tape: QuantumTape,
) -> Tuple[List[DynamicJaxprTracer], PyTreeDef]:
"""Trace quantum measurement. Accept a list of QNode ouptputs and its Pytree-shape. Process
the quantum measurement outputs, leave other outputs as-is.
Expand All @@ -592,7 +599,11 @@ def trace_quantum_measurements(
out_tree: modified PyTree-shape of the qnode output.
"""
# pylint: disable=too-many-branches
shots = device.shots
if isinstance(device, qml.Device):
shots = device.shots
else:
# TODO: support shot vectors
shots = tape.shots.total_shots
out_classical_tracers = []

for i, o in enumerate(outputs):
Expand Down Expand Up @@ -686,23 +697,23 @@ def is_midcircuit_measurement(op):
return are_batch_transforms_valid


def apply_transform(qnode, tape, flat_results):
def apply_transform(transform_program, tape, flat_results):
"""Apply transform."""

# Some transforms use trainability as a basis for transforming.
# See batch_params
params = tape.get_parameters(trainable_only=False)
tape.trainable_params = qml.math.get_trainable_indices(params)

is_program_transformed = qnode and qnode.transform_program
is_program_transformed = transform_program

if is_program_transformed and qnode.transform_program.is_informative:
if is_program_transformed and transform_program.is_informative:
msg = "Catalyst does not support informative transforms."
raise CompileError(msg)

if is_program_transformed:
is_valid_for_batch = is_transform_valid_for_batch_transforms(tape, flat_results)
tapes, post_processing = qnode.transform_program([tape])
tapes, post_processing = transform_program([tape])
if not is_valid_for_batch and len(tapes) > 1:
msg = "Multiple tapes are generated, but each run might produce different results."
raise CompileError(msg)
Expand Down Expand Up @@ -803,6 +814,7 @@ def trace_quantum_function(
device (QubitDevice): Quantum device to use for quantum computations
args: Positional arguments to pass to ``f``
kwargs: Keyword arguments to pass to ``f``
qnode: The quantum node to be traced, it contains user transforms.
Returns:
closed_jaxpr: JAXPR expression of the function ``f``.
Expand Down Expand Up @@ -838,15 +850,23 @@ def is_leaf(obj):
return_values, is_leaf=is_leaf
)

# TODO: In order to compose transforms, we would need to recursively
# call apply_transform while popping the latest transform applied,
# until there are no more transforms to be applied.
# But first we should clean this up this method a bit more.
tapes, post_processing = apply_transform(qnode, quantum_tape, return_values_flat)
if isinstance(device, qml.devices.Device):
transform_program, _ = device.preprocess()
else:
transform_program = TransformProgram()

# We add pragma because lit test are giving qfunc directly
# But lit tests are not sending coverage results
if qnode: # pragma: no branch
transform_program = qnode.transform_program + transform_program

tapes, post_processing = apply_transform(
transform_program, quantum_tape, return_values_flat
)

# (2) - Quantum tracing
transformed_results = []
is_program_transformed = qnode and qnode.transform_program
is_program_transformed = transform_program

with EvaluationContext.frame_tracing_context(ctx, trace):
# Set up same device and quantum register for all tapes in the program.
Expand All @@ -871,7 +891,7 @@ def is_leaf(obj):
trees = return_values_tree

qrp_out = trace_quantum_tape(tape, device, qreg_in, ctx, trace)
meas, meas_trees = trace_quantum_measurements(device, qrp_out, output, trees)
meas, meas_trees = trace_quantum_measurements(device, qrp_out, output, trees, tape)
qreg_out = qrp_out.actualize()

meas_tracers = [trace.full_raise(m) for m in meas]
Expand Down
7 changes: 5 additions & 2 deletions frontend/catalyst/pennylane_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
trace_quantum_tape,
unify_jaxpr_result_types,
)
from catalyst.qjit_device import QJITDevice
from catalyst.qjit_device import QJITDevice, QJITDeviceNewAPI
from catalyst.tracing.contexts import (
EvaluationContext,
EvaluationMode,
Expand Down Expand Up @@ -161,7 +161,10 @@ def __call__(self, *args, **kwargs):
QFunc._add_toml_file(self.device)
dev_args = QFunc.extract_backend_info(self.device)
config, rest = dev_args[0], dev_args[1:]
device = QJITDevice(config, self.device.shots, self.device.wires, *rest)
if isinstance(self.device, qml.devices.Device):
device = QJITDeviceNewAPI(self.device, config, *rest)
else:
device = QJITDevice(config, self.device.shots, self.device.wires, *rest)
else: # pragma: nocover
# Allow QFunc to still be used by itself for internal testing.
device = self.device
Expand Down
Loading

0 comments on commit 5e6d6bd

Please sign in to comment.