Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZNE local folding #1006

Merged
merged 106 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
eb90c2d
Folding type in ZNE mitigation as an MLIR enum
cosenal Jul 11, 2024
e904a4c
attempt to bind python enum to mlir
cosenal Jul 18, 2024
5312cd4
Code reuse for local folding algorithm
WrathfulSpatula Jul 18, 2024
c4fded9
allow reading attrs from mitigation dialect
cosenal Jul 19, 2024
ab2a573
style fixes suggested by codefactor
cosenal Jul 22, 2024
fcd6cd0
trim trailing whitespace
cosenal Jul 22, 2024
e033ce8
test folding argument error
cosenal Jul 22, 2024
1762222
more style fixes
cosenal Jul 22, 2024
ffa96c7
document folding argument
cosenal Jul 22, 2024
1dd49b2
run make format
cosenal Jul 22, 2024
da9c63c
Merge branch 'main' into mlir-zne-folding-argument
cosenal Jul 22, 2024
5925d27
Merge branch 'mlir-zne-folding-argument' into local-folding
WrathfulSpatula Jul 22, 2024
c3089bf
Draft random/local folding branch points
WrathfulSpatula Jul 22, 2024
e89bbec
Fix function prototypes
WrathfulSpatula Jul 22, 2024
7e5a85d
Function prototypes
WrathfulSpatula Jul 23, 2024
4bb1983
misc addressing review comments
cosenal Jul 23, 2024
5e84719
make format
WrathfulSpatula Jul 23, 2024
aeebc50
Testing that refactor works
WrathfulSpatula Jul 23, 2024
23a1edd
doc folding enum
cosenal Jul 23, 2024
4711219
Cleaner branching
WrathfulSpatula Jul 23, 2024
0030d11
ditch StrEnum as not supported in Python 3.10
cosenal Jul 23, 2024
d4862e1
Throw on request for random folding (for now)
WrathfulSpatula Jul 23, 2024
567589f
More (incremental) code reuse
WrathfulSpatula Jul 23, 2024
caa1d63
More (incremental) code reuse
WrathfulSpatula Jul 23, 2024
369a3fa
Better (incremental) code reuse
WrathfulSpatula Jul 23, 2024
5e03c0d
Remove unused argument
WrathfulSpatula Jul 23, 2024
fc8ed0f
Merge branch 'main' into mlir-zne-folding-argument
cosenal Jul 23, 2024
afe58f8
update changelog
cosenal Jul 23, 2024
e365323
Helpful comment about *.cpp module function scope
WrathfulSpatula Jul 23, 2024
01fa69c
Don't repeat any code when avoidable
WrathfulSpatula Jul 23, 2024
7dcbb1a
Better code reuse
WrathfulSpatula Jul 23, 2024
2d2f265
Don't need deviceInitOp in functions
WrathfulSpatula Jul 23, 2024
8ef9cde
Don't need deviceInitOp in functions
WrathfulSpatula Jul 23, 2024
f892041
Merge from @cosenal branch
WrathfulSpatula Jul 30, 2024
560b014
Simplify folding functions
WrathfulSpatula Jul 30, 2024
559317e
Incomplete code
WrathfulSpatula Jul 31, 2024
fac618a
MLIR test for ZNE local folding
cosenal Jul 31, 2024
eeb7f8c
Fix function prototypes/calls (doesn't compile)
WrathfulSpatula Aug 1, 2024
3f39679
Fix AdjointOp creation
WrathfulSpatula Aug 1, 2024
1c80a29
Use CallOp
WrathfulSpatula Aug 1, 2024
5d6f07c
Fix walk() on QubitUnitaryOp
WrathfulSpatula Aug 1, 2024
57c1608
Merge main
WrathfulSpatula Aug 1, 2024
4ab0131
Code reuse
WrathfulSpatula Aug 1, 2024
48c5411
Code reuse
WrathfulSpatula Aug 1, 2024
bb12285
Cut redundant rewriter.setInsertionPointAfter()
WrathfulSpatula Aug 1, 2024
5114e28
Advance walk()
WrathfulSpatula Aug 1, 2024
8bedc73
improve test with FileCheck best practices
cosenal Aug 2, 2024
cb7a3d6
Merge branch 'main' into zne-local-folding-tests
cosenal Aug 6, 2024
5b705dd
leftover from merge
cosenal Aug 6, 2024
f008c89
fix test label
cosenal Aug 6, 2024
ffbf1ec
labels in local folding mlir test
cosenal Aug 6, 2024
886ae14
Merge
WrathfulSpatula Aug 6, 2024
f4414af
Merge
WrathfulSpatula Aug 6, 2024
7b189d5
Debug
WrathfulSpatula Aug 6, 2024
942d38b
Generates local folding function
WrathfulSpatula Aug 6, 2024
f668fbb
Cut fnWithoutMeasurementsOp from local folding
WrathfulSpatula Aug 6, 2024
2276c5f
Use builder in for loops
WrathfulSpatula Aug 6, 2024
81d1978
Draft local folding (doesn't work)
WrathfulSpatula Aug 7, 2024
a9d8476
fix indentation
cosenal Aug 13, 2024
3f20d12
Save insertion point
WrathfulSpatula Aug 15, 2024
fdf0f71
Get args from fnFoldedOp
WrathfulSpatula Aug 19, 2024
34c5df9
Pass qubits from result to input
WrathfulSpatula Aug 21, 2024
9d9f205
Refactor Location usage
WrathfulSpatula Aug 21, 2024
feeae40
Fix AdjointOP
WrathfulSpatula Aug 21, 2024
dab455a
Merge main
WrathfulSpatula Aug 22, 2024
7a13ab5
Passing qubit args through ForOp
WrathfulSpatula Aug 22, 2024
8aab961
Use setQubitOperands() on clone
WrathfulSpatula Aug 22, 2024
af0f385
Move LCVs inside fnWithMeasurementsOp
WrathfulSpatula Aug 24, 2024
1f4bd4b
Revert "Move LCVs inside fnWithMeasurementsOp"
WrathfulSpatula Aug 24, 2024
2e24e07
Follow TDD
WrathfulSpatula Aug 26, 2024
a5d0cda
Debug
WrathfulSpatula Aug 26, 2024
9a21a0e
Debug
WrathfulSpatula Aug 26, 2024
88ebb6e
Debug
WrathfulSpatula Aug 26, 2024
e8b9ae3
ForOp
WrathfulSpatula Aug 27, 2024
7a7ca56
Fix local folding ForOp
WrathfulSpatula Aug 27, 2024
e055799
Just copy @circuit
WrathfulSpatula Aug 28, 2024
cbc95cb
Cut unused args
WrathfulSpatula Aug 28, 2024
d97ebc3
All but func::ReturnOp
WrathfulSpatula Aug 28, 2024
7a687f8
Generates code for @circuit.folded
WrathfulSpatula Aug 28, 2024
1aaa73f
No alloc. func. for local folding
WrathfulSpatula Aug 28, 2024
2ba29db
Cut redundant code in ZNE rewrite
WrathfulSpatula Aug 28, 2024
bc0c517
Update mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp
WrathfulSpatula Aug 30, 2024
facfec1
Update mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp
WrathfulSpatula Aug 30, 2024
4d5735b
Update mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp
WrathfulSpatula Aug 30, 2024
9e945b5
Per @rmoyard review
WrathfulSpatula Aug 30, 2024
1e9e95b
Per @rmoyard review
WrathfulSpatula Aug 30, 2024
ea78c00
Per @rmoyard review
WrathfulSpatula Sep 3, 2024
204386d
Per @cosenal review
WrathfulSpatula Sep 3, 2024
301f2a2
Partial fix for unit test
WrathfulSpatula Sep 3, 2024
b10aaef
Partial fix for unit test
WrathfulSpatula Sep 3, 2024
9d296ea
Partial unit test fix
WrathfulSpatula Sep 4, 2024
f5b2c40
Partial unit test fix
WrathfulSpatula Sep 4, 2024
3642b8d
Partial unit test fix
WrathfulSpatula Sep 4, 2024
d81bdda
Passing unit test
WrathfulSpatula Sep 4, 2024
b6acb1c
Unit test for adjoint and more qubits, per @rmoyard review
WrathfulSpatula Sep 5, 2024
51039cb
Merge main
WrathfulSpatula Sep 5, 2024
ce2ab05
Pytest parameterization for local folding
WrathfulSpatula Sep 5, 2024
4dc195c
Update mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp
WrathfulSpatula Sep 6, 2024
53a95c1
Merge branch 'main' into test-local-folding
WrathfulSpatula Sep 6, 2024
3c23dc7
Local folding docs (per @rmoyard review)
WrathfulSpatula Sep 6, 2024
2c1192a
Per @cosenal review
WrathfulSpatula Sep 6, 2024
4bdad2b
Per @cosenal review
WrathfulSpatula Sep 6, 2024
bde26f4
Per @cosenal review
WrathfulSpatula Sep 6, 2024
43ca208
Update doc/releases/changelog-dev.md
WrathfulSpatula Sep 6, 2024
d11cd9d
make format
WrathfulSpatula Sep 6, 2024
9955872
fix indentation in changelog
cosenal Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,40 @@
Array([[1], [0], [1], [1], [0], [1],[0]], dtype=int64))
```

* Zero-Noise Extrapolation (ZNE) local folding: Introduces the option to fold gates locally as well as the existing method of globally. Global folding (as in previous versions) applies the scale factor by forming the inverse of the entire quantum circuit (without measurements) and repeating the circuit with its inverse; local folding inserts per-gate folding sequences directly in place of each gate in the original circuit instead of applying the scale factor to the entire circuit at once. [(#1006)](https://github.com/PennyLaneAI/catalyst/pull/1006)

For example,

```python
import jax
import pennylane as qml
from catalyst import qjit, mitigate_with_zne
from pennylane.transforms import exponential_extrapolate

dev = qml.device("lightning.qubit", wires=4, shots=5)

@qml.qnode(dev)
def circuit():
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliY(wires=0))

