Skip to content

Commit

Permalink
Cherry-picking ZNE work from main (#1128)
Browse files Browse the repository at this point in the history
**Context:** We need the ZNE related commits from main.

**Description of the Change:** Cherry-pick them.

---------

Co-authored-by: Daniel Strano <stranoj@gmail.com>
Co-authored-by: Alessandro Cosentino <cosenal@gmail.com>
Co-authored-by: Romain Moyard <rmoyard@gmail.com>
Co-authored-by: Joey Carter <joseph.carter@xanadu.ai>
Co-authored-by: Josh Izaac <josh146@gmail.com>
Co-authored-by: paul0403 <79805239+paul0403@users.noreply.github.com>
  • Loading branch information
7 people authored Sep 12, 2024
1 parent 249d53a commit 2c45243
Show file tree
Hide file tree
Showing 30 changed files with 3,195 additions and 734 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check-pl-compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
git checkout $(git tag | sort -V | tail -1)
- if: ${{ inputs.catalyst == 'release-candidate' }}
run: |
git checkout v0.8.0-rc
git checkout v0.8.1-rc
- name: Install deps
run: |
Expand Down
1 change: 0 additions & 1 deletion doc/dev/jax_integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ that doesn't work with Catalyst includes:

- ``jax.numpy.polyfit``
- ``jax.numpy.fft``
- ``jax.scipy.linalg``
- ``jax.numpy.ndarray.at[index]`` when ``index`` corresponds to all array
indices.

Expand Down
79 changes: 78 additions & 1 deletion doc/releases/changelog-0.8.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,91 @@

<h3>New features</h3>

* The `catalyst.mitigate_with_zne` error mitigation compilation pass now supports
the option to fold gates locally as well as the existing method of globally.
[(#1006)](https://github.com/PennyLaneAI/catalyst/pull/1006)
[(#1129)](https://github.com/PennyLaneAI/catalyst/pull/1129)

While global folding applies the scale factor by forming the inverse of the
entire quantum circuit (without measurements) and repeating
the circuit with its inverse, local folding instead inserts per-gate folding sequences directly in place
of each gate in the original circuit.

For example,

```python
import jax
import pennylane as qml
from catalyst import qjit, mitigate_with_zne
from pennylane.transforms import exponential_extrapolate

dev = qml.device("lightning.qubit", wires=4, shots=5)

@qml.qnode(dev)
def circuit():
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliY(wires=0))

@qjit(keep_intermediate=True)
def mitigated_circuit():
s = jax.numpy.array([1, 2, 3])
return mitigate_with_zne(
circuit,
scale_factors=s,
extrapolate=exponential_extrapolate,
folding="local-all" # "local-all" for local on all gates or "global" for the original method (default being "global")
)()
```

```pycon
>>> circuit()
>>> mitigated_circuit()
```

<h3>Improvements</h3>

* Fixes an issue where certain JAX linear algebra functions from `jax.scipy.linalg` gave incorrect
results when invoked from within a qjit block, and adds full support for other `jax.scipy.linalg`
functions.
[(#1097)](https://github.com/PennyLaneAI/catalyst/pull/1097)

The supported linear algebra functions include, but are not limited to:

- [`jax.scipy.linalg.cholesky`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cholesky.html)
- [`jax.scipy.linalg.expm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.expm.html)
- [`jax.scipy.linalg.funm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.funm.html)
- [`jax.scipy.linalg.hessenberg`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.hessenberg.html)
- [`jax.scipy.linalg.lu`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu.html)
- [`jax.scipy.linalg.lu_solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu_solve.html)
- [`jax.scipy.linalg.polar`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.polar.html)
- [`jax.scipy.linalg.qr`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.qr.html)
- [`jax.scipy.linalg.schur`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.schur.html)
- [`jax.scipy.linalg.solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.solve.html)
- [`jax.scipy.linalg.sqrtm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.sqrtm.html)
- [`jax.scipy.linalg.svd`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.svd.html)

<h3>Breaking changes</h3>

<h3>Deprecations</h3>
* The argument `scale_factors` of `mitigate_with_zne` function now follows the proper literature
definition. It now needs to be a list of positive odd integers, as we don't support the fractional
part.
[(#1120)](https://github.com/PennyLaneAI/catalyst/pull/1120)

<h3>Bug fixes</h3>

* Those functions calling the `gather_p` primitive (like `jax.scipy.linalg.expm`)
can now be used in multiple qjits in a single program.
[(#1096)](https://github.com/PennyLaneAI/catalyst/pull/1096)

<h3>Contributors</h3>

This release contains contributions from (in alphabetical order):

Joey Carter,
Alessandro Cosentino,
Paul Haochen Wang,
David Ittah,
Romain Moyard,
Daniel Strano,
Raul Torres.
2 changes: 1 addition & 1 deletion frontend/catalyst/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.8.0"
__version__ = "0.8.1"
31 changes: 20 additions & 11 deletions frontend/catalyst/api_extensions/error_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
from catalyst.jax_primitives import Folding, zne_p


def _is_odd_positive(numbers_list):
return all(isinstance(i, int) and i > 0 and i % 2 != 0 for i in numbers_list)


## API ##
def mitigate_with_zne(
fn=None, *, scale_factors=None, extrapolate=None, extrapolate_kwargs=None, folding="global"
Expand All @@ -47,7 +51,7 @@ def mitigate_with_zne(
Args:
fn (qml.QNode): the circuit to be mitigated.
scale_factors (array[int]): the range of noise scale factors used.
scale_factors (list[int]): the range of noise scale factors used.
extrapolate (Callable): A qjit-compatible function taking two sequences as arguments (scale
factors, and results), and returning a float by performing a fitting procedure.
By default, perfect polynomial fitting :func:`~.polynomial_extrapolate` will be used,
Expand All @@ -56,6 +60,7 @@ def mitigate_with_zne(
function.
folding (str): Unitary folding technique to be used to scale the circuit. Possible values:
- global: the global unitary of the input circuit is folded
- local-all: per-gate folding sequences replace original gates in-place in the circuit
Returns:
Callable: A callable object that computes the mitigated of the wrapped :class:`~.QNode`
Expand Down Expand Up @@ -113,10 +118,11 @@ def workflow(weights, s):
return zne_circuit(weights)
>>> weights = jnp.ones([3, 2, 3])
>>> scale_factors = jnp.array([1, 2, 3])
>>> scale_factors = [1, 3, 5]
>>> workflow(weights, scale_factors)
Array(-0.19946598, dtype=float64)
"""

kwargs = copy.copy(locals())
kwargs.pop("fn")

Expand All @@ -128,7 +134,12 @@ def workflow(weights, s):
elif extrapolate_kwargs is not None:
extrapolate = functools.partial(extrapolate, **extrapolate_kwargs)

return ZNE(fn, scale_factors, extrapolate, folding)
if not _is_odd_positive(scale_factors):
raise ValueError("The scale factors must be positive odd integers: {scale_factors}")

num_folds = jnp.array([jnp.floor((s - 1) / 2) for s in scale_factors], dtype=int)

return ZNE(fn, num_folds, extrapolate, folding)


## IMPL ##
Expand All @@ -147,15 +158,15 @@ class ZNE:
def __init__(
self,
fn: Callable,
scale_factors: jnp.ndarray,
num_folds: jnp.ndarray,
extrapolate: Callable[[Sequence[float], Sequence[float]], float],
folding: str,
):
if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")
self.fn = fn
self.__name__ = f"zne.{getattr(fn, '__name__', 'unknown')}"
self.scale_factors = scale_factors
self.num_folds = num_folds
self.extrapolate = extrapolate
self.folding = folding

Expand All @@ -175,14 +186,12 @@ def __call__(self, *args, **kwargs):
except ValueError as e:
raise ValueError(f"Folding type must be one of {list(map(str, Folding))}") from e
# TODO: remove the following check once #755 is completed
if folding != Folding.GLOBAL:
if folding == Folding.RANDOM:
raise NotImplementedError(f"Folding type {folding.value} is being developed")

results = zne_p.bind(
*args_data, self.scale_factors, folding=folding, jaxpr=jaxpr, fn=self.fn
)
float_scale_factors = jnp.array(self.scale_factors, dtype=float)
results = self.extrapolate(float_scale_factors, results[0])
results = zne_p.bind(*args_data, self.num_folds, folding=folding, jaxpr=jaxpr, fn=self.fn)
float_num_folds = jnp.array(self.num_folds, dtype=float)
results = self.extrapolate(float_num_folds, results[0])
# Single measurement
if results.shape == ():
return results
Expand Down
55 changes: 0 additions & 55 deletions frontend/catalyst/jax_extras/jax_scipy_linalg_warnings.py

This file was deleted.

14 changes: 14 additions & 0 deletions frontend/catalyst/jax_extras/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import jax
from jax._src.lax.lax import _nary_lower_hlo
from jax._src.lax.slicing import (
_argnum_weak_type,
_gather_dtype_rule,
_gather_shape_computation,
_is_sorted,
_no_duplicate_dims,
_rank,
_sorted_dims_in_range,
standard_primitive,
)
from jax._src.lib.mlir.dialects import hlo
from jax.core import AbstractValue, Tracer, concrete_aval
Expand All @@ -35,6 +38,7 @@
"_gather_shape_rule_dynamic",
"_sin_lowering2",
"_cos_lowering2",
"gather2_p",
)


Expand Down Expand Up @@ -186,6 +190,16 @@ def _gather_shape_rule_dynamic(
return _gather_shape_computation(indices, dimension_numbers, slice_sizes)


# TODO: See the `_gather_shape_rule_dynamic` comment. Remove once the upstream change is
# applied.
gather2_p = standard_primitive(
_gather_shape_rule_dynamic,
_gather_dtype_rule,
"gather",
weak_type_rule=_argnum_weak_type(0),
)


def _sin_lowering2(ctx, x):
"""Use hlo.sine lowering instead of the new sin lowering from jax 0.4.28"""
return _nary_lower_hlo(hlo.sine, ctx, x)
Expand Down
17 changes: 2 additions & 15 deletions frontend/catalyst/jax_extras/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,7 @@
)
from jax._src.lax.control_flow import _initial_style_jaxpr
from jax._src.lax.lax import _abstractify, cos_p, sin_p
from jax._src.lax.slicing import (
_argnum_weak_type,
_gather_dtype_rule,
_gather_lower,
standard_primitive,
)
from jax._src.lax.slicing import _gather_lower
from jax._src.linear_util import annotate
from jax._src.pjit import _extract_implicit_args, _flat_axes_specs
from jax._src.source_info_util import current as jax_current
Expand Down Expand Up @@ -99,8 +94,8 @@

from catalyst.jax_extras.patches import (
_cos_lowering2,
_gather_shape_rule_dynamic,
_sin_lowering2,
gather2_p,
get_aval2,
)
from catalyst.logging import debug_logger
Expand Down Expand Up @@ -514,14 +509,6 @@ def abstractify(args, kwargs):
in_type = infer_lambda_input_type(axes_specs, flat_args)
return in_type, in_tree

# TODO: See the `_gather_shape_rule_dynamic` comment. Remove once the upstream change is
# applied.
gather2_p = standard_primitive(
_gather_shape_rule_dynamic,
_gather_dtype_rule,
"gather",
weak_type_rule=_argnum_weak_type(0),
)
register_lowering(gather2_p, _gather_lower)

# TBD
Expand Down
10 changes: 5 additions & 5 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ class Folding(Enum):
"""

GLOBAL = "global"
RANDOM = "random"
ALL = "all"
RANDOM = "local-random"
ALL = "local-all"


##############
Expand Down Expand Up @@ -930,7 +930,7 @@ def _folding_attribute(ctx, folding):
ctx = ctx.module_context.context
return ir.OpaqueAttr.get(
"mitigation",
("folding " + Folding(folding).value).encode("utf-8"),
("folding " + Folding(folding).name.lower()).encode("utf-8"),
ir.NoneType.get(ctx),
ctx,
)
Expand All @@ -950,13 +950,13 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn):
symbol_name = func_op.name.value
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
scale_factors = args[-1]
num_folds = args[-1]
return ZneOp(
flat_output_types,
ir.FlatSymbolRefAttr.get(symbol_name),
mlir.flatten_lowering_ir_args(args[0:-1]),
_folding_attribute(ctx, folding),
scale_factors,
num_folds,
).results


Expand Down
Loading

0 comments on commit 2c45243

Please sign in to comment.