@qjit(keep_intermediate=True)
def mitigated_circuit():
s = jax.numpy.array([1, 2, 3])
return mitigate_with_zne(
circuit,
scale_factors=s,
extrapolate=exponential_extrapolate,
folding="all" # "all" for local or "global" for the original method (default being "global")
)()
```

```pycon
>>> circuit()
>>> mitigated_circuit()
```

<h3>Improvements</h3>

<h3>Breaking changes</h3>
Expand Down Expand Up @@ -56,3 +90,4 @@ Romain Moyard,
Erick Ochoa Lopez,
Paul Haochen Wang,
Sengthai Heng,
Daniel Strano
4 changes: 3 additions & 1 deletion frontend/catalyst/api_extensions/error_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def mitigate_with_zne(
function.
folding (str): Unitary folding technique to be used to scale the circuit. Possible values:
- global: the global unitary of the input circuit is folded
- all: per-gate folding sequences replace original gates in-place in the circuit

Returns:
Callable: A callable object that computes the mitigated of the wrapped :class:`~.QNode`
Expand Down Expand Up @@ -117,6 +118,7 @@ def workflow(weights, s):
>>> workflow(weights, scale_factors)
Array(-0.19946598, dtype=float64)
"""

kwargs = copy.copy(locals())
kwargs.pop("fn")

Expand Down Expand Up @@ -175,7 +177,7 @@ def __call__(self, *args, **kwargs):
except ValueError as e:
raise ValueError(f"Folding type must be one of {list(map(str, Folding))}") from e
# TODO: remove the following check once #755 is completed
if folding != Folding.GLOBAL:
if folding == Folding.RANDOM:
raise NotImplementedError(f"Folding type {folding.value} is being developed")

results = zne_p.bind(
Expand Down
34 changes: 22 additions & 12 deletions frontend/test/pytest/test_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def skip_if_exponential_extrapolation_unstable(circuit_param, extrapolation_func

@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_single_measurement(params, extrapolation):
@pytest.mark.parametrize("folding", ["global", "all"])
def test_single_measurement(params, extrapolation, folding):
"""Test that without noise the same results are returned for single measurements."""
skip_if_exponential_extrapolation_unstable(params, extrapolation)

Expand All @@ -52,15 +53,19 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
circuit,
scale_factors=jax.numpy.array([1, 2, 3]),
extrapolate=extrapolation,
folding=folding,
)(args)

assert np.allclose(mitigated_qnode(params), circuit(params))


@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_multiple_measurements(params, extrapolation):
@pytest.mark.parametrize("folding", ["global", "all"])
def test_multiple_measurements(params, extrapolation, folding):
"""Test that without noise the same results are returned for multiple measurements"""
skip_if_exponential_extrapolation_unstable(params, extrapolation)

Expand All @@ -78,14 +83,18 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
circuit,
scale_factors=jax.numpy.array([1, 2, 3]),
extrapolate=extrapolation,
folding=folding,
)(args)

assert np.allclose(mitigated_qnode(params), circuit(params))


@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
def test_single_measurement_control_flow(params):
@pytest.mark.parametrize("folding", ["global", "all"])
def test_single_measurement_control_flow(params, folding):
"""Test that without noise the same results are returned for single measurement and with
control flow."""
dev = qml.device("lightning.qubit", wires=2)
Expand Down Expand Up @@ -113,9 +122,9 @@ def loop_1(i): # pylint: disable=unused-argument

@catalyst.qjit
def mitigated_qnode(args, n):
return catalyst.mitigate_with_zne(circuit, scale_factors=jax.numpy.array([1, 2, 3]))(
args, n
)
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), folding=folding
)(args, n)

assert np.allclose(mitigated_qnode(params, 3), catalyst.qjit(circuit)(params, 3))

Expand Down Expand Up @@ -238,15 +247,16 @@ def circuit():
return 0.0

def mitigated_qnode():
return catalyst.mitigate_with_zne(circuit, scale_factors=[], folding="all")()
return catalyst.mitigate_with_zne(circuit, scale_factors=[], folding="random")()

with pytest.raises(NotImplementedError):
catalyst.qjit(mitigated_qnode)


@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_zne_usage_patterns(params, extrapolation):
@pytest.mark.parametrize("folding", ["global", "all"])
def test_zne_usage_patterns(params, extrapolation, folding):
"""Test usage patterns of catalyst.zne."""
skip_if_exponential_extrapolation_unstable(params, extrapolation)

Expand All @@ -264,13 +274,13 @@ def fn(x):
@catalyst.qjit
def mitigated_qnode_fn_as_argument(args):
return catalyst.mitigate_with_zne(
fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation, folding=folding
)(args)

@catalyst.qjit
def mitigated_qnode_partial(args):
return catalyst.mitigate_with_zne(
scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation, folding=folding
)(fn)(args)

assert np.allclose(mitigated_qnode_fn_as_argument(params), fn(params))
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/Mitigation/IR/MitigationOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def Folding : I32EnumAttr<"Folding",
"Folding types",
[
I32EnumAttrCase<"global", 1>,
I32EnumAttrCase<"random", 2>,
I32EnumAttrCase<"all", 3>,
I32EnumAttrCase<"all", 2>,
I32EnumAttrCase<"random", 3>,
]> {
let cppNamespace = "catalyst::mitigation";
let genSpecializedAttr = 0;
Expand Down
Loading
Loading