From eb90c2d9a81bd3dbf780420726528e6bb9445233 Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Thu, 11 Jul 2024 18:11:06 +0200 Subject: [PATCH 01/94] Folding type in ZNE mitigation as an MLIR enum --- .../api_extensions/error_mitigation.py | 12 ++++++++--- frontend/catalyst/jax_primitives.py | 7 +++++-- mlir/include/Mitigation/IR/CMakeLists.txt | 7 +++++++ mlir/include/Mitigation/IR/MitigationOps.h | 3 +++ mlir/include/Mitigation/IR/MitigationOps.td | 20 +++++++++++++++++-- mlir/lib/Mitigation/IR/CMakeLists.txt | 2 +- mlir/lib/Mitigation/IR/MitigationDialect.cpp | 20 +++++++++++++++++++ mlir/lib/Mitigation/IR/MitigationOps.cpp | 7 +++++-- mlir/test/Mitigation/zne.mlir | 2 +- 9 files changed, 69 insertions(+), 11 deletions(-) diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index 280f98aa0d..f59436925d 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -29,9 +29,13 @@ from catalyst.jax_primitives import zne_p +from enum import IntEnum + +class Folding(IntEnum): + GLOBAL = "global" ## API ## -def mitigate_with_zne(fn=None, *, scale_factors=None, extrapolate=None, extrapolate_kwargs=None): +def mitigate_with_zne(fn=None, *, scale_factors=None, extrapolate=None, extrapolate_kwargs=None, folding=Folding.GLOBAL): """A :func:`~.qjit` compatible error mitigation of an input circuit using zero-noise extrapolation. @@ -97,7 +101,7 @@ def mitigated_circuit(args, n): elif extrapolate_kwargs is not None: extrapolate = functools.partial(extrapolate, **extrapolate_kwargs) - return ZNE(fn, scale_factors, extrapolate) + return ZNE(fn, scale_factors, extrapolate, folding) ## IMPL ## @@ -118,6 +122,7 @@ def __init__( fn: Callable, scale_factors: jnp.ndarray, extrapolate: Callable[[Sequence[float], Sequence[float]], float], + folding: Folding ): if not isinstance(fn, qml.QNode): raise TypeError(f"A QNode is expected, got the classical function {fn}") @@ -125,6 +130,7 @@ def __init__( self.__name__ = f"zne.{getattr(fn, '__name__', 'unknown')}" self.scale_factors = scale_factors self.extrapolate = extrapolate + self.folding = folding def __call__(self, *args, **kwargs): """Specifies the an actual call to the folded circuit.""" @@ -137,7 +143,7 @@ def __call__(self, *args, **kwargs): if len(set_dtypes) != 1 or set_dtypes.pop().kind != "f": raise TypeError("All expectation and classical values dtypes must match and be float.") args_data, _ = tree_flatten(args) - results = zne_p.bind(*args_data, self.scale_factors, jaxpr=jaxpr, fn=self.fn) + results = zne_p.bind(*args_data, self.folding, self.scale_factors, jaxpr=jaxpr, fn=self.fn) float_scale_factors = jnp.array(self.scale_factors, dtype=float) results = self.extrapolate(float_scale_factors, results[0]) # Single measurement diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index acb8bb8e34..fdac7d701f 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -741,11 +741,14 @@ def _zne_lowering(ctx, *args, jaxpr, fn): symbol_name = func_op.name.value output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) flat_output_types = util.flatten(output_types) + folding = args[-2] + scale_factors = args[-1] return ZneOp( flat_output_types, ir.FlatSymbolRefAttr.get(symbol_name), - mlir.flatten_lowering_ir_args(args[0:-1]), - args[-1], + mlir.flatten_lowering_ir_args(args[0:-2]), + folding.value, + scale_factors, ).results diff --git a/mlir/include/Mitigation/IR/CMakeLists.txt b/mlir/include/Mitigation/IR/CMakeLists.txt index 0c86fe2c55..c50821a6b4 100644 --- a/mlir/include/Mitigation/IR/CMakeLists.txt +++ b/mlir/include/Mitigation/IR/CMakeLists.txt @@ -1,3 +1,10 @@ add_mlir_dialect(MitigationOps mitigation) add_mlir_doc(MitigationDialect MitigationDialect Mitigation/ -gen-dialect-doc) add_mlir_doc(MitigationOps MitigationOps Mitigation/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS MitigationOps.td) +mlir_tablegen(MitigationEnums.h.inc -gen-enum-decls) +mlir_tablegen(MitigationEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(MitigationAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=mitigation) +mlir_tablegen(MitigationAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=mitigation) +add_public_tablegen_target(MLIRMitigationEnumsIncGen) diff --git a/mlir/include/Mitigation/IR/MitigationOps.h b/mlir/include/Mitigation/IR/MitigationOps.h index 61ae25ed58..02e4170664 100644 --- a/mlir/include/Mitigation/IR/MitigationOps.h +++ b/mlir/include/Mitigation/IR/MitigationOps.h @@ -23,5 +23,8 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "Mitigation/IR/MitigationEnums.h.inc" +#define GET_ATTRDEF_CLASSES +#include "Mitigation/IR/MitigationAttributes.h.inc" #define GET_OP_CLASSES #include "Mitigation/IR/MitigationOps.h.inc" diff --git a/mlir/include/Mitigation/IR/MitigationOps.td b/mlir/include/Mitigation/IR/MitigationOps.td index d91ce13284..ccaad50cd2 100644 --- a/mlir/include/Mitigation/IR/MitigationOps.td +++ b/mlir/include/Mitigation/IR/MitigationOps.td @@ -15,13 +15,28 @@ #ifndef MITIGATION_OPS #define MITIGATION_OPS +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/IR/BuiltinAttributes.td" -include "mlir/IR/OpBase.td" include "Mitigation/IR/MitigationDialect.td" +def Folding : I32EnumAttr<"Folding", + "Folding types", + [ + I32EnumAttrCase<"global", 1>, + I32EnumAttrCase<"random", 2>, + I32EnumAttrCase<"all", 3>, + ]> { + let cppNamespace = "catalyst::mitigation"; + let genSpecializedAttr = 0; +} + +def FoldingAttr : EnumAttr; + + def ZneOp : Mitigation_Op<"zne", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "Compute a quantum function with ZNE (Zero Noise Extrapolation) error mitigation."; @@ -32,12 +47,13 @@ def ZneOp : Mitigation_Op<"zne", [DeclareOpInterfaceMethods, let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$args, + FoldingAttr:$folding, RankedTensorOf<[AnySignlessIntegerOrIndex]>:$scaleFactors ); let results = (outs Variadic]>>); let assemblyFormat = [{ - $callee `(` $args `)` `scaleFactors` `(` $scaleFactors `:` type($scaleFactors) `)` attr-dict `:` functional-type($args, results) + $callee `(` $args `)` `folding` `(` $folding `)` `scaleFactors` `(` $scaleFactors `:` type($scaleFactors) `)` attr-dict `:` functional-type($args, results) }]; } diff --git a/mlir/lib/Mitigation/IR/CMakeLists.txt b/mlir/lib/Mitigation/IR/CMakeLists.txt index f3db5a0066..ad0d4a567e 100644 --- a/mlir/lib/Mitigation/IR/CMakeLists.txt +++ b/mlir/lib/Mitigation/IR/CMakeLists.txt @@ -7,5 +7,5 @@ add_mlir_library(MLIRMitigation DEPENDS MLIRMitigationOpsIncGen - + MLIRMitigationEnumsIncGen ) diff --git a/mlir/lib/Mitigation/IR/MitigationDialect.cpp b/mlir/lib/Mitigation/IR/MitigationDialect.cpp index 293ed4eea7..d4bc05ab66 100644 --- a/mlir/lib/Mitigation/IR/MitigationDialect.cpp +++ b/mlir/lib/Mitigation/IR/MitigationDialect.cpp @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mlir/IR/Builders.h" #include "mlir/Transforms/InliningUtils.h" +#include "mlir/IR/DialectImplementation.h" // needed for generated type parser +#include "llvm/ADT/TypeSwitch.h" // needed for generated type parser #include "Mitigation/IR/MitigationDialect.h" #include "Mitigation/IR/MitigationOps.h" @@ -44,9 +47,26 @@ struct MitigationInlinerInterface : public DialectInlinerInterface { void MitigationDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "Mitigation/IR/MitigationOpsTypes.cpp.inc" + >(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "Mitigation/IR/MitigationAttributes.cpp.inc" + >(); + addOperations< #define GET_OP_LIST #include "Mitigation/IR/MitigationOps.cpp.inc" >(); + addInterface(); } + +#define GET_TYPEDEF_CLASSES +#include "Mitigation/IR/MitigationOpsTypes.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "Mitigation/IR/MitigationAttributes.cpp.inc" diff --git a/mlir/lib/Mitigation/IR/MitigationOps.cpp b/mlir/lib/Mitigation/IR/MitigationOps.cpp index 91ff5dc317..21df8a9396 100644 --- a/mlir/lib/Mitigation/IR/MitigationOps.cpp +++ b/mlir/lib/Mitigation/IR/MitigationOps.cpp @@ -14,17 +14,20 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "Mitigation/IR/MitigationDialect.h" #include "Mitigation/IR/MitigationOps.h" +using namespace mlir; +using namespace catalyst::mitigation; + +#include "Mitigation/IR/MitigationEnums.cpp.inc" #define GET_OP_CLASSES #include "Mitigation/IR/MitigationOps.cpp.inc" -using namespace mlir; -using namespace catalyst::mitigation; //===----------------------------------------------------------------------===// // SymbolUserOpInterface diff --git a/mlir/test/Mitigation/zne.mlir b/mlir/test/Mitigation/zne.mlir index a44f35b9cc..0bfe041575 100644 --- a/mlir/test/Mitigation/zne.mlir +++ b/mlir/test/Mitigation/zne.mlir @@ -99,6 +99,6 @@ func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { // CHECK: return [[results]] : tensor<5xf64> func.func @zneCallScalarScalar(%arg0: tensor<3xf64>) -> tensor<5xf64> { %scaleFactors = arith.constant dense<[1, 2, 3, 4, 5]> : tensor<5xindex> - %0 = mitigation.zne @simpleCircuit(%arg0) scaleFactors (%scaleFactors : tensor<5xindex>) : (tensor<3xf64>) -> tensor<5xf64> + %0 = mitigation.zne @simpleCircuit(%arg0) folding (global) scaleFactors (%scaleFactors : tensor<5xindex>) : (tensor<3xf64>) -> tensor<5xf64> func.return %0 : tensor<5xf64> } From e904a4c436c2248ce346e0f0e78cd45bd7bb5adb Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Thu, 18 Jul 2024 16:56:17 +0200 Subject: [PATCH 02/94] attempt to bind python enum to mlir --- .../catalyst/api_extensions/error_mitigation.py | 14 +++++++++----- frontend/catalyst/jax_primitives.py | 11 +++++++++-- mlir/include/Mitigation/IR/MitigationOps.h | 1 + .../Transforms/MitigationMethods/Zne.cpp | 5 ++++- 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index f59436925d..6b65a2cf8d 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -32,10 +32,12 @@ from enum import IntEnum class Folding(IntEnum): - GLOBAL = "global" + GLOBAL = 1 + RANDOM = 2 + ALL = 3 ## API ## -def mitigate_with_zne(fn=None, *, scale_factors=None, extrapolate=None, extrapolate_kwargs=None, folding=Folding.GLOBAL): +def mitigate_with_zne(fn=None, *, scale_factors=None, extrapolate=None, extrapolate_kwargs=None, folding='global'): """A :func:`~.qjit` compatible error mitigation of an input circuit using zero-noise extrapolation. @@ -55,6 +57,7 @@ def mitigate_with_zne(fn=None, *, scale_factors=None, extrapolate=None, extrapol By default, perfect polynomial fitting will be used. extrapolate_kwargs (dict[str, Any]): Keyword arguments to be passed to the extrapolation function. + folding (str): The unitary folding technique to be used to scale the circuit Returns: Callable: A callable object that computes the mitigated of the wrapped :class:`qml.QNode` @@ -122,7 +125,7 @@ def __init__( fn: Callable, scale_factors: jnp.ndarray, extrapolate: Callable[[Sequence[float], Sequence[float]], float], - folding: Folding + folding: str ): if not isinstance(fn, qml.QNode): raise TypeError(f"A QNode is expected, got the classical function {fn}") @@ -134,7 +137,7 @@ def __init__( def __call__(self, *args, **kwargs): """Specifies the an actual call to the folded circuit.""" - jaxpr = jaxpr = jax.make_jaxpr(self.fn)(*args) + jaxpr = jax.make_jaxpr(self.fn)(*args) shapes = [out_val.shape for out_val in jaxpr.out_avals] dtypes = [out_val.dtype for out_val in jaxpr.out_avals] set_dtypes = set(dtypes) @@ -143,7 +146,8 @@ def __call__(self, *args, **kwargs): if len(set_dtypes) != 1 or set_dtypes.pop().kind != "f": raise TypeError("All expectation and classical values dtypes must match and be float.") args_data, _ = tree_flatten(args) - results = zne_p.bind(*args_data, self.folding, self.scale_factors, jaxpr=jaxpr, fn=self.fn) + folding = Folding[self.folding.upper()].value + results = zne_p.bind(*args_data, folding, self.scale_factors, jaxpr=jaxpr, fn=self.fn) float_scale_factors = jnp.array(self.scale_factors, dtype=float) results = self.extrapolate(float_scale_factors, results[0]) # Single measurement diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index fdac7d701f..a910885ed0 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -728,6 +728,12 @@ def _zne_abstract_eval(*args, jaxpr, fn): # pylint: disable=unused-argument return [core.ShapedArray(shape, jaxpr.out_avals[0].dtype)] +def _folding_attribute(ctx, folding: str): + ctx = ctx.module_context.context + return ir.OpaqueAttr.get( + "mitigation", ("folding " + folding).encode("utf-8"), ir.NoneType.get(ctx), ctx + ) + def _zne_lowering(ctx, *args, jaxpr, fn): """Lowering function to the ZNE opearation. Args: @@ -746,8 +752,9 @@ def _zne_lowering(ctx, *args, jaxpr, fn): return ZneOp( flat_output_types, ir.FlatSymbolRefAttr.get(symbol_name), - mlir.flatten_lowering_ir_args(args[0:-2]), - folding.value, + mlir.flatten_lowering_ir_args(args), + # TODO: Once this works, change hardcoded value to actual value from input + _folding_attribute(ctx, "global"), scale_factors, ).results diff --git a/mlir/include/Mitigation/IR/MitigationOps.h b/mlir/include/Mitigation/IR/MitigationOps.h index 02e4170664..dcbbd2ea4f 100644 --- a/mlir/include/Mitigation/IR/MitigationOps.h +++ b/mlir/include/Mitigation/IR/MitigationOps.h @@ -23,6 +23,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "Mitigation/IR/MitigationDialect.h" #include "Mitigation/IR/MitigationEnums.h.inc" #define GET_ATTRDEF_CLASSES #include "Mitigation/IR/MitigationAttributes.h.inc" diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 3fe3014902..d1f74ede13 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -45,6 +45,9 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const RankedTensorType scaleFactorType = scaleFactors.getType().cast(); const auto sizeInt = scaleFactorType.getDimSize(0); + // Folding type + auto folding = op.getFolding(); + // Create the folded circuit function FlatSymbolRefAttr foldedCircuitRefAttr = getOrInsertFoldedCircuit(loc, rewriter, op, scaleFactorType.getElementType()); @@ -149,7 +152,7 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew func::FuncOp fnAllocOp = SymbolTable::lookupNearestSymbolFrom(op, quantumAllocRefAttr); - // Get the number of qubits + // Get the number of qubits quantum::AllocOp allocOp = *fnOp.getOps().begin(); std::optional numberQubitsOptional = allocOp.getNqubitsAttr(); int64_t numberQubits = numberQubitsOptional.value_or(0); From 5312cd46a1e824f92051c1654b54411a4c325814 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 18 Jul 2024 16:26:53 -0400 Subject: [PATCH 03/94] Code reuse for local folding algorithm --- .../Transforms/MitigationMethods/Zne.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index d1f74ede13..2d85f18478 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -46,11 +46,14 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const const auto sizeInt = scaleFactorType.getDimSize(0); // Folding type - auto folding = op.getFolding(); + auto foldingAlgorithm = op.getFoldingAlgorithm(); + // TODO: Just cast this to an integer, by here: + // 1 - Global + // 2 - Local // Create the folded circuit function FlatSymbolRefAttr foldedCircuitRefAttr = - getOrInsertFoldedCircuit(loc, rewriter, op, scaleFactorType.getElementType()); + getOrInsertFoldedCircuit(loc, rewriter, op, scaleFactorType.getElementType(), foldingAlgorithm); func::FuncOp foldedCircuit = SymbolTable::lookupNearestSymbolFrom(op, foldedCircuitRefAttr); @@ -128,7 +131,8 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const } FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, - mitigation::ZneOp op, Type scalarType) + mitigation::ZneOp op, Type scalarType + int foldingAlgorithm) { MLIRContext *ctx = rewriter.getContext(); @@ -152,7 +156,7 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew func::FuncOp fnAllocOp = SymbolTable::lookupNearestSymbolFrom(op, quantumAllocRefAttr); - // Get the number of qubits + // Get the number of qubits quantum::AllocOp allocOp = *fnOp.getOps().begin(); std::optional numberQubitsOptional = allocOp.getNqubitsAttr(); int64_t numberQubits = numberQubitsOptional.value_or(0); @@ -162,6 +166,10 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew StringAttr name = deviceInitOp.getNameAttr(); StringAttr kwargs = deviceInitOp.getKwargsAttr(); + if (foldingType == 2) { + return localFolding(/* TODO: what args? */); + } + // Function without measurements: Create function without measurements and with qreg as last // argument FlatSymbolRefAttr fnWithoutMeasurementsRefAttr = From c4fded9aceef274feca4d2c7a2acd7ed4c673fb9 Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Fri, 19 Jul 2024 19:29:25 +0200 Subject: [PATCH 04/94] allow reading attrs from mitigation dialect --- .../api_extensions/error_mitigation.py | 21 +++++++++++-------- frontend/catalyst/jax_primitives.py | 21 +++++++++++-------- .../Mitigation/IR/MitigationDialect.td | 2 +- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index 6b65a2cf8d..e575598597 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -27,14 +27,8 @@ import pennylane as qml from jax._src.tree_util import tree_flatten -from catalyst.jax_primitives import zne_p +from catalyst.jax_primitives import zne_p, Folding -from enum import IntEnum - -class Folding(IntEnum): - GLOBAL = 1 - RANDOM = 2 - ALL = 3 ## API ## def mitigate_with_zne(fn=None, *, scale_factors=None, extrapolate=None, extrapolate_kwargs=None, folding='global'): @@ -146,8 +140,17 @@ def __call__(self, *args, **kwargs): if len(set_dtypes) != 1 or set_dtypes.pop().kind != "f": raise TypeError("All expectation and classical values dtypes must match and be float.") args_data, _ = tree_flatten(args) - folding = Folding[self.folding.upper()].value - results = zne_p.bind(*args_data, folding, self.scale_factors, jaxpr=jaxpr, fn=self.fn) + try: + folding=Folding[self.folding] + except KeyError as e: + raise KeyError(f"Folding type must be one of {Folding._member_names_}") from e + results = zne_p.bind( + *args_data, + self.scale_factors, + folding=folding, + jaxpr=jaxpr, + fn=self.fn + ) float_scale_factors = jnp.array(self.scale_factors, dtype=float) results = self.extrapolate(float_scale_factors, results[0]) # Single measurement diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index a910885ed0..b83579526b 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -17,6 +17,7 @@ import sys from dataclasses import dataclass +from enum import IntEnum from itertools import chain from typing import Any, Dict, Iterable, List, Union @@ -93,6 +94,7 @@ from catalyst.utils.extra_bindings import FromElementsOp, TensorExtractOp from catalyst.utils.types import convert_shaped_arrays_to_tensors + # pylint: disable=unused-argument,too-many-lines,too-many-statements,protected-access ######### @@ -190,6 +192,9 @@ def _obs_lowering(aval): mlir.ir_type_handlers[AbstractObs] = _obs_lowering + +Folding = IntEnum('Folding', ['global', 'random', 'all']) + ############## # Primitives # ############## @@ -716,25 +721,25 @@ def _vjp_lowering(ctx, *args, jaxpr, fn, grad_params): @zne_p.def_impl -def _zne_def_impl(ctx, *args, jaxpr, fn): # pragma: no cover +def _zne_def_impl(ctx, *args, folding, jaxpr, fn): # pragma: no cover raise NotImplementedError() @zne_p.def_abstract_eval -def _zne_abstract_eval(*args, jaxpr, fn): # pylint: disable=unused-argument +def _zne_abstract_eval(*args, folding, jaxpr, fn): # pylint: disable=unused-argument shape = list(args[-1].shape) if len(jaxpr.out_avals) > 1: shape.append(len(jaxpr.out_avals)) return [core.ShapedArray(shape, jaxpr.out_avals[0].dtype)] -def _folding_attribute(ctx, folding: str): +def _folding_attribute(ctx, folding): ctx = ctx.module_context.context return ir.OpaqueAttr.get( - "mitigation", ("folding " + folding).encode("utf-8"), ir.NoneType.get(ctx), ctx + "mitigation", ("folding " + Folding(folding).name).encode("utf-8"), ir.NoneType.get(ctx), ctx ) -def _zne_lowering(ctx, *args, jaxpr, fn): +def _zne_lowering(ctx, *args, folding, jaxpr, fn): """Lowering function to the ZNE opearation. Args: ctx: the MLIR context @@ -747,14 +752,12 @@ def _zne_lowering(ctx, *args, jaxpr, fn): symbol_name = func_op.name.value output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) flat_output_types = util.flatten(output_types) - folding = args[-2] scale_factors = args[-1] return ZneOp( flat_output_types, ir.FlatSymbolRefAttr.get(symbol_name), - mlir.flatten_lowering_ir_args(args), - # TODO: Once this works, change hardcoded value to actual value from input - _folding_attribute(ctx, "global"), + mlir.flatten_lowering_ir_args(args[0:-1]), + _folding_attribute(ctx, folding), scale_factors, ).results diff --git a/mlir/include/Mitigation/IR/MitigationDialect.td b/mlir/include/Mitigation/IR/MitigationDialect.td index 7d9da23b83..93f055a62c 100644 --- a/mlir/include/Mitigation/IR/MitigationDialect.td +++ b/mlir/include/Mitigation/IR/MitigationDialect.td @@ -34,7 +34,7 @@ def Mitigation_Dialect : Dialect { /// Use the default type printing/parsing hooks, otherwise we would explicitly define them. // let useDefaultTypePrinterParser = 1; - // let useDefaultAttributePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; } //===----------------------------------------------------------------------===// From ab2a57397536d3f48cd493695dd44c9339de509e Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Mon, 22 Jul 2024 10:16:59 +0200 Subject: [PATCH 05/94] style fixes suggested by codefactor --- .../api_extensions/error_mitigation.py | 21 ++++++++++++------- frontend/catalyst/jax_primitives.py | 4 +++- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index e575598597..e14274d1a9 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -31,7 +31,14 @@ ## API ## -def mitigate_with_zne(fn=None, *, scale_factors=None, extrapolate=None, extrapolate_kwargs=None, folding='global'): +def mitigate_with_zne( + fn=None, + *, + scale_factors=None, + extrapolate=None, + extrapolate_kwargs=None, + folding="global" +): """A :func:`~.qjit` compatible error mitigation of an input circuit using zero-noise extrapolation. @@ -141,14 +148,14 @@ def __call__(self, *args, **kwargs): raise TypeError("All expectation and classical values dtypes must match and be float.") args_data, _ = tree_flatten(args) try: - folding=Folding[self.folding] + folding = Folding[self.folding] except KeyError as e: - raise KeyError(f"Folding type must be one of {Folding._member_names_}") from e + raise KeyError(f"Folding type must be one of {Folding._member_names_}") from e results = zne_p.bind( - *args_data, - self.scale_factors, - folding=folding, - jaxpr=jaxpr, + *args_data, + self.scale_factors, + folding=folding, + jaxpr=jaxpr, fn=self.fn ) float_scale_factors = jnp.array(self.scale_factors, dtype=float) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index b83579526b..c92c6e4a05 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -736,7 +736,9 @@ def _zne_abstract_eval(*args, folding, jaxpr, fn): # pylint: disable=unused-arg def _folding_attribute(ctx, folding): ctx = ctx.module_context.context return ir.OpaqueAttr.get( - "mitigation", ("folding " + Folding(folding).name).encode("utf-8"), ir.NoneType.get(ctx), ctx + "mitigation", ("folding " + Folding(folding).name).encode("utf-8"), + ir.NoneType.get(ctx), + ctx ) def _zne_lowering(ctx, *args, folding, jaxpr, fn): From fcd6cd0935df076de78f3f695febf41cee42b4ed Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Mon, 22 Jul 2024 10:22:00 +0200 Subject: [PATCH 06/94] trim trailing whitespace --- frontend/catalyst/api_extensions/error_mitigation.py | 12 ++++++------ frontend/catalyst/jax_primitives.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index e14274d1a9..ea8b9a583f 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -32,11 +32,11 @@ ## API ## def mitigate_with_zne( - fn=None, - *, - scale_factors=None, - extrapolate=None, - extrapolate_kwargs=None, + fn=None, + *, + scale_factors=None, + extrapolate=None, + extrapolate_kwargs=None, folding="global" ): """A :func:`~.qjit` compatible error mitigation of an input circuit using zero-noise @@ -58,7 +58,7 @@ def mitigate_with_zne( By default, perfect polynomial fitting will be used. extrapolate_kwargs (dict[str, Any]): Keyword arguments to be passed to the extrapolation function. - folding (str): The unitary folding technique to be used to scale the circuit + folding (str): The unitary folding technique to be used to scale the circuit Returns: Callable: A callable object that computes the mitigated of the wrapped :class:`qml.QNode` diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index c92c6e4a05..4d0ae4e211 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -736,8 +736,8 @@ def _zne_abstract_eval(*args, folding, jaxpr, fn): # pylint: disable=unused-arg def _folding_attribute(ctx, folding): ctx = ctx.module_context.context return ir.OpaqueAttr.get( - "mitigation", ("folding " + Folding(folding).name).encode("utf-8"), - ir.NoneType.get(ctx), + "mitigation", ("folding " + Folding(folding).name).encode("utf-8"), + ir.NoneType.get(ctx), ctx ) From e033ce8fc31fd7ef303c87c9cc187abffc3a296b Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Mon, 22 Jul 2024 11:30:38 +0200 Subject: [PATCH 07/94] test folding argument error --- frontend/test/pytest/test_mitigation.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/frontend/test/pytest/test_mitigation.py b/frontend/test/pytest/test_mitigation.py index b948822f8d..1207c8c422 100644 --- a/frontend/test/pytest/test_mitigation.py +++ b/frontend/test/pytest/test_mitigation.py @@ -196,6 +196,24 @@ def mitigated_qnode(args): mitigated_qnode(0.1) +def test_folding_type_error(): + """Test that value of folding argument is from allowed list""" + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(device=dev) + def circuit(): + return 0.0 + + @catalyst.qjit + def mitigated_qnode(*args): + return catalyst.mitigate_with_zne( + circuit, scale_factors=[], folding="bad-folding-type-value" + )() + + with pytest.raises(KeyError, match="Folding type must be"): + mitigated_qnode() + + @pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5]) def test_zne_usage_patterns(params): """Test usage patterns of catalyst.zne.""" From 176222275fb19ff59eaa09e0510b022278e5d2d1 Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Mon, 22 Jul 2024 11:35:36 +0200 Subject: [PATCH 08/94] more style fixes --- frontend/catalyst/jax_primitives.py | 3 ++- frontend/test/pytest/test_mitigation.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 4d0ae4e211..ded2aec96b 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -736,7 +736,8 @@ def _zne_abstract_eval(*args, folding, jaxpr, fn): # pylint: disable=unused-arg def _folding_attribute(ctx, folding): ctx = ctx.module_context.context return ir.OpaqueAttr.get( - "mitigation", ("folding " + Folding(folding).name).encode("utf-8"), + "mitigation", + ("folding " + Folding(folding).name).encode("utf-8"), ir.NoneType.get(ctx), ctx ) diff --git a/frontend/test/pytest/test_mitigation.py b/frontend/test/pytest/test_mitigation.py index 1207c8c422..f4526b1700 100644 --- a/frontend/test/pytest/test_mitigation.py +++ b/frontend/test/pytest/test_mitigation.py @@ -205,7 +205,7 @@ def circuit(): return 0.0 @catalyst.qjit - def mitigated_qnode(*args): + def mitigated_qnode(*args): # unused dummy argument to force lazy evaluation of the function return catalyst.mitigate_with_zne( circuit, scale_factors=[], folding="bad-folding-type-value" )() From ffa96c76d622886021ac937115c6a4ec33f48fea Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Mon, 22 Jul 2024 15:18:08 +0200 Subject: [PATCH 09/94] document folding argument --- frontend/catalyst/api_extensions/error_mitigation.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index ea8b9a583f..78bf24895d 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -58,7 +58,10 @@ def mitigate_with_zne( By default, perfect polynomial fitting will be used. extrapolate_kwargs (dict[str, Any]): Keyword arguments to be passed to the extrapolation function. - folding (str): The unitary folding technique to be used to scale the circuit + folding (str): Unitary folding technique to be used to scale the circuit. Possible values: + - global: global unitary of the input circuit is folded + - all: all gates locally folded + - random: random subset of gates of the input circuits locally folded Returns: Callable: A callable object that computes the mitigated of the wrapped :class:`qml.QNode` @@ -151,6 +154,10 @@ def __call__(self, *args, **kwargs): folding = Folding[self.folding] except KeyError as e: raise KeyError(f"Folding type must be one of {Folding._member_names_}") from e + # TODO: remove the following check once #755 is completed + if folding != Folding.global: + raise NotImplementedError(f"Folding type {folding.name}" is being developed") + results = zne_p.bind( *args_data, self.scale_factors, From 1dd49b2623fd5157c7d9a8ee0829f2d895bcdd1f Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Mon, 22 Jul 2024 15:43:24 +0200 Subject: [PATCH 10/94] run make format --- .../api_extensions/error_mitigation.py | 27 +++++++------------ frontend/catalyst/jax_primitives.py | 7 +++-- mlir/include/Mitigation/IR/MitigationOps.td | 6 ++--- mlir/lib/Mitigation/IR/MitigationDialect.cpp | 4 +-- mlir/lib/Mitigation/IR/MitigationOps.cpp | 1 - .../Transforms/MitigationMethods/Zne.cpp | 2 +- mlir/test/Mitigation/zne.mlir | 2 +- 7 files changed, 19 insertions(+), 30 deletions(-) diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index 78bf24895d..d3c8a1d9b6 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -27,17 +27,12 @@ import pennylane as qml from jax._src.tree_util import tree_flatten -from catalyst.jax_primitives import zne_p, Folding +from catalyst.jax_primitives import Folding, zne_p ## API ## def mitigate_with_zne( - fn=None, - *, - scale_factors=None, - extrapolate=None, - extrapolate_kwargs=None, - folding="global" + fn=None, *, scale_factors=None, extrapolate=None, extrapolate_kwargs=None, folding="GLOBAL" ): """A :func:`~.qjit` compatible error mitigation of an input circuit using zero-noise extrapolation. @@ -59,9 +54,9 @@ def mitigate_with_zne( extrapolate_kwargs (dict[str, Any]): Keyword arguments to be passed to the extrapolation function. folding (str): Unitary folding technique to be used to scale the circuit. Possible values: - - global: global unitary of the input circuit is folded - - all: all gates locally folded - - random: random subset of gates of the input circuits locally folded + - GLOBAL: global unitary of the input circuit is folded + - ALL: all gates locally folded + - RANDOM: random subset of gates of the input circuits locally folded Returns: Callable: A callable object that computes the mitigated of the wrapped :class:`qml.QNode` @@ -129,7 +124,7 @@ def __init__( fn: Callable, scale_factors: jnp.ndarray, extrapolate: Callable[[Sequence[float], Sequence[float]], float], - folding: str + folding: str, ): if not isinstance(fn, qml.QNode): raise TypeError(f"A QNode is expected, got the classical function {fn}") @@ -155,15 +150,11 @@ def __call__(self, *args, **kwargs): except KeyError as e: raise KeyError(f"Folding type must be one of {Folding._member_names_}") from e # TODO: remove the following check once #755 is completed - if folding != Folding.global: - raise NotImplementedError(f"Folding type {folding.name}" is being developed") + if folding != Folding.GLOBAL: + raise NotImplementedError(f"Folding type {folding.name} is being developed") results = zne_p.bind( - *args_data, - self.scale_factors, - folding=folding, - jaxpr=jaxpr, - fn=self.fn + *args_data, self.scale_factors, folding=folding, jaxpr=jaxpr, fn=self.fn ) float_scale_factors = jnp.array(self.scale_factors, dtype=float) results = self.extrapolate(float_scale_factors, results[0]) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index ded2aec96b..ecf274f91a 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -94,7 +94,6 @@ from catalyst.utils.extra_bindings import FromElementsOp, TensorExtractOp from catalyst.utils.types import convert_shaped_arrays_to_tensors - # pylint: disable=unused-argument,too-many-lines,too-many-statements,protected-access ######### @@ -192,8 +191,7 @@ def _obs_lowering(aval): mlir.ir_type_handlers[AbstractObs] = _obs_lowering - -Folding = IntEnum('Folding', ['global', 'random', 'all']) +Folding = IntEnum("Folding", ["GLOBAL", "RANDOM", "ALL"]) ############## # Primitives # @@ -739,9 +737,10 @@ def _folding_attribute(ctx, folding): "mitigation", ("folding " + Folding(folding).name).encode("utf-8"), ir.NoneType.get(ctx), - ctx + ctx, ) + def _zne_lowering(ctx, *args, folding, jaxpr, fn): """Lowering function to the ZNE opearation. Args: diff --git a/mlir/include/Mitigation/IR/MitigationOps.td b/mlir/include/Mitigation/IR/MitigationOps.td index ccaad50cd2..1d6b00b912 100644 --- a/mlir/include/Mitigation/IR/MitigationOps.td +++ b/mlir/include/Mitigation/IR/MitigationOps.td @@ -26,9 +26,9 @@ include "Mitigation/IR/MitigationDialect.td" def Folding : I32EnumAttr<"Folding", "Folding types", [ - I32EnumAttrCase<"global", 1>, - I32EnumAttrCase<"random", 2>, - I32EnumAttrCase<"all", 3>, + I32EnumAttrCase<"GLOBAL", 1>, + I32EnumAttrCase<"RANDOM", 2>, + I32EnumAttrCase<"ALL", 3>, ]> { let cppNamespace = "catalyst::mitigation"; let genSpecializedAttr = 0; diff --git a/mlir/lib/Mitigation/IR/MitigationDialect.cpp b/mlir/lib/Mitigation/IR/MitigationDialect.cpp index d4bc05ab66..b9b14d01db 100644 --- a/mlir/lib/Mitigation/IR/MitigationDialect.cpp +++ b/mlir/lib/Mitigation/IR/MitigationDialect.cpp @@ -13,9 +13,9 @@ // limitations under the License. #include "mlir/IR/Builders.h" -#include "mlir/Transforms/InliningUtils.h" #include "mlir/IR/DialectImplementation.h" // needed for generated type parser -#include "llvm/ADT/TypeSwitch.h" // needed for generated type parser +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/TypeSwitch.h" // needed for generated type parser #include "Mitigation/IR/MitigationDialect.h" #include "Mitigation/IR/MitigationOps.h" diff --git a/mlir/lib/Mitigation/IR/MitigationOps.cpp b/mlir/lib/Mitigation/IR/MitigationOps.cpp index 21df8a9396..af526263cc 100644 --- a/mlir/lib/Mitigation/IR/MitigationOps.cpp +++ b/mlir/lib/Mitigation/IR/MitigationOps.cpp @@ -28,7 +28,6 @@ using namespace catalyst::mitigation; #define GET_OP_CLASSES #include "Mitigation/IR/MitigationOps.cpp.inc" - //===----------------------------------------------------------------------===// // SymbolUserOpInterface //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index d1f74ede13..6bf429061e 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -152,7 +152,7 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew func::FuncOp fnAllocOp = SymbolTable::lookupNearestSymbolFrom(op, quantumAllocRefAttr); - // Get the number of qubits + // Get the number of qubits quantum::AllocOp allocOp = *fnOp.getOps().begin(); std::optional numberQubitsOptional = allocOp.getNqubitsAttr(); int64_t numberQubits = numberQubitsOptional.value_or(0); diff --git a/mlir/test/Mitigation/zne.mlir b/mlir/test/Mitigation/zne.mlir index 0bfe041575..237ad7e7d5 100644 --- a/mlir/test/Mitigation/zne.mlir +++ b/mlir/test/Mitigation/zne.mlir @@ -99,6 +99,6 @@ func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { // CHECK: return [[results]] : tensor<5xf64> func.func @zneCallScalarScalar(%arg0: tensor<3xf64>) -> tensor<5xf64> { %scaleFactors = arith.constant dense<[1, 2, 3, 4, 5]> : tensor<5xindex> - %0 = mitigation.zne @simpleCircuit(%arg0) folding (global) scaleFactors (%scaleFactors : tensor<5xindex>) : (tensor<3xf64>) -> tensor<5xf64> + %0 = mitigation.zne @simpleCircuit(%arg0) folding (GLOBAL) scaleFactors (%scaleFactors : tensor<5xindex>) : (tensor<3xf64>) -> tensor<5xf64> func.return %0 : tensor<5xf64> } From c3089bfc0e0158933a2f0445884bd0b326f18338 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Mon, 22 Jul 2024 14:46:23 -0400 Subject: [PATCH 11/94] Draft random/local folding branch points --- .../Transforms/MitigationMethods/Zne.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 2d85f18478..7769d8a27d 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -46,10 +46,7 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const const auto sizeInt = scaleFactorType.getDimSize(0); // Folding type - auto foldingAlgorithm = op.getFoldingAlgorithm(); - // TODO: Just cast this to an integer, by here: - // 1 - Global - // 2 - Local + auto foldingAlgorithm = op.getFolding(); // Create the folded circuit function FlatSymbolRefAttr foldedCircuitRefAttr = @@ -132,7 +129,7 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op, Type scalarType - int foldingAlgorithm) + Folding foldingAlgorithm) { MLIRContext *ctx = rewriter.getContext(); @@ -166,8 +163,11 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew StringAttr name = deviceInitOp.getNameAttr(); StringAttr kwargs = deviceInitOp.getKwargsAttr(); - if (foldingType == 2) { - return localFolding(/* TODO: what args? */); + if (foldingAlgorithm == Folding(2)) { + return randomLocalFolding(/* TODO: what args? */); + } + if (foldingAlgorithm == Folding(3)) { + return allLocalFolding(/* TODO: what args? */); } // Function without measurements: Create function without measurements and with qreg as last From e89bbece590cc179770e2afd36507271a54fc8ac Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Mon, 22 Jul 2024 15:15:43 -0400 Subject: [PATCH 12/94] Fix function prototypes --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 2 +- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 7769d8a27d..c35cd9ee71 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -128,7 +128,7 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const } FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, - mitigation::ZneOp op, Type scalarType + mitigation::ZneOp op, Type scalarType, Folding foldingAlgorithm) { MLIRContext *ctx = rewriter.getContext(); diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp index 2f9388c36f..6496116cea 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp @@ -32,7 +32,8 @@ struct ZneLowering : public OpRewritePattern { private: static FlatSymbolRefAttr getOrInsertFoldedCircuit(Location loc, PatternRewriter &builder, - mitigation::ZneOp op, Type scalarType); + mitigation::ZneOp op, Type scalarType, + Folding foldingAlgorithm); static FlatSymbolRefAttr getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op); static FlatSymbolRefAttr From 7e5a85d004b2d004e85d343d93f60b310ee9d6b2 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 11:25:11 -0400 Subject: [PATCH 13/94] Function prototypes --- .../Transforms/MitigationMethods/Zne.cpp | 64 +++++++++++++++---- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index c35cd9ee71..cc80f5e738 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -126,16 +126,35 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const // Replace the original results rewriter.replaceOp(op, resultValues); } +FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, + Type scalarType, ModuleOp moduleOp, + std::string fnFoldedName, func::FuncOp fnOp, + TypeRange originalTypes, Type qregType, + func::FuncOp fnAllocOp, int64_t numberQubits, + quantum::DeviceInitOp deviceInitOp) +{ + // TODO: Implement. + return FlatSymbolRefAttr(); +} +FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, + Type scalarType, ModuleOp moduleOp, + std::string fnFoldedName, func::FuncOp fnOp, + TypeRange originalTypes, Type qregType, + func::FuncOp fnAllocOp, int64_t numberQubits, + quantum::DeviceInitOp deviceInitOp) +{ + // TODO: Implement. + return FlatSymbolRefAttr(); +} FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op, Type scalarType, Folding foldingAlgorithm) { - MLIRContext *ctx = rewriter.getContext(); - OpBuilder::InsertionGuard guard(rewriter); ModuleOp moduleOp = op->getParentOfType(); std::string fnFoldedName = op.getCallee().str() + ".folded"; + MLIRContext *ctx = rewriter.getContext(); if (moduleOp.lookupSymbol(fnFoldedName)) { return SymbolRefAttr::get(ctx, fnFoldedName); @@ -144,7 +163,7 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // Original function func::FuncOp fnOp = SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr()); TypeRange originalTypes = op.getArgs().getTypes(); - Type qregType = quantum::QuregType::get(rewriter.getContext()); + Type qregType = quantum::QuregType::get(ctx); // Set insertion in the module rewriter.setInsertionPointToStart(moduleOp.getBody()); @@ -154,22 +173,45 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew SymbolTable::lookupNearestSymbolFrom(op, quantumAllocRefAttr); // Get the number of qubits - quantum::AllocOp allocOp = *fnOp.getOps().begin(); - std::optional numberQubitsOptional = allocOp.getNqubitsAttr(); - int64_t numberQubits = numberQubitsOptional.value_or(0); + const int64_t numberQubits = (*fnOp.getOps().begin()).getNqubitsAttr().value_or(0); // Get the device quantum::DeviceInitOp deviceInitOp = *fnOp.getOps().begin(); - StringAttr lib = deviceInitOp.getLibAttr(); - StringAttr name = deviceInitOp.getNameAttr(); - StringAttr kwargs = deviceInitOp.getKwargsAttr(); if (foldingAlgorithm == Folding(2)) { - return randomLocalFolding(/* TODO: what args? */); + return randomLocalFolding( + loc, + rewriter, + scalarType, + moduleOp, + fnFoldedName, + fnOp, + originalTypes, + qregType, + fnAllocOp, + numberQubits, + deviceInitOp + ); } if (foldingAlgorithm == Folding(3)) { - return allLocalFolding(/* TODO: what args? */); + return allLocalFolding( + loc, + rewriter, + scalarType, + moduleOp, + fnFoldedName, + fnOp, + originalTypes, + qregType, + fnAllocOp, + numberQubits, + deviceInitOp + ); } + StringAttr lib = deviceInitOp.getLibAttr(); + StringAttr name = deviceInitOp.getNameAttr(); + StringAttr kwargs = deviceInitOp.getKwargsAttr(); + // Function without measurements: Create function without measurements and with qreg as last // argument FlatSymbolRefAttr fnWithoutMeasurementsRefAttr = From 4bb1983ccdc90600791c6f719cdd43b7f7d59a38 Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Tue, 23 Jul 2024 17:36:04 +0200 Subject: [PATCH 14/94] misc addressing review comments --- .../api_extensions/error_mitigation.py | 12 ++++------ frontend/catalyst/jax_primitives.py | 10 +++++--- frontend/test/pytest/test_mitigation.py | 24 +++++++++++++++---- mlir/include/Mitigation/IR/MitigationOps.td | 10 ++++---- mlir/lib/Mitigation/IR/MitigationOps.cpp | 6 ++--- .../Transforms/MitigationMethods/Zne.cpp | 3 --- mlir/test/Mitigation/zne.mlir | 2 +- 7 files changed, 40 insertions(+), 27 deletions(-) diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index d3c8a1d9b6..f8168f6484 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -32,7 +32,7 @@ ## API ## def mitigate_with_zne( - fn=None, *, scale_factors=None, extrapolate=None, extrapolate_kwargs=None, folding="GLOBAL" + fn=None, *, scale_factors=None, extrapolate=None, extrapolate_kwargs=None, folding="global" ): """A :func:`~.qjit` compatible error mitigation of an input circuit using zero-noise extrapolation. @@ -54,9 +54,7 @@ def mitigate_with_zne( extrapolate_kwargs (dict[str, Any]): Keyword arguments to be passed to the extrapolation function. folding (str): Unitary folding technique to be used to scale the circuit. Possible values: - - GLOBAL: global unitary of the input circuit is folded - - ALL: all gates locally folded - - RANDOM: random subset of gates of the input circuits locally folded + - global: the global unitary of the input circuit is folded Returns: Callable: A callable object that computes the mitigated of the wrapped :class:`qml.QNode` @@ -146,9 +144,9 @@ def __call__(self, *args, **kwargs): raise TypeError("All expectation and classical values dtypes must match and be float.") args_data, _ = tree_flatten(args) try: - folding = Folding[self.folding] - except KeyError as e: - raise KeyError(f"Folding type must be one of {Folding._member_names_}") from e + folding = Folding(self.folding) + except ValueError as e: + raise ValueError(f"Folding type must be one of {list(map(str, Folding))}") from e # TODO: remove the following check once #755 is completed if folding != Folding.GLOBAL: raise NotImplementedError(f"Folding type {folding.name} is being developed") diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index a789070dd8..411ddf8ca5 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -17,7 +17,7 @@ import sys from dataclasses import dataclass -from enum import IntEnum +from enum import StrEnum from itertools import chain from typing import Any, Dict, Iterable, List, Union @@ -196,7 +196,11 @@ def _obs_lowering(aval): mlir.ir_type_handlers[AbstractObs] = _obs_lowering -Folding = IntEnum("Folding", ["GLOBAL", "RANDOM", "ALL"]) +class Folding(StrEnum): + GLOBAL = "global" + RANDOM = "random" + ALL = "all" + ############## # Primitives # @@ -742,7 +746,7 @@ def _folding_attribute(ctx, folding): ctx = ctx.module_context.context return ir.OpaqueAttr.get( "mitigation", - ("folding " + Folding(folding).name).encode("utf-8"), + ("folding " + Folding(folding)).encode("utf-8"), ir.NoneType.get(ctx), ctx, ) diff --git a/frontend/test/pytest/test_mitigation.py b/frontend/test/pytest/test_mitigation.py index f4526b1700..97ec785525 100644 --- a/frontend/test/pytest/test_mitigation.py +++ b/frontend/test/pytest/test_mitigation.py @@ -196,7 +196,7 @@ def mitigated_qnode(args): mitigated_qnode(0.1) -def test_folding_type_error(): +def test_folding_type_not_supported(): """Test that value of folding argument is from allowed list""" dev = qml.device("lightning.qubit", wires=2) @@ -204,14 +204,28 @@ def test_folding_type_error(): def circuit(): return 0.0 - @catalyst.qjit - def mitigated_qnode(*args): # unused dummy argument to force lazy evaluation of the function + def mitigated_qnode(): return catalyst.mitigate_with_zne( circuit, scale_factors=[], folding="bad-folding-type-value" )() - with pytest.raises(KeyError, match="Folding type must be"): - mitigated_qnode() + with pytest.raises(ValueError, match="Folding type must be"): + catalyst.qjit(mitigated_qnode) + + +def test_folding_type_not_implemented(): + """Test value of folding argument supported but not yet developed""" + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(device=dev) + def circuit(): + return 0.0 + + def mitigated_qnode(): + return catalyst.mitigate_with_zne(circuit, scale_factors=[], folding="all")() + + with pytest.raises(NotImplementedError): + catalyst.qjit(mitigated_qnode) @pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5]) diff --git a/mlir/include/Mitigation/IR/MitigationOps.td b/mlir/include/Mitigation/IR/MitigationOps.td index 1d6b00b912..8cb682d26b 100644 --- a/mlir/include/Mitigation/IR/MitigationOps.td +++ b/mlir/include/Mitigation/IR/MitigationOps.td @@ -15,20 +15,20 @@ #ifndef MITIGATION_OPS #define MITIGATION_OPS +include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" -include "mlir/Interfaces/CallInterfaces.td" include "mlir/IR/SymbolInterfaces.td" -include "mlir/IR/BuiltinAttributes.td" +include "mlir/Interfaces/CallInterfaces.td" include "Mitigation/IR/MitigationDialect.td" def Folding : I32EnumAttr<"Folding", "Folding types", [ - I32EnumAttrCase<"GLOBAL", 1>, - I32EnumAttrCase<"RANDOM", 2>, - I32EnumAttrCase<"ALL", 3>, + I32EnumAttrCase<"global", 1>, + I32EnumAttrCase<"random", 2>, + I32EnumAttrCase<"all", 3>, ]> { let cppNamespace = "catalyst::mitigation"; let genSpecializedAttr = 0; diff --git a/mlir/lib/Mitigation/IR/MitigationOps.cpp b/mlir/lib/Mitigation/IR/MitigationOps.cpp index af526263cc..10439cfacd 100644 --- a/mlir/lib/Mitigation/IR/MitigationOps.cpp +++ b/mlir/lib/Mitigation/IR/MitigationOps.cpp @@ -21,13 +21,13 @@ #include "Mitigation/IR/MitigationDialect.h" #include "Mitigation/IR/MitigationOps.h" -using namespace mlir; -using namespace catalyst::mitigation; - #include "Mitigation/IR/MitigationEnums.cpp.inc" #define GET_OP_CLASSES #include "Mitigation/IR/MitigationOps.cpp.inc" +using namespace mlir; +using namespace catalyst::mitigation; + //===----------------------------------------------------------------------===// // SymbolUserOpInterface //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 6bf429061e..3fe3014902 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -45,9 +45,6 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const RankedTensorType scaleFactorType = scaleFactors.getType().cast(); const auto sizeInt = scaleFactorType.getDimSize(0); - // Folding type - auto folding = op.getFolding(); - // Create the folded circuit function FlatSymbolRefAttr foldedCircuitRefAttr = getOrInsertFoldedCircuit(loc, rewriter, op, scaleFactorType.getElementType()); diff --git a/mlir/test/Mitigation/zne.mlir b/mlir/test/Mitigation/zne.mlir index 237ad7e7d5..0bfe041575 100644 --- a/mlir/test/Mitigation/zne.mlir +++ b/mlir/test/Mitigation/zne.mlir @@ -99,6 +99,6 @@ func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { // CHECK: return [[results]] : tensor<5xf64> func.func @zneCallScalarScalar(%arg0: tensor<3xf64>) -> tensor<5xf64> { %scaleFactors = arith.constant dense<[1, 2, 3, 4, 5]> : tensor<5xindex> - %0 = mitigation.zne @simpleCircuit(%arg0) folding (GLOBAL) scaleFactors (%scaleFactors : tensor<5xindex>) : (tensor<3xf64>) -> tensor<5xf64> + %0 = mitigation.zne @simpleCircuit(%arg0) folding (global) scaleFactors (%scaleFactors : tensor<5xindex>) : (tensor<3xf64>) -> tensor<5xf64> func.return %0 : tensor<5xf64> } From 5e84719803cc3236873afc1ade213c6d32a9e843 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 11:48:47 -0400 Subject: [PATCH 15/94] make format --- .../Transforms/MitigationMethods/Zne.cpp | 57 ++++++------------- 1 file changed, 16 insertions(+), 41 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index cc80f5e738..ec5d141cba 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -49,8 +49,8 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const auto foldingAlgorithm = op.getFolding(); // Create the folded circuit function - FlatSymbolRefAttr foldedCircuitRefAttr = - getOrInsertFoldedCircuit(loc, rewriter, op, scaleFactorType.getElementType(), foldingAlgorithm); + FlatSymbolRefAttr foldedCircuitRefAttr = getOrInsertFoldedCircuit( + loc, rewriter, op, scaleFactorType.getElementType(), foldingAlgorithm); func::FuncOp foldedCircuit = SymbolTable::lookupNearestSymbolFrom(op, foldedCircuitRefAttr); @@ -126,23 +126,19 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const // Replace the original results rewriter.replaceOp(op, resultValues); } -FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, - Type scalarType, ModuleOp moduleOp, - std::string fnFoldedName, func::FuncOp fnOp, - TypeRange originalTypes, Type qregType, - func::FuncOp fnAllocOp, int64_t numberQubits, - quantum::DeviceInitOp deviceInitOp) +FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, + ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, + TypeRange originalTypes, Type qregType, func::FuncOp fnAllocOp, + int64_t numberQubits, quantum::DeviceInitOp deviceInitOp) { // TODO: Implement. return FlatSymbolRefAttr(); } -FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, - Type scalarType, ModuleOp moduleOp, - std::string fnFoldedName, func::FuncOp fnOp, - TypeRange originalTypes, Type qregType, - func::FuncOp fnAllocOp, int64_t numberQubits, - quantum::DeviceInitOp deviceInitOp) +FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, + ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, + TypeRange originalTypes, Type qregType, func::FuncOp fnAllocOp, + int64_t numberQubits, quantum::DeviceInitOp deviceInitOp) { // TODO: Implement. return FlatSymbolRefAttr(); @@ -173,39 +169,18 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew SymbolTable::lookupNearestSymbolFrom(op, quantumAllocRefAttr); // Get the number of qubits - const int64_t numberQubits = (*fnOp.getOps().begin()).getNqubitsAttr().value_or(0); + const int64_t numberQubits = + (*fnOp.getOps().begin()).getNqubitsAttr().value_or(0); // Get the device quantum::DeviceInitOp deviceInitOp = *fnOp.getOps().begin(); if (foldingAlgorithm == Folding(2)) { - return randomLocalFolding( - loc, - rewriter, - scalarType, - moduleOp, - fnFoldedName, - fnOp, - originalTypes, - qregType, - fnAllocOp, - numberQubits, - deviceInitOp - ); + return randomLocalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, + originalTypes, qregType, fnAllocOp, numberQubits, deviceInitOp); } if (foldingAlgorithm == Folding(3)) { - return allLocalFolding( - loc, - rewriter, - scalarType, - moduleOp, - fnFoldedName, - fnOp, - originalTypes, - qregType, - fnAllocOp, - numberQubits, - deviceInitOp - ); + return allLocalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, + originalTypes, qregType, fnAllocOp, numberQubits, deviceInitOp); } StringAttr lib = deviceInitOp.getLibAttr(); From aeebc509fb99c209608e9b282d61540a9928233a Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 11:59:22 -0400 Subject: [PATCH 16/94] Testing that refactor works --- .../Transforms/MitigationMethods/Zne.cpp | 115 +++++++++++++++--- 1 file changed, 98 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index ec5d141cba..5254e0a723 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -129,7 +129,9 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, TypeRange originalTypes, Type qregType, func::FuncOp fnAllocOp, - int64_t numberQubits, quantum::DeviceInitOp deviceInitOp) + int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, + func::FuncOp fnWithoutMeasurementsOp, + func::FuncOp fnWithMeasurementsOp) { // TODO: Implement. return FlatSymbolRefAttr(); @@ -138,10 +140,87 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, Ty FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, TypeRange originalTypes, Type qregType, func::FuncOp fnAllocOp, - int64_t numberQubits, quantum::DeviceInitOp deviceInitOp) + int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, + func::FuncOp fnWithoutMeasurementsOp, + func::FuncOp fnWithMeasurementsOp) { - // TODO: Implement. - return FlatSymbolRefAttr(); + MLIRContext *ctx = rewriter.getContext(); + StringAttr lib = deviceInitOp.getLibAttr(); + StringAttr name = deviceInitOp.getNameAttr(); + StringAttr kwargs = deviceInitOp.getKwargsAttr(); + + // Function folded: Create the folded circuit (withoutMeasurement * + // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements + rewriter.setInsertionPointToStart(moduleOp.getBody()); + SmallVector typesFolded(originalTypes.begin(), originalTypes.end()); + Type indexType = rewriter.getIndexType(); + typesFolded.push_back(indexType); + FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ + typesFolded, + /*outputs=*/fnOp.getResultTypes()); + + func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); + fnFoldedOp.setPrivate(); + + Block *foldedBloc = fnFoldedOp.addEntryBlock(); + rewriter.setInsertionPointToStart(foldedBloc); + // Add device + rewriter.create(loc, lib, name, kwargs); + TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); + Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); + Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + int64_t sizeArgs = fnFoldedOp.getArguments().size(); + Value size = fnFoldedOp.getArgument(sizeArgs - 1); + // Add scf for loop to create the folding + Value loopedQreg = + rewriter + .create( + loc, c0, size, c1, /*iterArgsInit=*/allocQreg, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { + Value qreg = iterArgs.front(); + std::vector argsAndQreg(fnFoldedOp.getArguments().begin(), + fnFoldedOp.getArguments().end()); + argsAndQreg.pop_back(); + argsAndQreg.push_back(qreg); + + // Call the function without measurements + Value fnWithoutMeasurementsQreg = + builder.create(loc, fnWithoutMeasurementsOp, argsAndQreg) + .getResult(0); + + // Call the function without measurements in an adjoint region + auto adjointOp = builder.create(loc, qregType, + fnWithoutMeasurementsQreg); + Region *adjointRegion = &adjointOp.getRegion(); + Block *adjointBlock = builder.createBlock(adjointRegion, {}, qregType, loc); + + std::vector argsAndQregAdjoint(fnFoldedOp.getArguments().begin(), + fnFoldedOp.getArguments().end()); + argsAndQregAdjoint.pop_back(); + argsAndQregAdjoint.push_back(adjointBlock->getArgument(0)); + Value fnWithoutMeasurementsAdjointQreg = + builder + .create(loc, fnWithoutMeasurementsOp, argsAndQregAdjoint) + .getResult(0); + builder.create(loc, fnWithoutMeasurementsAdjointQreg); + builder.setInsertionPointAfter(adjointOp); + builder.create(loc, adjointOp.getResult()); + }) + .getResult(0); + std::vector argsAndRegMeasurement(fnFoldedOp.getArguments().begin(), + fnFoldedOp.getArguments().end()); + argsAndRegMeasurement.pop_back(); + argsAndRegMeasurement.push_back(loopedQreg); + ValueRange funcFolded = + rewriter.create(loc, fnWithMeasurementsOp, argsAndRegMeasurement) + .getResults(); + // Remove device + rewriter.create(loc); + rewriter.create(loc, funcFolded); + return SymbolRefAttr::get(ctx, fnFoldedName); } FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op, Type scalarType, @@ -174,19 +253,6 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // Get the device quantum::DeviceInitOp deviceInitOp = *fnOp.getOps().begin(); - if (foldingAlgorithm == Folding(2)) { - return randomLocalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, - originalTypes, qregType, fnAllocOp, numberQubits, deviceInitOp); - } - if (foldingAlgorithm == Folding(3)) { - return allLocalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, - originalTypes, qregType, fnAllocOp, numberQubits, deviceInitOp); - } - - StringAttr lib = deviceInitOp.getLibAttr(); - StringAttr name = deviceInitOp.getNameAttr(); - StringAttr kwargs = deviceInitOp.getKwargsAttr(); - // Function without measurements: Create function without measurements and with qreg as last // argument FlatSymbolRefAttr fnWithoutMeasurementsRefAttr = @@ -200,6 +266,21 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew func::FuncOp fnWithMeasurementsOp = SymbolTable::lookupNearestSymbolFrom(op, fnWithMeasurementsRefAttr); + if (foldingAlgorithm == Folding(2)) { + return randomLocalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, + originalTypes, qregType, fnAllocOp, numberQubits, deviceInitOp, + fnWithoutMeasurementsOp, fnWithMeasurementsOp); + } + // if (foldingAlgorithm == Folding(3)) { + return allLocalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, originalTypes, + qregType, fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, + fnWithMeasurementsOp); + // } + + StringAttr lib = deviceInitOp.getLibAttr(); + StringAttr name = deviceInitOp.getNameAttr(); + StringAttr kwargs = deviceInitOp.getKwargsAttr(); + // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements rewriter.setInsertionPointToStart(moduleOp.getBody()); From 23a1edde1b50e7b15a0116b126653187ba657bf0 Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Tue, 23 Jul 2024 18:03:11 +0200 Subject: [PATCH 17/94] doc folding enum --- frontend/catalyst/api_extensions/error_mitigation.py | 2 +- frontend/catalyst/jax_primitives.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index f8168f6484..ef356bee9d 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -149,7 +149,7 @@ def __call__(self, *args, **kwargs): raise ValueError(f"Folding type must be one of {list(map(str, Folding))}") from e # TODO: remove the following check once #755 is completed if folding != Folding.GLOBAL: - raise NotImplementedError(f"Folding type {folding.name} is being developed") + raise NotImplementedError(f"Folding type {folding.value} is being developed") results = zne_p.bind( *args_data, self.scale_factors, folding=folding, jaxpr=jaxpr, fn=self.fn diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 411ddf8ca5..a386f57552 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -197,6 +197,10 @@ def _obs_lowering(aval): class Folding(StrEnum): + """ + Folding types supported by ZNE mitigation + """ + GLOBAL = "global" RANDOM = "random" ALL = "all" From 47112190ebfcb4661ed56c582cc327c5279d7a77 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 12:07:55 -0400 Subject: [PATCH 18/94] Cleaner branching --- .../Transforms/MitigationMethods/Zne.cpp | 128 +++++------------- 1 file changed, 32 insertions(+), 96 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 5254e0a723..d3d2d7eba6 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -126,23 +126,12 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const // Replace the original results rewriter.replaceOp(op, resultValues); } -FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, - ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, - TypeRange originalTypes, Type qregType, func::FuncOp fnAllocOp, - int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, - func::FuncOp fnWithoutMeasurementsOp, - func::FuncOp fnWithMeasurementsOp) -{ - // TODO: Implement. - return FlatSymbolRefAttr(); -} - -FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, - ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, - TypeRange originalTypes, Type qregType, func::FuncOp fnAllocOp, - int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, - func::FuncOp fnWithoutMeasurementsOp, - func::FuncOp fnWithMeasurementsOp) +FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, + ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, + TypeRange originalTypes, Type qregType, func::FuncOp fnAllocOp, + int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, + func::FuncOp fnWithoutMeasurementsOp, + func::FuncOp fnWithMeasurementsOp) { MLIRContext *ctx = rewriter.getContext(); StringAttr lib = deviceInitOp.getLibAttr(); @@ -222,6 +211,26 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, Type rewriter.create(loc, funcFolded); return SymbolRefAttr::get(ctx, fnFoldedName); } +FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, + ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, + TypeRange originalTypes, Type qregType, func::FuncOp fnAllocOp, + int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, + func::FuncOp fnWithoutMeasurementsOp, + func::FuncOp fnWithMeasurementsOp) +{ + // TODO: Implement. + return FlatSymbolRefAttr(); +} +FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, + ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, + TypeRange originalTypes, Type qregType, func::FuncOp fnAllocOp, + int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, + func::FuncOp fnWithoutMeasurementsOp, + func::FuncOp fnWithMeasurementsOp) +{ + // TODO: Implement. + return FlatSymbolRefAttr(); +} FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op, Type scalarType, Folding foldingAlgorithm) @@ -266,93 +275,20 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew func::FuncOp fnWithMeasurementsOp = SymbolTable::lookupNearestSymbolFrom(op, fnWithMeasurementsRefAttr); + if (foldingAlgorithm == Folding(1)) { + return globalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, originalTypes, + qregType, fnAllocOp, numberQubits, deviceInitOp, + fnWithoutMeasurementsOp, fnWithMeasurementsOp); + } if (foldingAlgorithm == Folding(2)) { return randomLocalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, originalTypes, qregType, fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } - // if (foldingAlgorithm == Folding(3)) { + // Else, if (foldingAlgorithm == Folding(3)): return allLocalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, originalTypes, qregType, fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); - // } - - StringAttr lib = deviceInitOp.getLibAttr(); - StringAttr name = deviceInitOp.getNameAttr(); - StringAttr kwargs = deviceInitOp.getKwargsAttr(); - - // Function folded: Create the folded circuit (withoutMeasurement * - // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements - rewriter.setInsertionPointToStart(moduleOp.getBody()); - SmallVector typesFolded(originalTypes.begin(), originalTypes.end()); - Type indexType = rewriter.getIndexType(); - typesFolded.push_back(indexType); - FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ - typesFolded, - /*outputs=*/fnOp.getResultTypes()); - - func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); - fnFoldedOp.setPrivate(); - - Block *foldedBloc = fnFoldedOp.addEntryBlock(); - rewriter.setInsertionPointToStart(foldedBloc); - // Add device - rewriter.create(loc, lib, name, kwargs); - TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); - Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); - Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); - - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - int64_t sizeArgs = fnFoldedOp.getArguments().size(); - Value size = fnFoldedOp.getArgument(sizeArgs - 1); - // Add scf for loop to create the folding - Value loopedQreg = - rewriter - .create( - loc, c0, size, c1, /*iterArgsInit=*/allocQreg, - [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - Value qreg = iterArgs.front(); - std::vector argsAndQreg(fnFoldedOp.getArguments().begin(), - fnFoldedOp.getArguments().end()); - argsAndQreg.pop_back(); - argsAndQreg.push_back(qreg); - - // Call the function without measurements - Value fnWithoutMeasurementsQreg = - builder.create(loc, fnWithoutMeasurementsOp, argsAndQreg) - .getResult(0); - - // Call the function without measurements in an adjoint region - auto adjointOp = builder.create(loc, qregType, - fnWithoutMeasurementsQreg); - Region *adjointRegion = &adjointOp.getRegion(); - Block *adjointBlock = builder.createBlock(adjointRegion, {}, qregType, loc); - - std::vector argsAndQregAdjoint(fnFoldedOp.getArguments().begin(), - fnFoldedOp.getArguments().end()); - argsAndQregAdjoint.pop_back(); - argsAndQregAdjoint.push_back(adjointBlock->getArgument(0)); - Value fnWithoutMeasurementsAdjointQreg = - builder - .create(loc, fnWithoutMeasurementsOp, argsAndQregAdjoint) - .getResult(0); - builder.create(loc, fnWithoutMeasurementsAdjointQreg); - builder.setInsertionPointAfter(adjointOp); - builder.create(loc, adjointOp.getResult()); - }) - .getResult(0); - std::vector argsAndRegMeasurement(fnFoldedOp.getArguments().begin(), - fnFoldedOp.getArguments().end()); - argsAndRegMeasurement.pop_back(); - argsAndRegMeasurement.push_back(loopedQreg); - ValueRange funcFolded = - rewriter.create(loc, fnWithMeasurementsOp, argsAndRegMeasurement) - .getResults(); - // Remove device - rewriter.create(loc); - rewriter.create(loc, funcFolded); - return SymbolRefAttr::get(ctx, fnFoldedName); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) From 0030d111e78fbdcfb83fc3e972c64a2d341cc269 Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Tue, 23 Jul 2024 18:25:07 +0200 Subject: [PATCH 19/94] ditch StrEnum as not supported in Python 3.10 --- frontend/catalyst/jax_primitives.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index a386f57552..9f815186e5 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -17,7 +17,7 @@ import sys from dataclasses import dataclass -from enum import StrEnum +from enum import Enum from itertools import chain from typing import Any, Dict, Iterable, List, Union @@ -196,7 +196,7 @@ def _obs_lowering(aval): mlir.ir_type_handlers[AbstractObs] = _obs_lowering -class Folding(StrEnum): +class Folding(Enum): """ Folding types supported by ZNE mitigation """ @@ -750,7 +750,7 @@ def _folding_attribute(ctx, folding): ctx = ctx.module_context.context return ir.OpaqueAttr.get( "mitigation", - ("folding " + Folding(folding)).encode("utf-8"), + ("folding " + Folding(folding).value).encode("utf-8"), ir.NoneType.get(ctx), ctx, ) From d4862e1cae05fc41f75e14a1f93e740bb1eb68f7 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 14:09:59 -0400 Subject: [PATCH 20/94] Throw on request for random folding (for now) --- frontend/catalyst/api_extensions/error_mitigation.py | 4 ++++ .../Mitigation/Transforms/MitigationMethods/Zne.cpp | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index d3c8a1d9b6..f594118d9f 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -92,6 +92,10 @@ def mitigated_circuit(args, n): s = jax.numpy.array([1, 2, 3]) return mitigate_with_zne(circuit, scale_factors=s)(args, n) """ + + if folding.upper() == "RANDOM": + raise NotImplementedError("Random global folding not yet implemented!") + kwargs = copy.copy(locals()) kwargs.pop("fn") diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index d3d2d7eba6..5963e0eb3a 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -219,6 +219,10 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, Ty func::FuncOp fnWithMeasurementsOp) { // TODO: Implement. + + // Can't throw, because disabled by compilation. + //throw std::logic_error("Random local folding not implemented!"); + return FlatSymbolRefAttr(); } FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, @@ -229,6 +233,12 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, Type func::FuncOp fnWithMeasurementsOp) { // TODO: Implement. + + // MLIRContext *ctx = rewriter.getContext(); + // StringAttr lib = deviceInitOp.getLibAttr(); + // StringAttr name = deviceInitOp.getNameAttr(); + // StringAttr kwargs = deviceInitOp.getKwargsAttr(); + return FlatSymbolRefAttr(); } FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, From 567589f43ddaf79b3d3557eb8445eac8b5c1e53c Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 14:43:15 -0400 Subject: [PATCH 21/94] More (incremental) code reuse --- .../Transforms/MitigationMethods/Zne.cpp | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 5963e0eb3a..7311da2eb0 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -128,8 +128,9 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const } FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, - TypeRange originalTypes, Type qregType, func::FuncOp fnAllocOp, - int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, + SmallVector typesFolded, Type qregType, + func::FuncOp fnAllocOp, int64_t numberQubits, + quantum::DeviceInitOp deviceInitOp, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { @@ -141,7 +142,6 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, Type sc // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements rewriter.setInsertionPointToStart(moduleOp.getBody()); - SmallVector typesFolded(originalTypes.begin(), originalTypes.end()); Type indexType = rewriter.getIndexType(); typesFolded.push_back(indexType); FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ @@ -213,22 +213,24 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, Type sc } FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, - TypeRange originalTypes, Type qregType, func::FuncOp fnAllocOp, - int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, + SmallVector typesFolded, Type qregType, + func::FuncOp fnAllocOp, int64_t numberQubits, + quantum::DeviceInitOp deviceInitOp, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { // TODO: Implement. // Can't throw, because disabled by compilation. - //throw std::logic_error("Random local folding not implemented!"); + // throw std::logic_error("Random local folding not implemented!"); return FlatSymbolRefAttr(); } FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, - TypeRange originalTypes, Type qregType, func::FuncOp fnAllocOp, - int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, + SmallVector typesFolded, Type qregType, + func::FuncOp fnAllocOp, int64_t numberQubits, + quantum::DeviceInitOp deviceInitOp, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { @@ -285,18 +287,20 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew func::FuncOp fnWithMeasurementsOp = SymbolTable::lookupNearestSymbolFrom(op, fnWithMeasurementsRefAttr); + SmallVector typesFolded(originalTypes.begin(), originalTypes.end()); + if (foldingAlgorithm == Folding(1)) { - return globalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, originalTypes, + return globalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, typesFolded, qregType, fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } if (foldingAlgorithm == Folding(2)) { return randomLocalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, - originalTypes, qregType, fnAllocOp, numberQubits, deviceInitOp, + typesFolded, qregType, fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } // Else, if (foldingAlgorithm == Folding(3)): - return allLocalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, originalTypes, + return allLocalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, typesFolded, qregType, fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } From caa1d6334bc5a7587257752ac9bca5085d92c0ba Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 14:46:26 -0400 Subject: [PATCH 22/94] More (incremental) code reuse --- .../Transforms/MitigationMethods/Zne.cpp | 42 +++++++++---------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 7311da2eb0..1ffb8e3d50 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -126,13 +126,11 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const // Replace the original results rewriter.replaceOp(op, resultValues); } -FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, - ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, - SmallVector typesFolded, Type qregType, - func::FuncOp fnAllocOp, int64_t numberQubits, - quantum::DeviceInitOp deviceInitOp, - func::FuncOp fnWithoutMeasurementsOp, - func::FuncOp fnWithMeasurementsOp) +FlatSymbolRefAttr +globalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, std::string fnFoldedName, + func::FuncOp fnOp, SmallVector typesFolded, Type qregType, + func::FuncOp fnAllocOp, int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, + func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { MLIRContext *ctx = rewriter.getContext(); StringAttr lib = deviceInitOp.getLibAttr(); @@ -141,7 +139,6 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, Type sc // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements - rewriter.setInsertionPointToStart(moduleOp.getBody()); Type indexType = rewriter.getIndexType(); typesFolded.push_back(indexType); FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ @@ -212,7 +209,7 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, Type sc return SymbolRefAttr::get(ctx, fnFoldedName); } FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, - ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, + std::string fnFoldedName, func::FuncOp fnOp, SmallVector typesFolded, Type qregType, func::FuncOp fnAllocOp, int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, @@ -226,13 +223,11 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, Ty return FlatSymbolRefAttr(); } -FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, - ModuleOp moduleOp, std::string fnFoldedName, func::FuncOp fnOp, - SmallVector typesFolded, Type qregType, - func::FuncOp fnAllocOp, int64_t numberQubits, - quantum::DeviceInitOp deviceInitOp, - func::FuncOp fnWithoutMeasurementsOp, - func::FuncOp fnWithMeasurementsOp) +FlatSymbolRefAttr +allLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, std::string fnFoldedName, + func::FuncOp fnOp, SmallVector typesFolded, Type qregType, + func::FuncOp fnAllocOp, int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, + func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { // TODO: Implement. @@ -288,20 +283,21 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew SymbolTable::lookupNearestSymbolFrom(op, fnWithMeasurementsRefAttr); SmallVector typesFolded(originalTypes.begin(), originalTypes.end()); + rewriter.setInsertionPointToStart(moduleOp.getBody()); if (foldingAlgorithm == Folding(1)) { - return globalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, typesFolded, - qregType, fnAllocOp, numberQubits, deviceInitOp, - fnWithoutMeasurementsOp, fnWithMeasurementsOp); + return globalFolding(loc, rewriter, scalarType, fnFoldedName, fnOp, typesFolded, qregType, + fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, + fnWithMeasurementsOp); } if (foldingAlgorithm == Folding(2)) { - return randomLocalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, - typesFolded, qregType, fnAllocOp, numberQubits, deviceInitOp, + return randomLocalFolding(loc, rewriter, scalarType, fnFoldedName, fnOp, typesFolded, + qregType, fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } // Else, if (foldingAlgorithm == Folding(3)): - return allLocalFolding(loc, rewriter, scalarType, moduleOp, fnFoldedName, fnOp, typesFolded, - qregType, fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, + return allLocalFolding(loc, rewriter, scalarType, fnFoldedName, fnOp, typesFolded, qregType, + fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, From 369a3face5437d89e5712872182214e1f656477d Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 15:01:45 -0400 Subject: [PATCH 23/94] Better (incremental) code reuse --- .../Transforms/MitigationMethods/Zne.cpp | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 1ffb8e3d50..378ad467cf 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -126,13 +126,15 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const // Replace the original results rewriter.replaceOp(op, resultValues); } -FlatSymbolRefAttr -globalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, std::string fnFoldedName, - func::FuncOp fnOp, SmallVector typesFolded, Type qregType, - func::FuncOp fnAllocOp, int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, - func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) +FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, + std::string fnFoldedName, func::FuncOp fnOp, + SmallVector typesFolded, func::FuncOp fnAllocOp, + int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, + func::FuncOp fnWithoutMeasurementsOp, + func::FuncOp fnWithMeasurementsOp) { MLIRContext *ctx = rewriter.getContext(); + Type qregType = quantum::QuregType::get(ctx); StringAttr lib = deviceInitOp.getLibAttr(); StringAttr name = deviceInitOp.getNameAttr(); StringAttr kwargs = deviceInitOp.getKwargsAttr(); @@ -210,9 +212,8 @@ globalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, std::str } FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, std::string fnFoldedName, func::FuncOp fnOp, - SmallVector typesFolded, Type qregType, - func::FuncOp fnAllocOp, int64_t numberQubits, - quantum::DeviceInitOp deviceInitOp, + SmallVector typesFolded, func::FuncOp fnAllocOp, + int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { @@ -223,11 +224,12 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, Ty return FlatSymbolRefAttr(); } -FlatSymbolRefAttr -allLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, std::string fnFoldedName, - func::FuncOp fnOp, SmallVector typesFolded, Type qregType, - func::FuncOp fnAllocOp, int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, - func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) +FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, + std::string fnFoldedName, func::FuncOp fnOp, + SmallVector typesFolded, func::FuncOp fnAllocOp, + int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, + func::FuncOp fnWithoutMeasurementsOp, + func::FuncOp fnWithMeasurementsOp) { // TODO: Implement. @@ -254,7 +256,6 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // Original function func::FuncOp fnOp = SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr()); TypeRange originalTypes = op.getArgs().getTypes(); - Type qregType = quantum::QuregType::get(ctx); // Set insertion in the module rewriter.setInsertionPointToStart(moduleOp.getBody()); @@ -286,18 +287,18 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew rewriter.setInsertionPointToStart(moduleOp.getBody()); if (foldingAlgorithm == Folding(1)) { - return globalFolding(loc, rewriter, scalarType, fnFoldedName, fnOp, typesFolded, qregType, - fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, + return globalFolding(loc, rewriter, scalarType, fnFoldedName, fnOp, typesFolded, fnAllocOp, + numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } if (foldingAlgorithm == Folding(2)) { return randomLocalFolding(loc, rewriter, scalarType, fnFoldedName, fnOp, typesFolded, - qregType, fnAllocOp, numberQubits, deviceInitOp, - fnWithoutMeasurementsOp, fnWithMeasurementsOp); + fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, + fnWithMeasurementsOp); } // Else, if (foldingAlgorithm == Folding(3)): - return allLocalFolding(loc, rewriter, scalarType, fnFoldedName, fnOp, typesFolded, qregType, - fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, + return allLocalFolding(loc, rewriter, scalarType, fnFoldedName, fnOp, typesFolded, fnAllocOp, + numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, From 5e03c0d20e194b4eae2f4cc4e7547f3d20384923 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 15:07:19 -0400 Subject: [PATCH 24/94] Remove unused argument --- .../Transforms/MitigationMethods/Zne.cpp | 35 +++++++++---------- .../Transforms/MitigationMethods/Zne.hpp | 2 +- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 378ad467cf..08c2ff22e8 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -49,8 +49,8 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const auto foldingAlgorithm = op.getFolding(); // Create the folded circuit function - FlatSymbolRefAttr foldedCircuitRefAttr = getOrInsertFoldedCircuit( - loc, rewriter, op, scaleFactorType.getElementType(), foldingAlgorithm); + FlatSymbolRefAttr foldedCircuitRefAttr = + getOrInsertFoldedCircuit(loc, rewriter, op, foldingAlgorithm); func::FuncOp foldedCircuit = SymbolTable::lookupNearestSymbolFrom(op, foldedCircuitRefAttr); @@ -126,10 +126,10 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const // Replace the original results rewriter.replaceOp(op, resultValues); } -FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, - std::string fnFoldedName, func::FuncOp fnOp, - SmallVector typesFolded, func::FuncOp fnAllocOp, - int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, +FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, + func::FuncOp fnOp, SmallVector typesFolded, + func::FuncOp fnAllocOp, int64_t numberQubits, + quantum::DeviceInitOp deviceInitOp, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { @@ -210,7 +210,7 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, Type sc rewriter.create(loc, funcFolded); return SymbolRefAttr::get(ctx, fnFoldedName); } -FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, +FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, func::FuncOp fnOp, SmallVector typesFolded, func::FuncOp fnAllocOp, int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, @@ -224,10 +224,10 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, Ty return FlatSymbolRefAttr(); } -FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, Type scalarType, - std::string fnFoldedName, func::FuncOp fnOp, - SmallVector typesFolded, func::FuncOp fnAllocOp, - int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, +FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, + func::FuncOp fnOp, SmallVector typesFolded, + func::FuncOp fnAllocOp, int64_t numberQubits, + quantum::DeviceInitOp deviceInitOp, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { @@ -241,7 +241,7 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, Type return FlatSymbolRefAttr(); } FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, - mitigation::ZneOp op, Type scalarType, + mitigation::ZneOp op, Folding foldingAlgorithm) { OpBuilder::InsertionGuard guard(rewriter); @@ -287,19 +287,18 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew rewriter.setInsertionPointToStart(moduleOp.getBody()); if (foldingAlgorithm == Folding(1)) { - return globalFolding(loc, rewriter, scalarType, fnFoldedName, fnOp, typesFolded, fnAllocOp, + return globalFolding(loc, rewriter, fnFoldedName, fnOp, typesFolded, fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } if (foldingAlgorithm == Folding(2)) { - return randomLocalFolding(loc, rewriter, scalarType, fnFoldedName, fnOp, typesFolded, - fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, + return randomLocalFolding(loc, rewriter, fnFoldedName, fnOp, typesFolded, fnAllocOp, + numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } // Else, if (foldingAlgorithm == Folding(3)): - return allLocalFolding(loc, rewriter, scalarType, fnFoldedName, fnOp, typesFolded, fnAllocOp, - numberQubits, deviceInitOp, fnWithoutMeasurementsOp, - fnWithMeasurementsOp); + return allLocalFolding(loc, rewriter, fnFoldedName, fnOp, typesFolded, fnAllocOp, numberQubits, + deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp index 6496116cea..32d871e974 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp @@ -32,7 +32,7 @@ struct ZneLowering : public OpRewritePattern { private: static FlatSymbolRefAttr getOrInsertFoldedCircuit(Location loc, PatternRewriter &builder, - mitigation::ZneOp op, Type scalarType, + mitigation::ZneOp op, Folding foldingAlgorithm); static FlatSymbolRefAttr getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op); From afe58f835165710c0ac1ee7935c4eec3383c49ad Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Tue, 23 Jul 2024 21:29:44 +0200 Subject: [PATCH 25/94] update changelog --- doc/changelog.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/doc/changelog.md b/doc/changelog.md index 7405eac51e..86487b3af8 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -207,10 +207,16 @@ instead of a tree. This means that we need to manually trace each term and finally multiply it with the coefficients to create a Hamiltonian. +* The function `mitigate_with_zne` accomodates a `folding` input argument for specifying the type of + circuit folding technique to be used by the error-mitigation routine + (only `global` value is supported to date.) + [(#946)](https://github.com/PennyLaneAI/catalyst/pull/946) +

Contributors

This release contains contributions from (in alphabetical order): +Alessandro Cosentino, Kunwar Maheep Singh, Mehrdad Malekmohammadi, Romain Moyard, From e365323b55392d81328e662c88feb80bc2eabf80 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 15:57:02 -0400 Subject: [PATCH 26/94] Helpful comment about *.cpp module function scope --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 08c2ff22e8..1610fccc68 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -126,6 +126,7 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const // Replace the original results rewriter.replaceOp(op, resultValues); } +// In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, func::FuncOp fnOp, SmallVector typesFolded, func::FuncOp fnAllocOp, int64_t numberQubits, @@ -210,6 +211,7 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st rewriter.create(loc, funcFolded); return SymbolRefAttr::get(ctx, fnFoldedName); } +// In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, func::FuncOp fnOp, SmallVector typesFolded, func::FuncOp fnAllocOp, @@ -224,6 +226,7 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, return FlatSymbolRefAttr(); } +// In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, func::FuncOp fnOp, SmallVector typesFolded, func::FuncOp fnAllocOp, int64_t numberQubits, From 01fa69c3d5340eafbb8462a2add5fafe1fb0e796 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 16:19:49 -0400 Subject: [PATCH 27/94] Don't repeat any code when avoidable --- .../Transforms/MitigationMethods/Zne.cpp | 58 ++++++++++--------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 1610fccc68..bb6ec9f593 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -128,18 +128,13 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const } // In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, - func::FuncOp fnOp, SmallVector typesFolded, - func::FuncOp fnAllocOp, int64_t numberQubits, - quantum::DeviceInitOp deviceInitOp, + MLIRContext *ctx, StringAttr lib, StringAttr name, + StringAttr kwargs, Type qregType, func::FuncOp fnOp, + SmallVector typesFolded, func::FuncOp fnAllocOp, + int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { - MLIRContext *ctx = rewriter.getContext(); - Type qregType = quantum::QuregType::get(ctx); - StringAttr lib = deviceInitOp.getLibAttr(); - StringAttr name = deviceInitOp.getNameAttr(); - StringAttr kwargs = deviceInitOp.getKwargsAttr(); - // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements Type indexType = rewriter.getIndexType(); @@ -212,12 +207,12 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st return SymbolRefAttr::get(ctx, fnFoldedName); } // In *.cpp module only, to keep extraneous headers out of *.hpp -FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, - std::string fnFoldedName, func::FuncOp fnOp, - SmallVector typesFolded, func::FuncOp fnAllocOp, - int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, - func::FuncOp fnWithoutMeasurementsOp, - func::FuncOp fnWithMeasurementsOp) +FlatSymbolRefAttr +randomLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, + MLIRContext *ctx, StringAttr lib, StringAttr name, StringAttr kwargs, + Type qregType, func::FuncOp fnOp, SmallVector typesFolded, + func::FuncOp fnAllocOp, int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, + func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { // TODO: Implement. @@ -228,9 +223,10 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, } // In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, - func::FuncOp fnOp, SmallVector typesFolded, - func::FuncOp fnAllocOp, int64_t numberQubits, - quantum::DeviceInitOp deviceInitOp, + MLIRContext *ctx, StringAttr lib, StringAttr name, + StringAttr kwargs, Type qregType, func::FuncOp fnOp, + SmallVector typesFolded, func::FuncOp fnAllocOp, + int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { @@ -258,7 +254,6 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // Original function func::FuncOp fnOp = SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr()); - TypeRange originalTypes = op.getArgs().getTypes(); // Set insertion in the module rewriter.setInsertionPointToStart(moduleOp.getBody()); @@ -273,6 +268,13 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // Get the device quantum::DeviceInitOp deviceInitOp = *fnOp.getOps().begin(); + Type qregType = quantum::QuregType::get(ctx); + TypeRange originalTypes = op.getArgs().getTypes(); + SmallVector typesFolded(originalTypes.begin(), originalTypes.end()); + StringAttr lib = deviceInitOp.getLibAttr(); + StringAttr name = deviceInitOp.getNameAttr(); + StringAttr kwargs = deviceInitOp.getKwargsAttr(); + // Function without measurements: Create function without measurements and with qreg as last // argument FlatSymbolRefAttr fnWithoutMeasurementsRefAttr = @@ -286,22 +288,22 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew func::FuncOp fnWithMeasurementsOp = SymbolTable::lookupNearestSymbolFrom(op, fnWithMeasurementsRefAttr); - SmallVector typesFolded(originalTypes.begin(), originalTypes.end()); rewriter.setInsertionPointToStart(moduleOp.getBody()); if (foldingAlgorithm == Folding(1)) { - return globalFolding(loc, rewriter, fnFoldedName, fnOp, typesFolded, fnAllocOp, - numberQubits, deviceInitOp, fnWithoutMeasurementsOp, - fnWithMeasurementsOp); + return globalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, fnOp, + typesFolded, fnAllocOp, numberQubits, deviceInitOp, + fnWithoutMeasurementsOp, fnWithMeasurementsOp); } if (foldingAlgorithm == Folding(2)) { - return randomLocalFolding(loc, rewriter, fnFoldedName, fnOp, typesFolded, fnAllocOp, - numberQubits, deviceInitOp, fnWithoutMeasurementsOp, - fnWithMeasurementsOp); + return randomLocalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, + fnOp, typesFolded, fnAllocOp, numberQubits, deviceInitOp, + fnWithoutMeasurementsOp, fnWithMeasurementsOp); } // Else, if (foldingAlgorithm == Folding(3)): - return allLocalFolding(loc, rewriter, fnFoldedName, fnOp, typesFolded, fnAllocOp, numberQubits, - deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); + return allLocalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, fnOp, + typesFolded, fnAllocOp, numberQubits, deviceInitOp, + fnWithoutMeasurementsOp, fnWithMeasurementsOp); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) From 7dcbb1a8022f87bd37640f3ed0977e8c2aa36453 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 16:43:13 -0400 Subject: [PATCH 28/94] Better code reuse --- .../Transforms/MitigationMethods/Zne.cpp | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index bb6ec9f593..294c74bc99 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -137,8 +137,6 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st { // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements - Type indexType = rewriter.getIndexType(); - typesFolded.push_back(indexType); FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ typesFolded, /*outputs=*/fnOp.getResultTypes()); @@ -232,11 +230,6 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: { // TODO: Implement. - // MLIRContext *ctx = rewriter.getContext(); - // StringAttr lib = deviceInitOp.getLibAttr(); - // StringAttr name = deviceInitOp.getNameAttr(); - // StringAttr kwargs = deviceInitOp.getKwargsAttr(); - return FlatSymbolRefAttr(); } FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, @@ -269,12 +262,15 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew quantum::DeviceInitOp deviceInitOp = *fnOp.getOps().begin(); Type qregType = quantum::QuregType::get(ctx); - TypeRange originalTypes = op.getArgs().getTypes(); - SmallVector typesFolded(originalTypes.begin(), originalTypes.end()); StringAttr lib = deviceInitOp.getLibAttr(); StringAttr name = deviceInitOp.getNameAttr(); StringAttr kwargs = deviceInitOp.getKwargsAttr(); + TypeRange originalTypes = op.getArgs().getTypes(); + SmallVector typesFolded(originalTypes.begin(), originalTypes.end()); + Type indexType = rewriter.getIndexType(); + typesFolded.push_back(indexType); + // Function without measurements: Create function without measurements and with qreg as last // argument FlatSymbolRefAttr fnWithoutMeasurementsRefAttr = From 2d2f26580c8e174acb79f062db2dee77bc38d940 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 16:51:50 -0400 Subject: [PATCH 29/94] Don't need deviceInitOp in functions --- .../Transforms/MitigationMethods/Zne.cpp | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 294c74bc99..1acc8bc527 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -129,7 +129,7 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const // In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, MLIRContext *ctx, StringAttr lib, StringAttr name, - StringAttr kwargs, Type qregType, func::FuncOp fnOp, + StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnAllocOp, int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, func::FuncOp fnWithoutMeasurementsOp, @@ -137,10 +137,6 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st { // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements - FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ - typesFolded, - /*outputs=*/fnOp.getResultTypes()); - func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); fnFoldedOp.setPrivate(); @@ -208,7 +204,7 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, MLIRContext *ctx, StringAttr lib, StringAttr name, StringAttr kwargs, - Type qregType, func::FuncOp fnOp, SmallVector typesFolded, + Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnAllocOp, int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { @@ -222,7 +218,7 @@ randomLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFolded // In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, MLIRContext *ctx, StringAttr lib, StringAttr name, - StringAttr kwargs, Type qregType, func::FuncOp fnOp, + StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnAllocOp, int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, func::FuncOp fnWithoutMeasurementsOp, @@ -286,19 +282,23 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew rewriter.setInsertionPointToStart(moduleOp.getBody()); + FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ + typesFolded, + /*outputs=*/fnOp.getResultTypes()); + if (foldingAlgorithm == Folding(1)) { - return globalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, fnOp, - typesFolded, fnAllocOp, numberQubits, deviceInitOp, + return globalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, + fnFoldedType, typesFolded, fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } if (foldingAlgorithm == Folding(2)) { return randomLocalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, - fnOp, typesFolded, fnAllocOp, numberQubits, deviceInitOp, + fnFoldedType, typesFolded, fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } // Else, if (foldingAlgorithm == Folding(3)): - return allLocalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, fnOp, - typesFolded, fnAllocOp, numberQubits, deviceInitOp, + return allLocalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, + fnFoldedType, typesFolded, fnAllocOp, numberQubits, deviceInitOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, From 8ef9cdea7a2b96eaa385d1b5e00bc050d9a38456 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 23 Jul 2024 16:55:31 -0400 Subject: [PATCH 30/94] Don't need deviceInitOp in functions --- .../Transforms/MitigationMethods/Zne.cpp | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 1acc8bc527..8eb4e3d6ec 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -131,8 +131,7 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st MLIRContext *ctx, StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnAllocOp, - int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, - func::FuncOp fnWithoutMeasurementsOp, + int64_t numberQubits, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { // Function folded: Create the folded circuit (withoutMeasurement * @@ -201,12 +200,13 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st return SymbolRefAttr::get(ctx, fnFoldedName); } // In *.cpp module only, to keep extraneous headers out of *.hpp -FlatSymbolRefAttr -randomLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, - MLIRContext *ctx, StringAttr lib, StringAttr name, StringAttr kwargs, - Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, - func::FuncOp fnAllocOp, int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, - func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) +FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, + std::string fnFoldedName, MLIRContext *ctx, StringAttr lib, + StringAttr name, StringAttr kwargs, Type qregType, + FunctionType fnFoldedType, SmallVector typesFolded, + func::FuncOp fnAllocOp, int64_t numberQubits, + func::FuncOp fnWithoutMeasurementsOp, + func::FuncOp fnWithMeasurementsOp) { // TODO: Implement. @@ -220,8 +220,7 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: MLIRContext *ctx, StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnAllocOp, - int64_t numberQubits, quantum::DeviceInitOp deviceInitOp, - func::FuncOp fnWithoutMeasurementsOp, + int64_t numberQubits, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { // TODO: Implement. @@ -288,17 +287,17 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew if (foldingAlgorithm == Folding(1)) { return globalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, - fnFoldedType, typesFolded, fnAllocOp, numberQubits, deviceInitOp, + fnFoldedType, typesFolded, fnAllocOp, numberQubits, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } if (foldingAlgorithm == Folding(2)) { return randomLocalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, - fnFoldedType, typesFolded, fnAllocOp, numberQubits, deviceInitOp, + fnFoldedType, typesFolded, fnAllocOp, numberQubits, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } // Else, if (foldingAlgorithm == Folding(3)): return allLocalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, - fnFoldedType, typesFolded, fnAllocOp, numberQubits, deviceInitOp, + fnFoldedType, typesFolded, fnAllocOp, numberQubits, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, From 560b014b8fa61806585d40dd06dfffa2c2ffc936 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 30 Jul 2024 11:06:29 -0400 Subject: [PATCH 31/94] Simplify folding functions --- .../Transforms/MitigationMethods/Zne.cpp | 70 +++++++++---------- 1 file changed, 34 insertions(+), 36 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 8eb4e3d6ec..7ac75afcc2 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -128,25 +128,14 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const } // In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, - MLIRContext *ctx, StringAttr lib, StringAttr name, - StringAttr kwargs, Type qregType, FunctionType fnFoldedType, - SmallVector typesFolded, func::FuncOp fnAllocOp, - int64_t numberQubits, func::FuncOp fnWithoutMeasurementsOp, + StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, + FunctionType fnFoldedType, SmallVector typesFolded, + func::FuncOp fnFoldedOp, Value allocQreg, + func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements - func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); - fnFoldedOp.setPrivate(); - - Block *foldedBloc = fnFoldedOp.addEntryBlock(); - rewriter.setInsertionPointToStart(foldedBloc); - // Add device - rewriter.create(loc, lib, name, kwargs); - TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); - Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); - Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); - Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); int64_t sizeArgs = fnFoldedOp.getArguments().size(); @@ -197,16 +186,14 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st // Remove device rewriter.create(loc); rewriter.create(loc, funcFolded); - return SymbolRefAttr::get(ctx, fnFoldedName); + return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); } // In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, - std::string fnFoldedName, MLIRContext *ctx, StringAttr lib, - StringAttr name, StringAttr kwargs, Type qregType, - FunctionType fnFoldedType, SmallVector typesFolded, - func::FuncOp fnAllocOp, int64_t numberQubits, - func::FuncOp fnWithoutMeasurementsOp, - func::FuncOp fnWithMeasurementsOp) + std::string fnFoldedName, StringAttr lib, StringAttr name, + StringAttr kwargs, Type qregType, FunctionType fnFoldedType, + SmallVector typesFolded, func::FuncOp fnFoldedOp, + Value allocQreg, func::FuncOp fnWithMeasurementsOp) { // TODO: Implement. @@ -217,13 +204,14 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, } // In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, - MLIRContext *ctx, StringAttr lib, StringAttr name, - StringAttr kwargs, Type qregType, FunctionType fnFoldedType, - SmallVector typesFolded, func::FuncOp fnAllocOp, - int64_t numberQubits, func::FuncOp fnWithoutMeasurementsOp, + StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, + FunctionType fnFoldedType, SmallVector typesFolded, + func::FuncOp fnFoldedOp, Value allocQreg, func::FuncOp fnWithMeasurementsOp) { - // TODO: Implement. + fnWithMeasurementsOp.walk([&](mlir::Operation *op) { + // process Operation `op`. + }); return FlatSymbolRefAttr(); } @@ -285,20 +273,30 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew typesFolded, /*outputs=*/fnOp.getResultTypes()); + func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); + fnFoldedOp.setPrivate(); + + Block *foldedBlock = fnFoldedOp.addEntryBlock(); + rewriter.setInsertionPointToStart(foldedBlock); + // Add device + rewriter.create(loc, lib, name, kwargs); + TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); + Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); + Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); + if (foldingAlgorithm == Folding(1)) { - return globalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, - fnFoldedType, typesFolded, fnAllocOp, numberQubits, - fnWithoutMeasurementsOp, fnWithMeasurementsOp); + return globalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, fnFoldedType, + typesFolded, fnFoldedOp, allocQreg, fnWithoutMeasurementsOp, + fnWithMeasurementsOp); } if (foldingAlgorithm == Folding(2)) { - return randomLocalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, - fnFoldedType, typesFolded, fnAllocOp, numberQubits, - fnWithoutMeasurementsOp, fnWithMeasurementsOp); + return randomLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, + fnFoldedType, typesFolded, fnFoldedOp, allocQreg, + fnWithMeasurementsOp); } // Else, if (foldingAlgorithm == Folding(3)): - return allLocalFolding(loc, rewriter, fnFoldedName, ctx, lib, name, kwargs, qregType, - fnFoldedType, typesFolded, fnAllocOp, numberQubits, - fnWithoutMeasurementsOp, fnWithMeasurementsOp); + return allLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, fnFoldedType, + typesFolded, fnFoldedOp, allocQreg, fnWithMeasurementsOp); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) From 559317eef6e12f575d6716d1b806a2fae3bfdcab Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 31 Jul 2024 11:12:07 -0400 Subject: [PATCH 32/94] Incomplete code --- .../Transforms/MitigationMethods/Zne.cpp | 49 ++++++++++++++----- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 7ac75afcc2..20b3fc41c8 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -130,12 +130,16 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, - func::FuncOp fnFoldedOp, Value allocQreg, + func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, const int64_t numberQubits, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements + TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); + Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); + Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); + Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); int64_t sizeArgs = fnFoldedOp.getArguments().size(); @@ -193,7 +197,7 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, - Value allocQreg, func::FuncOp fnWithMeasurementsOp) + func::FuncOp fnAllocOp, const int64_t numberQubits, func::FuncOp fnWithMeasurementsOp) { // TODO: Implement. @@ -206,14 +210,36 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, - func::FuncOp fnFoldedOp, Value allocQreg, + func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, const int64_t numberQubits, func::FuncOp fnWithMeasurementsOp) { - fnWithMeasurementsOp.walk([&](mlir::Operation *op) { - // process Operation `op`. - }); + TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); + Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); + rewriter.create(loc, fnAllocOp, numberQubitsValue); - return FlatSymbolRefAttr(); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + int64_t sizeArgs = fnFoldedOp.getArguments().size(); + Value size = fnFoldedOp.getArgument(sizeArgs - 1); + + fnWithMeasurementsOp.walk([&](quantum::QubitUnitaryOp *op) { + // TODO: Skip measurements and control structures. + // Add scf for loop to create the folding + rewriter.setInsertionPointAfter(op); + rewriter + .create( + loc, c0, size, c1, ValueRange(), + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { + // Call the function without measurements in an adjoint region + auto adjointOp = builder.create(loc, qregType,*op); + auto origOp = builder.create(loc, qregType,adjointOp); + builder.setInsertionPointAfter(origOp); + builder.create(loc, origOp.getResult()); + }); + }); + // Remove device + rewriter.create(loc); + return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); } FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op, @@ -280,23 +306,20 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew rewriter.setInsertionPointToStart(foldedBlock); // Add device rewriter.create(loc, lib, name, kwargs); - TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); - Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); - Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); if (foldingAlgorithm == Folding(1)) { return globalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, fnFoldedType, - typesFolded, fnFoldedOp, allocQreg, fnWithoutMeasurementsOp, + typesFolded, fnFoldedOp, numberQubits, fnAllocOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } if (foldingAlgorithm == Folding(2)) { return randomLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, - fnFoldedType, typesFolded, fnFoldedOp, allocQreg, + fnFoldedType, typesFolded, fnFoldedOp, numberQubits, fnAllocOp, fnWithMeasurementsOp); } // Else, if (foldingAlgorithm == Folding(3)): return allLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, fnFoldedType, - typesFolded, fnFoldedOp, allocQreg, fnWithMeasurementsOp); + typesFolded, fnFoldedOp, numberQubits, fnAllocOp, fnWithMeasurementsOp); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) From fac618a3b0321996e2217790e392957b2f3095b7 Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Wed, 31 Jul 2024 17:35:19 +0200 Subject: [PATCH 33/94] MLIR test for ZNE local folding --- mlir/test/Mitigation/ZneFoldingAllTest.mlir | 94 +++++++++++++++++++ .../{zne.mlir => ZneFoldingGlobalTest.mlir} | 0 2 files changed, 94 insertions(+) create mode 100644 mlir/test/Mitigation/ZneFoldingAllTest.mlir rename mlir/test/Mitigation/{zne.mlir => ZneFoldingGlobalTest.mlir} (100%) diff --git a/mlir/test/Mitigation/ZneFoldingAllTest.mlir b/mlir/test/Mitigation/ZneFoldingAllTest.mlir new file mode 100644 index 0000000000..5b18e022d8 --- /dev/null +++ b/mlir/test/Mitigation/ZneFoldingAllTest.mlir @@ -0,0 +1,94 @@ +// Copyright 2023 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt %s --lower-mitigation --split-input-file --verify-diagnostics | FileCheck %s + +func.func @circuit() -> tensor attributes {qnode} { + quantum.device ["rtd_lightning.so", "LightningQubit", "{shots: 0}"] + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %out_qubits_0:2 = quantum.custom "CNOT"() %out_qubits, %2 : !quantum.bit, !quantum.bit + %3 = quantum.namedobs %out_qubits_0#0[ PauliY] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %0[ 0], %out_qubits_0#0 : !quantum.reg, !quantum.bit + %6 = quantum.insert %5[ 1], %out_qubits_0#1 : !quantum.reg, !quantum.bit + quantum.dealloc %6 : !quantum.reg + quantum.device_release + return %from_elements : tensor +} + + // CHECK: func.func private @circuit.folded(%arg0: index) -> tensor { + // CHECK: %c2_i64 = arith.constant 2 : i64 + // CHECK: %idx0 = index.constant 0 + // CHECK: %idx1 = index.constant 1 + // CHECK: quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] + // CHECK: %0 = call @circuit.quantumAlloc(%c2_i64) : (i64) -> !quantum.reg + // CHECK: %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + // CHECK: %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + // CHECK: %2 = scf.for %arg1 = %idx0 to %arg0 step %idx1 -> (!quantum.bit) { + // CHECK: %out_qubits = quantum.custom "Hadamard"() %out_qubits {adjoint} : !quantum.bit + // CHECK: %out_qubits = quantum.custom "Hadamard"() %out_qubits : !quantum.bit + // CHECK: scf.yield %out_qubits: !quantum.bit + // CHECK: } + // CHECK: %3 = quantum.extract %arg0[ 1] : !quantum.reg -> !quantum.bit + // CHECK: %out_qubits_0:2 = quantum.custom "CNOT"() %2, %3 : !quantum.bit, !quantum.bit + // CHECK: %4 = scf.for %arg1 = %idx0 to %arg0 step %idx1 -> (!quantum.bit, !quantum.bit) { + // CHECK: %out_qubits_0:2 = quantum.custom "CNOT"() %out_qubits_0#0, %out_qubits_0#1 {adjoint} : !quantum.bit, !quantum.bit + // CHECK: %out_qubits_0:2 = quantum.custom "CNOT"() %out_qubits_0#0, %out_qubits_0#1 : !quantum.bit, !quantum.bit + // CHECK: scf.yield %out_qubits_0 : (!quantum.bit, !quantum.bit) + // CHECK: } + // CHECK: %5 = quantum.namedobs %out_qubits_0#0[ PauliY] : !quantum.obs + // CHECK: %6 = quantum.expval %5 : f64 + // CHECK: %from_elements = tensor.from_elements %6 : tensor + // CHECK: %7 = quantum.insert %0[ 0], %out_qubits_0#0 : !quantum.reg, !quantum.bit + // CHECK: %8 = quantum.insert %7[ 1], %out_qubits_0#1 : !quantum.reg, !quantum.bit + // CHECK: quantum.dealloc %8 : !quantum.reg + // CHECK: quantum.device_release + // CHECK: return %from_elements : tensor + // CHECK: } + + // CHECK: func.func private @circuit.quantumAlloc(%arg0: i64) -> !quantum.reg { + // CHECK: %0 = quantum.alloc(%arg0) : !quantum.reg + // CHECK: return %0 : !quantum.reg + // CHECK: } + + +func.func @mitigated_circuit() -> tensor<3xf64> { + %scaleFactors = arith.constant dense<[1, 2, 3]> : tensor<3xindex> + %0 = mitigation.zne @circuit() folding (all) scaleFactors (%scaleFactors : tensor<3xindex>) : () -> tensor<3xf64> + func.return %0 : tensor<3xf64> +} +//CHECK: func.func @mitigated_circuit() -> tensor<3xf64> { + //CHECK: %idx0 = index.constant 0 + //CHECK: %idx1 = index.constant 1 + //CHECK: %idx3 = index.constant 3 + //CHECK: %cst = arith.constant dense<[1, 2, 3]> : tensor<3xindex> + //CHECK: %0 = tensor.empty() : tensor<3xf64> + //CHECK: %1 = scf.for %arg0 = %idx0 to %idx3 step %idx1 iter_args(%arg1 = %0) -> (tensor<3xf64>) { + //CHECK: %extracted = tensor.extract %cst[%arg0] : tensor<3xindex> + //CHECK: %2 = func.call @circuit.folded(%extracted) : (index) -> tensor + //CHECK: %extracted_0 = tensor.extract %2[] : tensor + //CHECK: %from_elements = tensor.from_elements %extracted_0 : tensor<1xf64> + //CHECK: %3 = scf.for %arg2 = %idx0 to %idx1 step %idx1 iter_args(%arg3 = %arg1) -> (tensor<3xf64>) { + //CHECK: %extracted_1 = tensor.extract %from_elements[%arg2] : tensor<1xf64> + //CHECK: %inserted = tensor.insert %extracted_1 into %arg3[%arg0] : tensor<3xf64> + //CHECK: scf.yield %inserted : tensor<3xf64> + //CHECK: } + //CHECK: scf.yield %3 : tensor<3xf64> + //CHECK: } + //CHECK: return %1 : tensor<3xf64> +//CHECK: } diff --git a/mlir/test/Mitigation/zne.mlir b/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir similarity index 100% rename from mlir/test/Mitigation/zne.mlir rename to mlir/test/Mitigation/ZneFoldingGlobalTest.mlir From eeb7f8cf8b7a9ab8a207a3d02f02363557e6b6e5 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 1 Aug 2024 09:29:24 -0400 Subject: [PATCH 34/94] Fix function prototypes/calls (doesn't compile) --- .../Transforms/MitigationMethods/Zne.cpp | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 20b3fc41c8..bb4c833806 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -130,8 +130,8 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, - func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, const int64_t numberQubits, - func::FuncOp fnWithoutMeasurementsOp, + func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, + const int64_t numberQubits, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp) { // Function folded: Create the folded circuit (withoutMeasurement * @@ -197,7 +197,8 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, - func::FuncOp fnAllocOp, const int64_t numberQubits, func::FuncOp fnWithMeasurementsOp) + func::FuncOp fnAllocOp, const int64_t numberQubits, + func::FuncOp fnWithMeasurementsOp) { // TODO: Implement. @@ -210,8 +211,8 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, - func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, const int64_t numberQubits, - func::FuncOp fnWithMeasurementsOp) + func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, + const int64_t numberQubits, func::FuncOp fnWithMeasurementsOp) { TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); @@ -225,17 +226,16 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: fnWithMeasurementsOp.walk([&](quantum::QubitUnitaryOp *op) { // TODO: Skip measurements and control structures. // Add scf for loop to create the folding - rewriter.setInsertionPointAfter(op); - rewriter - .create( - loc, c0, size, c1, ValueRange(), - [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - // Call the function without measurements in an adjoint region - auto adjointOp = builder.create(loc, qregType,*op); - auto origOp = builder.create(loc, qregType,adjointOp); - builder.setInsertionPointAfter(origOp); - builder.create(loc, origOp.getResult()); - }); + rewriter.setInsertionPointAfter((mlir::Operation *)op); + rewriter.create( + loc, c0, size, c1, ValueRange(), + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { + // Call the function without measurements in an adjoint region + auto adjointOp = builder.create(loc, qregType, *op); + auto origOp = builder.create(loc, qregType, adjointOp); + builder.setInsertionPointAfter(origOp); + builder.create(loc, origOp.getResult()); + }); }); // Remove device rewriter.create(loc); @@ -309,17 +309,17 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew if (foldingAlgorithm == Folding(1)) { return globalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, fnFoldedType, - typesFolded, fnFoldedOp, numberQubits, fnAllocOp, fnWithoutMeasurementsOp, - fnWithMeasurementsOp); + typesFolded, fnFoldedOp, fnAllocOp, numberQubits, + fnWithoutMeasurementsOp, fnWithMeasurementsOp); } if (foldingAlgorithm == Folding(2)) { return randomLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, - fnFoldedType, typesFolded, fnFoldedOp, numberQubits, fnAllocOp, + fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, numberQubits, fnWithMeasurementsOp); } // Else, if (foldingAlgorithm == Folding(3)): return allLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, fnFoldedType, - typesFolded, fnFoldedOp, numberQubits, fnAllocOp, fnWithMeasurementsOp); + typesFolded, fnFoldedOp, fnAllocOp, numberQubits, fnWithMeasurementsOp); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) From 3f3967959a45c68d5bc7d570067ddb9c4fdeb696 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 1 Aug 2024 09:43:50 -0400 Subject: [PATCH 35/94] Fix AdjointOp creation --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index bb4c833806..21e59e141d 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -231,8 +231,8 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: loc, c0, size, c1, ValueRange(), [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { // Call the function without measurements in an adjoint region - auto adjointOp = builder.create(loc, qregType, *op); - auto origOp = builder.create(loc, qregType, adjointOp); + auto adjointOp = builder.create(loc, qregType, (*op).getResult(0)); + auto origOp = builder.create(loc, qregType, adjointOp.getResult()); builder.setInsertionPointAfter(origOp); builder.create(loc, origOp.getResult()); }); From 1c80a2933fe08fb6c419f81dbc38ceded8dd9bbe Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 1 Aug 2024 09:55:42 -0400 Subject: [PATCH 36/94] Use CallOp --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 21e59e141d..609fc70bd5 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -238,7 +238,15 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: }); }); // Remove device + std::vector argsAndRegMeasurement(fnFoldedOp.getArguments().begin(), + fnFoldedOp.getArguments().end()); + argsAndRegMeasurement.pop_back(); + argsAndRegMeasurement.push_back(fnWithMeasurementsOp->getResult(0)); + ValueRange funcFolded = + rewriter.create(loc, fnWithMeasurementsOp, argsAndRegMeasurement) + .getResults(); rewriter.create(loc); + rewriter.create(loc, funcFolded); return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); } FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, From 5d6f07c7e1d1210e6ec00cd4a5730e9a7e3e95d6 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 1 Aug 2024 11:08:18 -0400 Subject: [PATCH 37/94] Fix walk() on QubitUnitaryOp --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 609fc70bd5..1730de1fac 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -223,16 +223,17 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: int64_t sizeArgs = fnFoldedOp.getArguments().size(); Value size = fnFoldedOp.getArgument(sizeArgs - 1); - fnWithMeasurementsOp.walk([&](quantum::QubitUnitaryOp *op) { + fnWithMeasurementsOp.walk([&](quantum::QubitUnitaryOp op) { // TODO: Skip measurements and control structures. // Add scf for loop to create the folding - rewriter.setInsertionPointAfter((mlir::Operation *)op); + rewriter.setInsertionPointAfter(op); rewriter.create( loc, c0, size, c1, ValueRange(), [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { // Call the function without measurements in an adjoint region - auto adjointOp = builder.create(loc, qregType, (*op).getResult(0)); - auto origOp = builder.create(loc, qregType, adjointOp.getResult()); + auto adjointOp = builder.create(loc, qregType, op.getResult(0)); + auto origOp = + builder.create(loc, qregType, adjointOp.getResult()); builder.setInsertionPointAfter(origOp); builder.create(loc, origOp.getResult()); }); From 4ab01313f261c7ffaab1688875f7043f1248b43b Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 1 Aug 2024 14:22:50 -0400 Subject: [PATCH 38/94] Code reuse --- .../Transforms/MitigationMethods/Zne.cpp | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 102910e5e7..2bcdefb3ec 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -132,7 +132,7 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, const int64_t numberQubits, func::FuncOp fnWithoutMeasurementsOp, - func::FuncOp fnWithMeasurementsOp) + func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) { // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements @@ -140,8 +140,6 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); int64_t sizeArgs = fnFoldedOp.getArguments().size(); Value size = fnFoldedOp.getArgument(sizeArgs - 1); // Add scf for loop to create the folding @@ -198,7 +196,7 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, const int64_t numberQubits, - func::FuncOp fnWithMeasurementsOp) + func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) { // TODO: Implement. @@ -212,14 +210,13 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, - const int64_t numberQubits, func::FuncOp fnWithMeasurementsOp) + const int64_t numberQubits, func::FuncOp fnWithMeasurementsOp, + Value c0, Value c1) { TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); rewriter.create(loc, fnAllocOp, numberQubitsValue); - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); int64_t sizeArgs = fnFoldedOp.getArguments().size(); Value size = fnFoldedOp.getArgument(sizeArgs - 1); @@ -313,22 +310,26 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew Block *foldedBlock = fnFoldedOp.addEntryBlock(); rewriter.setInsertionPointToStart(foldedBlock); + // Loop control variables + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); // Add device rewriter.create(loc, lib, name, kwargs); if (foldingAlgorithm == Folding(1)) { return globalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, numberQubits, - fnWithoutMeasurementsOp, fnWithMeasurementsOp); + fnWithoutMeasurementsOp, fnWithMeasurementsOp, c0, c1); } if (foldingAlgorithm == Folding(2)) { return randomLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, numberQubits, - fnWithMeasurementsOp); + fnWithMeasurementsOp, c0, c1); } // Else, if (foldingAlgorithm == Folding(3)): return allLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, fnFoldedType, - typesFolded, fnFoldedOp, fnAllocOp, numberQubits, fnWithMeasurementsOp); + typesFolded, fnFoldedOp, fnAllocOp, numberQubits, fnWithMeasurementsOp, + c0, c1); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) From 48c5411645db10e8a44e7a14a17d3c63cbe233c1 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 1 Aug 2024 14:48:46 -0400 Subject: [PATCH 39/94] Code reuse --- .../Transforms/MitigationMethods/Zne.cpp | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 2bcdefb3ec..da9a6e21ec 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -131,13 +131,11 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, - const int64_t numberQubits, func::FuncOp fnWithoutMeasurementsOp, + Value numberQubitsValue, func::FuncOp fnWithoutMeasurementsOp, func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) { // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements - TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); - Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); int64_t sizeArgs = fnFoldedOp.getArguments().size(); @@ -195,7 +193,7 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, - func::FuncOp fnAllocOp, const int64_t numberQubits, + func::FuncOp fnAllocOp, Value numberQubitsValue, func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) { // TODO: Implement. @@ -210,11 +208,9 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, - const int64_t numberQubits, func::FuncOp fnWithMeasurementsOp, + Value numberQubitsValue, func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) { - TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); - Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); rewriter.create(loc, fnAllocOp, numberQubitsValue); int64_t sizeArgs = fnFoldedOp.getArguments().size(); @@ -310,6 +306,8 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew Block *foldedBlock = fnFoldedOp.addEntryBlock(); rewriter.setInsertionPointToStart(foldedBlock); + TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); + Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); // Loop control variables Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); @@ -318,18 +316,18 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew if (foldingAlgorithm == Folding(1)) { return globalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, fnFoldedType, - typesFolded, fnFoldedOp, fnAllocOp, numberQubits, + typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, fnWithoutMeasurementsOp, fnWithMeasurementsOp, c0, c1); } if (foldingAlgorithm == Folding(2)) { return randomLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, - fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, numberQubits, - fnWithMeasurementsOp, c0, c1); + fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, + numberQubitsValue, fnWithMeasurementsOp, c0, c1); } // Else, if (foldingAlgorithm == Folding(3)): return allLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, fnFoldedType, - typesFolded, fnFoldedOp, fnAllocOp, numberQubits, fnWithMeasurementsOp, - c0, c1); + typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, + fnWithMeasurementsOp, c0, c1); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) From bb12285c67e2e41e3ba1f2397a3297f83f75e4e4 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 1 Aug 2024 15:38:32 -0400 Subject: [PATCH 40/94] Cut redundant rewriter.setInsertionPointAfter() --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index da9a6e21ec..c162c17cea 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -227,7 +227,6 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: auto adjointOp = builder.create(loc, qregType, op.getResult(0)); auto origOp = builder.create(loc, qregType, adjointOp.getResult()); - builder.setInsertionPointAfter(origOp); builder.create(loc, origOp.getResult()); }); }); From 5114e2822b57e240a25f7d3fe55032f8a551a0c4 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 1 Aug 2024 16:01:31 -0400 Subject: [PATCH 41/94] Advance walk() --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index c162c17cea..861e9fa713 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -229,6 +229,7 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: builder.create(loc, qregType, adjointOp.getResult()); builder.create(loc, origOp.getResult()); }); + return WalkResult::advance(); }); // Remove device std::vector argsAndRegMeasurement(fnFoldedOp.getArguments().begin(), From 8bedc732a90014658e7a9ea0de3812c1621e3a8c Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Fri, 2 Aug 2024 19:45:23 +0200 Subject: [PATCH 42/94] improve test with FileCheck best practices --- mlir/test/Mitigation/ZneFoldingAllTest.mlir | 103 ++++++++---------- .../test/Mitigation/ZneFoldingGlobalTest.mlir | 93 ++++++++-------- 2 files changed, 94 insertions(+), 102 deletions(-) diff --git a/mlir/test/Mitigation/ZneFoldingAllTest.mlir b/mlir/test/Mitigation/ZneFoldingAllTest.mlir index 5b18e022d8..10352c4e47 100644 --- a/mlir/test/Mitigation/ZneFoldingAllTest.mlir +++ b/mlir/test/Mitigation/ZneFoldingAllTest.mlir @@ -14,6 +14,38 @@ // RUN: quantum-opt %s --lower-mitigation --split-input-file --verify-diagnostics | FileCheck %s +// CHECK: func.func private @circuit.folded(%arg0: index) -> tensor { + // CHECK: [[nQubits:%.+]] = arith.constant 2 + // CHECK: [[c0:%.+]] = index.constant 0 + // CHECK: [[c1:%.+]] = index.constant 1 + // CHECK: quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] + // CHECK: [[qReg:%.+]] = call @circuit.quantumAlloc([[nQubits]]) : (i64) -> !quantum.reg + // CHECK: [[q0:%.+]] = quantum.extract [[qReg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[q0_out:%.+]] = quantum.custom "Hadamard"() [[q0]] : !quantum.bit + // CHECK: [[q0_out_1:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] -> (!quantum.bit) { + // CHECK: [[q0_out]] = quantum.custom "Hadamard"() [[q0_out]] {adjoint} : !quantum.bit + // CHECK: [[q0_out]] = quantum.custom "Hadamard"() [[q0_out]] : !quantum.bit + // CHECK: scf.yield [[q0_out]]: !quantum.bit + // CHECK: [[%q1:%.+]] = quantum.extract [[qReg]][ 1] : !quantum.reg -> !quantum.bit + // CHECK: [[q01_out:%.+]] = quantum.custom "CNOT"() [[q0_out_1]],[[q1]] : !quantum.bit, !quantum.bit + // CHECK: [[q01_out2:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] -> (!quantum.bit, !quantum.bit) { + // CHECK: [[q01_out]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 {adjoint} : !quantum.bit, !quantum.bit + // CHECK: [[q01_out]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 : !quantum.bit, !quantum.bit + // CHECK: scf.yield [[q01_out]] : (!quantum.bit, !quantum.bit) + // CHECK: [[%q2:%.+]] = quantum.namedobs [[q01_out2]]#0[ PauliY] : !quantum.obs + // CHECK: [[results:%.+]] = quantum.expval [[q1]] : f64 + // CHECK: [[tensorRes:%.+]] = tensor.from_elements [[result]] : tensor + // CHECK: [[%q2:%.+]] = quantum.insert %0[ 0], [[q01_out2]]#0 : !quantum.reg, !quantum.bit + // CHECK: [[%q3:%.+]] = quantum.insert %7[ 1], [[q01_out2]]#1 : !quantum.reg, !quantum.bit + // CHECK: quantum.dealloc [[q2]] : !quantum.reg + // CHECK: quantum.device_release + // CHECK: return [[tensorRes]] + +// CHECK: func.func private @simpleCircuit.quantumAlloc(%arg0: i64) -> !quantum.reg { + // CHECK: [[allocQreg:%.+]] = quantum.alloc(%arg0) : !quantum.reg + // CHECK: return [[allocQreg]] : !quantum.reg + +//CHECK-LABEL: func.func @circuit func.func @circuit() -> tensor attributes {qnode} { quantum.device ["rtd_lightning.so", "LightningQubit", "{shots: 0}"] %0 = quantum.alloc( 2) : !quantum.reg @@ -31,64 +63,25 @@ func.func @circuit() -> tensor attributes {qnode} { return %from_elements : tensor } - // CHECK: func.func private @circuit.folded(%arg0: index) -> tensor { - // CHECK: %c2_i64 = arith.constant 2 : i64 - // CHECK: %idx0 = index.constant 0 - // CHECK: %idx1 = index.constant 1 - // CHECK: quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] - // CHECK: %0 = call @circuit.quantumAlloc(%c2_i64) : (i64) -> !quantum.reg - // CHECK: %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit - // CHECK: %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit - // CHECK: %2 = scf.for %arg1 = %idx0 to %arg0 step %idx1 -> (!quantum.bit) { - // CHECK: %out_qubits = quantum.custom "Hadamard"() %out_qubits {adjoint} : !quantum.bit - // CHECK: %out_qubits = quantum.custom "Hadamard"() %out_qubits : !quantum.bit - // CHECK: scf.yield %out_qubits: !quantum.bit - // CHECK: } - // CHECK: %3 = quantum.extract %arg0[ 1] : !quantum.reg -> !quantum.bit - // CHECK: %out_qubits_0:2 = quantum.custom "CNOT"() %2, %3 : !quantum.bit, !quantum.bit - // CHECK: %4 = scf.for %arg1 = %idx0 to %arg0 step %idx1 -> (!quantum.bit, !quantum.bit) { - // CHECK: %out_qubits_0:2 = quantum.custom "CNOT"() %out_qubits_0#0, %out_qubits_0#1 {adjoint} : !quantum.bit, !quantum.bit - // CHECK: %out_qubits_0:2 = quantum.custom "CNOT"() %out_qubits_0#0, %out_qubits_0#1 : !quantum.bit, !quantum.bit - // CHECK: scf.yield %out_qubits_0 : (!quantum.bit, !quantum.bit) - // CHECK: } - // CHECK: %5 = quantum.namedobs %out_qubits_0#0[ PauliY] : !quantum.obs - // CHECK: %6 = quantum.expval %5 : f64 - // CHECK: %from_elements = tensor.from_elements %6 : tensor - // CHECK: %7 = quantum.insert %0[ 0], %out_qubits_0#0 : !quantum.reg, !quantum.bit - // CHECK: %8 = quantum.insert %7[ 1], %out_qubits_0#1 : !quantum.reg, !quantum.bit - // CHECK: quantum.dealloc %8 : !quantum.reg - // CHECK: quantum.device_release - // CHECK: return %from_elements : tensor - // CHECK: } - - // CHECK: func.func private @circuit.quantumAlloc(%arg0: i64) -> !quantum.reg { - // CHECK: %0 = quantum.alloc(%arg0) : !quantum.reg - // CHECK: return %0 : !quantum.reg - // CHECK: } - +//CHECK-LABEL: func.func @mitigated_circuit() + //CHECK: [[c0:%.+]] = index.constant 0 + //CHECK: [[c1:%.+]] = index.constant 1 + //CHECK: [[c3:%.+]] = index.constant 3 + //CHECK: [[dense3:%.+]] = arith.constant dense<[1, 2, 3]> + //CHECK: [[emptyRes:%.+]] = tensor.empty() : tensor<3xf64> + //CHECK: [[results:%.+]] = scf.for [[idx:%.+]] = [[c0]] to [[c3]] step [[c1]] iter_args(%arg1 = [[emptyRes]]) -> (tensor<3xf64>) { + //CHECK: [[scaleFactor:%.+]] = tensor.extract [[dense3]][[[idx]]] : tensor<3xindex> + //CHECK: [[intermediateRes:%.+]] = func.call @circuit.folded([[scaleFactor]]) : (index) -> tensor + //CHECK: [[tensorRes:%.+]] = tensor.from_elements [[intermediateRes]] : tensor<1xf64> + //CHECK: [[resultsFor:%.+]] = scf.for %arg2 = [[c0]] to [[c1]] step [[c1]] iter_args(%arg3 = %arg1) -> (tensor<3xf64>) { + //CHECK: [[extracted:%.+]] = tensor.extract [[tensorRes]][%arg3] : tensor<1xf64> + //CHECK: [[insertedRes:%.+]] = tensor.insert [[extracted]] into %arg3[%arg1] : tensor<5xf64> + //CHECK: scf.yield [[insertedRes]] + //CHECK: scf.yield [[resultsFor]] + //CHECK: return [[results]] func.func @mitigated_circuit() -> tensor<3xf64> { %scaleFactors = arith.constant dense<[1, 2, 3]> : tensor<3xindex> %0 = mitigation.zne @circuit() folding (all) scaleFactors (%scaleFactors : tensor<3xindex>) : () -> tensor<3xf64> func.return %0 : tensor<3xf64> } -//CHECK: func.func @mitigated_circuit() -> tensor<3xf64> { - //CHECK: %idx0 = index.constant 0 - //CHECK: %idx1 = index.constant 1 - //CHECK: %idx3 = index.constant 3 - //CHECK: %cst = arith.constant dense<[1, 2, 3]> : tensor<3xindex> - //CHECK: %0 = tensor.empty() : tensor<3xf64> - //CHECK: %1 = scf.for %arg0 = %idx0 to %idx3 step %idx1 iter_args(%arg1 = %0) -> (tensor<3xf64>) { - //CHECK: %extracted = tensor.extract %cst[%arg0] : tensor<3xindex> - //CHECK: %2 = func.call @circuit.folded(%extracted) : (index) -> tensor - //CHECK: %extracted_0 = tensor.extract %2[] : tensor - //CHECK: %from_elements = tensor.from_elements %extracted_0 : tensor<1xf64> - //CHECK: %3 = scf.for %arg2 = %idx0 to %idx1 step %idx1 iter_args(%arg3 = %arg1) -> (tensor<3xf64>) { - //CHECK: %extracted_1 = tensor.extract %from_elements[%arg2] : tensor<1xf64> - //CHECK: %inserted = tensor.insert %extracted_1 into %arg3[%arg0] : tensor<3xf64> - //CHECK: scf.yield %inserted : tensor<3xf64> - //CHECK: } - //CHECK: scf.yield %3 : tensor<3xf64> - //CHECK: } - //CHECK: return %1 : tensor<3xf64> -//CHECK: } diff --git a/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir b/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir index 0bfe041575..d12e79c891 100644 --- a/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir +++ b/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir @@ -14,6 +14,47 @@ // RUN: quantum-opt %s --lower-mitigation --split-input-file --verify-diagnostics | FileCheck %s +// CHECK-LABEL: func.func private @simpleCircuit.folded(%arg0: tensor<3xf64>, %arg1: index) -> f64 { + // CHECK: [[nQubits:%.+]] = arith.constant 1 + // CHECK: [[c0:%.+]] = index.constant 0 + // CHECK: [[c1:%.+]] = index.constant 1 + // CHECK: quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] + // CHECK: [[qReg:%.+]] = call @simpleCircuit.quantumAlloc([[nQubits]]) : (i64) -> !quantum.reg + // CHECK: [[outQregFor:%.+]] = scf.for %arg2 = [[c0]] to %arg1 step [[c1]] iter_args([[inQreg:%.+]] = [[qReg]]) -> (!quantum.reg) { + // CHECK: [[outQreg1:%.+]] = func.call @simpleCircuit.withoutMeasurements(%arg0, [[inQreg]]) : (tensor<3xf64>, !quantum.reg) -> !quantum.reg + // CHECK: [[outQreg2:%.+]] = quantum.adjoint([[outQreg1]]) : !quantum.reg { + // CHECK: ^bb0(%arg4: !quantum.reg): + // CHECK: [[callWithoutMeasurements:%.+]] = func.call @simpleCircuit.withoutMeasurements(%arg0, %arg4) : (tensor<3xf64>, !quantum.reg) -> !quantum.reg + // CHECK: quantum.yield [[callWithoutMeasurements]] : !quantum.reg + // CHECK: scf.yield [[outQreg2]] : !quantum.reg + // CHECK: [[results:%.+]] = call @simpleCircuit.withMeasurements(%arg0, [[outQregFor]]) : (tensor<3xf64>, !quantum.reg) -> f64 + // CHECK: quantum.device_release + // CHECK: return [[results]] + +// CHECK-LABEL: func.func private @simpleCircuit.quantumAlloc(%arg0: i64) -> !quantum.reg { + // CHECK: [[allocQreg:%.+]] = quantum.alloc(%arg0) : !quantum.reg + // CHECK: return [[allocQreg]] : !quantum.reg + +// CHECK-LABEL: func.func private @simpleCircuit.withoutMeasurements(%arg0: tensor<3xf64>, %arg1: !quantum.reg) -> !quantum.reg { + // CHECK: [[q_0:%.+]] = quantum.extract %arg1[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[q_1:%.+]] = quantum.custom "h"() [[q_0]] : !quantum.bit + // CHECK: [[q_2:%.+]] = quantum.custom "rz"({{.*}}) [[q_1]] : !quantum.bit + // CHECK: [[q_3:%.+]] = quantum.custom "u3"({{.*}}, {{.*}}, {{.*}}) [[q_2]] : !quantum.bit + // CHECK: [[q_4:%.+]] = quantum.insert %arg1[ 0], [[q_3]] : !quantum.reg, !quantum.bit + // CHECK: return [[q_4]] : !quantum.reg + +// CHECK-LABEL: func.func private @simpleCircuit.withMeasurements(%arg0: tensor<3xf64>, %arg1: !quantum.reg) -> f64 { + // CHECK: [[q_0:%.+]] = quantum.extract %arg1[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[q_1:%.+]] = quantum.custom "h"() [[q_0]] : !quantum.bit + // CHECK: [[q_2:%.+]] = quantum.custom "rz"({{.*}}) [[q_1]] : !quantum.bit + // CHECK: [[q_3:%.+]] = quantum.custom "u3"({{.*}}, {{.*}}, {{.*}}) [[q_2]] : !quantum.bit + // CHECK: [[q_4:%.+]] = quantum.insert %arg1[ 0], [[q_3]] : !quantum.reg, !quantum.bit + // CHECK: [[q_5:%.+]] = quantum.namedobs [[q_3]][ PauliX] : !quantum.obs + // CHECK: [[results:%.+]] = quantum.expval [[q_5]] : f64 + // CHECK: quantum.dealloc [[q_4]] : !quantum.reg + // CHECK: return [[results]] : f64 + +// CHECK-LABEL: func.func @simpleCircuit func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { quantum.device ["rtd_lightning.so", "LightningQubit", "{shots: 0}"] %c0 = arith.constant 0 : index @@ -39,53 +80,11 @@ func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { func.return %expval : f64 } -// CHECK: func.func private @simpleCircuit.folded(%arg0: tensor<3xf64>, %arg1: index) -> f64 { - // CHECK: [[nQubits:%.+]] = arith.constant 1 : i64 - // CHECK: %idx0 = index.constant 0 - // CHECK: %idx1 = index.constant 1 - // CHECK: quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] - // CHECK: [[qReg:%.+]] = call @simpleCircuit.quantumAlloc([[nQubits]]) : (i64) -> !quantum.reg - // CHECK: [[outQregFor:%.+]] = scf.for %arg2 = %idx0 to %arg1 step %idx1 iter_args([[inQreg:%.+]] = [[qReg]]) -> (!quantum.reg) { - // CHECK: [[outQreg1:%.+]] = func.call @simpleCircuit.withoutMeasurements(%arg0, [[inQreg]]) : (tensor<3xf64>, !quantum.reg) -> !quantum.reg - // CHECK: [[outQreg2:%.+]] = quantum.adjoint([[outQreg1]]) : !quantum.reg { - // CHECK: ^bb0(%arg4: !quantum.reg): - // CHECK: [[callWithoutMeasurements:%.+]] = func.call @simpleCircuit.withoutMeasurements(%arg0, %arg4) : (tensor<3xf64>, !quantum.reg) -> !quantum.reg - // CHECK: quantum.yield [[callWithoutMeasurements]] : !quantum.reg - // CHECK: scf.yield [[outQreg2]] : !quantum.reg - // CHECK: [[results:%.+]] = call @simpleCircuit.withMeasurements(%arg0, [[outQregFor]]) : (tensor<3xf64>, !quantum.reg) -> f64 - // CHECK: quantum.device_release - // CHECK: return [[results]] : f64 - -// CHECK: func.func private @simpleCircuit.quantumAlloc(%arg0: i64) -> !quantum.reg { - // CHECK: [[allocQreg:%.+]] = quantum.alloc(%arg0) : !quantum.reg - // CHECK: return [[allocQreg]] : !quantum.reg - -// CHECK: func.func private @simpleCircuit.withoutMeasurements(%arg0: tensor<3xf64>, %arg1: !quantum.reg) -> !quantum.reg { - // CHECK: [[q_0:%.+]] = quantum.extract %arg1[ 0] : !quantum.reg -> !quantum.bit - // CHECK: [[q_1:%.+]] = quantum.custom "h"() [[q_0]] : !quantum.bit - // CHECK: [[q_2:%.+]] = quantum.custom "rz"({{.*}}) [[q_1]] : !quantum.bit - // CHECK: [[q_3:%.+]] = quantum.custom "u3"({{.*}}, {{.*}}, {{.*}}) [[q_2]] : !quantum.bit - // CHECK: [[q_4:%.+]] = quantum.insert %arg1[ 0], [[q_3]] : !quantum.reg, !quantum.bit - // CHECK: return [[q_4]] : !quantum.reg - -// CHECK: func.func private @simpleCircuit.withMeasurements(%arg0: tensor<3xf64>, %arg1: !quantum.reg) -> f64 { - // CHECK: [[q_0:%.+]] = quantum.extract %arg1[ 0] : !quantum.reg -> !quantum.bit - // CHECK: [[q_1:%.+]] = quantum.custom "h"() [[q_0]] : !quantum.bit - // CHECK: [[q_2:%.+]] = quantum.custom "rz"({{.*}}) [[q_1]] : !quantum.bit - // CHECK: [[q_3:%.+]] = quantum.custom "u3"({{.*}}, {{.*}}, {{.*}}) [[q_2]] : !quantum.bit - // CHECK: [[q_4:%.+]] = quantum.insert %arg1[ 0], [[q_3]] : !quantum.reg, !quantum.bit - // CHECK: [[q_5:%.+]] = quantum.namedobs [[q_3]][ PauliX] : !quantum.obs - // CHECK: [[resulst:%.+]] = quantum.expval [[q_5]] : f64 - // CHECK: quantum.dealloc [[q_4]] : !quantum.reg - // CHECK: return [[resulst]] : f64 - -// CHECK: func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { - -// CHECK: func.func @zneCallScalarScalar(%arg0: tensor<3xf64>) -> tensor<5xf64> { +// CHECK-LABEL: func.func @zneCallScalarScalar(%arg0: tensor<3xf64>) -> tensor<5xf64> { // CHECK: [[c0:%.+]] = index.constant 0 // CHECK: [[c1:%.+]] = index.constant 1 // CHECK: [[c5:%.+]] = index.constant 5 - // CHECK: [[dense5:%.+]] = arith.constant dense<[1, 2, 3, 4, 5]> : tensor<5xindex> + // CHECK: [[dense5:%.+]] = arith.constant dense<[1, 2, 3, 4, 5]> // CHECK: [[emptyRes:%.+]] = tensor.empty() : tensor<5xf64> // CHECK: [[results:%.+]] = scf.for [[idx:%.+]] = [[c0]] to [[c5]] step [[c1]] iter_args(%arg2 = [[emptyRes]]) -> (tensor<5xf64>) { // CHECK: [[scalarFactor:%.+]] = tensor.extract [[dense5]][[[idx]]] : tensor<5xindex> @@ -94,9 +93,9 @@ func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { // CHECK: [[resultsFor:%.+]] = scf.for [[idxJ:%.+]] = [[c0]] to [[c1]] step [[c1]] iter_args(%arg4 = %arg2) -> (tensor<5xf64>) { // CHECK: [[extracted:%.+]] = tensor.extract [[tensorRes]][%arg3] : tensor<1xf64> // CHECK: [[insertedRes:%.+]] = tensor.insert [[extracted]] into %arg4[%arg1] : tensor<5xf64> - // CHECK: scf.yield [[insertedRes]] : tensor<5xf64> - // CHECK: scf.yield [[resultsFor]] : tensor<5xf64> - // CHECK: return [[results]] : tensor<5xf64> + // CHECK: scf.yield [[insertedRes]] + // CHECK: scf.yield [[resultsFor]] + // CHECK: return [[results]] func.func @zneCallScalarScalar(%arg0: tensor<3xf64>) -> tensor<5xf64> { %scaleFactors = arith.constant dense<[1, 2, 3, 4, 5]> : tensor<5xindex> %0 = mitigation.zne @simpleCircuit(%arg0) folding (global) scaleFactors (%scaleFactors : tensor<5xindex>) : (tensor<3xf64>) -> tensor<5xf64> From 5b705dd5ef7150dc48751dc9ab3df953bc4f5842 Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Tue, 6 Aug 2024 14:26:23 +0200 Subject: [PATCH 43/94] leftover from merge --- .../test/Mitigation/ZneFoldingGlobalTest.mlir | 44 +------------------ 1 file changed, 1 insertion(+), 43 deletions(-) diff --git a/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir b/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir index c608ae32fb..5e01831871 100644 --- a/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir +++ b/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir @@ -54,7 +54,7 @@ // CHECK: quantum.dealloc [[q_4]] : !quantum.reg // CHECK: return [[results]] : f64 -// CHECK-LABEL: func.func @simpleCircuit +// CHECK-LABEL: func.func @simpleCircuit() func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { quantum.device ["rtd_lightning.so", "LightningQubit", "{shots: 0}"] %c0 = arith.constant 0 : index @@ -80,48 +80,6 @@ func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { func.return %expval : f64 } -// CHECK: func.func private @simpleCircuit.folded(%arg0: tensor<3xf64>, %arg1: index) -> f64 { - // CHECK-DAG: [[nQubits:%.+]] = arith.constant 1 : i64 - // CHECK-DAG: %idx0 = index.constant 0 - // CHECK-DAG: %idx1 = index.constant 1 - // CHECK: quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] - // CHECK: [[qReg:%.+]] = call @simpleCircuit.quantumAlloc([[nQubits]]) : (i64) -> !quantum.reg - // CHECK: [[outQregFor:%.+]] = scf.for %arg2 = %idx0 to %arg1 step %idx1 iter_args([[inQreg:%.+]] = [[qReg]]) -> (!quantum.reg) { - // CHECK: [[outQreg1:%.+]] = func.call @simpleCircuit.withoutMeasurements(%arg0, [[inQreg]]) : (tensor<3xf64>, !quantum.reg) -> !quantum.reg - // CHECK: [[outQreg2:%.+]] = quantum.adjoint([[outQreg1]]) : !quantum.reg { - // CHECK: ^bb0(%arg4: !quantum.reg): - // CHECK: [[callWithoutMeasurements:%.+]] = func.call @simpleCircuit.withoutMeasurements(%arg0, %arg4) : (tensor<3xf64>, !quantum.reg) -> !quantum.reg - // CHECK: quantum.yield [[callWithoutMeasurements]] : !quantum.reg - // CHECK: scf.yield [[outQreg2]] : !quantum.reg - // CHECK: [[results:%.+]] = call @simpleCircuit.withMeasurements(%arg0, [[outQregFor]]) : (tensor<3xf64>, !quantum.reg) -> f64 - // CHECK: quantum.device_release - // CHECK: return [[results]] : f64 - -// CHECK: func.func private @simpleCircuit.quantumAlloc(%arg0: i64) -> !quantum.reg { - // CHECK: [[allocQreg:%.+]] = quantum.alloc(%arg0) : !quantum.reg - // CHECK: return [[allocQreg]] : !quantum.reg - -// CHECK: func.func private @simpleCircuit.withoutMeasurements(%arg0: tensor<3xf64>, %arg1: !quantum.reg) -> !quantum.reg { - // CHECK: [[q_0:%.+]] = quantum.extract %arg1[ 0] : !quantum.reg -> !quantum.bit - // CHECK: [[q_1:%.+]] = quantum.custom "h"() [[q_0]] : !quantum.bit - // CHECK: [[q_2:%.+]] = quantum.custom "rz"({{.*}}) [[q_1]] : !quantum.bit - // CHECK: [[q_3:%.+]] = quantum.custom "u3"({{.*}}, {{.*}}, {{.*}}) [[q_2]] : !quantum.bit - // CHECK: [[q_4:%.+]] = quantum.insert %arg1[ 0], [[q_3]] : !quantum.reg, !quantum.bit - // CHECK: return [[q_4]] : !quantum.reg - -// CHECK: func.func private @simpleCircuit.withMeasurements(%arg0: tensor<3xf64>, %arg1: !quantum.reg) -> f64 { - // CHECK: [[q_0:%.+]] = quantum.extract %arg1[ 0] : !quantum.reg -> !quantum.bit - // CHECK: [[q_1:%.+]] = quantum.custom "h"() [[q_0]] : !quantum.bit - // CHECK: [[q_2:%.+]] = quantum.custom "rz"({{.*}}) [[q_1]] : !quantum.bit - // CHECK: [[q_3:%.+]] = quantum.custom "u3"({{.*}}, {{.*}}, {{.*}}) [[q_2]] : !quantum.bit - // CHECK: [[q_4:%.+]] = quantum.insert %arg1[ 0], [[q_3]] : !quantum.reg, !quantum.bit - // CHECK: [[q_5:%.+]] = quantum.namedobs [[q_3]][ PauliX] : !quantum.obs - // CHECK: [[resulst:%.+]] = quantum.expval [[q_5]] : f64 - // CHECK: quantum.dealloc [[q_4]] : !quantum.reg - // CHECK: return [[resulst]] : f64 - -// CHECK: func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { - // CHECK: func.func @zneCallScalarScalar(%arg0: tensor<3xf64>) -> tensor<5xf64> { // CHECK-DAG: [[c0:%.+]] = index.constant 0 // CHECK-DAG: [[c1:%.+]] = index.constant 1 From f008c89285a61b51dc1c70d23df9cf8f273731e8 Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Tue, 6 Aug 2024 14:30:31 +0200 Subject: [PATCH 44/94] fix test label --- mlir/test/Mitigation/ZneFoldingGlobalTest.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir b/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir index 5e01831871..ab4a806ecb 100644 --- a/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir +++ b/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir @@ -15,9 +15,9 @@ // RUN: quantum-opt %s --lower-mitigation --split-input-file --verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func private @simpleCircuit.folded(%arg0: tensor<3xf64>, %arg1: index) -> f64 { - // CHECK: [[nQubits:%.+]] = arith.constant 1 - // CHECK: [[c0:%.+]] = index.constant 0 - // CHECK: [[c1:%.+]] = index.constant 1 + // CHECK-DAG: [[nQubits:%.+]] = arith.constant 1 + // CHECK-DAG: [[c0:%.+]] = index.constant 0 + // CHECK-DAG: [[c1:%.+]] = index.constant 1 // CHECK: quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] // CHECK: [[qReg:%.+]] = call @simpleCircuit.quantumAlloc([[nQubits]]) : (i64) -> !quantum.reg // CHECK: [[outQregFor:%.+]] = scf.for %arg2 = [[c0]] to %arg1 step [[c1]] iter_args([[inQreg:%.+]] = [[qReg]]) -> (!quantum.reg) { @@ -54,7 +54,7 @@ // CHECK: quantum.dealloc [[q_4]] : !quantum.reg // CHECK: return [[results]] : f64 -// CHECK-LABEL: func.func @simpleCircuit() +// CHECK-LABEL: func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { quantum.device ["rtd_lightning.so", "LightningQubit", "{shots: 0}"] %c0 = arith.constant 0 : index From ffbf1ec847a11b9d6578598684afe2e119533db8 Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Tue, 6 Aug 2024 14:35:07 +0200 Subject: [PATCH 45/94] labels in local folding mlir test --- mlir/test/Mitigation/ZneFoldingAllTest.mlir | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/mlir/test/Mitigation/ZneFoldingAllTest.mlir b/mlir/test/Mitigation/ZneFoldingAllTest.mlir index 10352c4e47..2b0cf18936 100644 --- a/mlir/test/Mitigation/ZneFoldingAllTest.mlir +++ b/mlir/test/Mitigation/ZneFoldingAllTest.mlir @@ -14,7 +14,7 @@ // RUN: quantum-opt %s --lower-mitigation --split-input-file --verify-diagnostics | FileCheck %s -// CHECK: func.func private @circuit.folded(%arg0: index) -> tensor { +// CHECK-LABEL: func.func private @circuit.folded(%arg0: index) -> tensor { // CHECK: [[nQubits:%.+]] = arith.constant 2 // CHECK: [[c0:%.+]] = index.constant 0 // CHECK: [[c1:%.+]] = index.constant 1 @@ -26,26 +26,26 @@ // CHECK: [[q0_out]] = quantum.custom "Hadamard"() [[q0_out]] {adjoint} : !quantum.bit // CHECK: [[q0_out]] = quantum.custom "Hadamard"() [[q0_out]] : !quantum.bit // CHECK: scf.yield [[q0_out]]: !quantum.bit - // CHECK: [[%q1:%.+]] = quantum.extract [[qReg]][ 1] : !quantum.reg -> !quantum.bit + // CHECK: [[q1:%.+]] = quantum.extract [[qReg]][ 1] : !quantum.reg -> !quantum.bit // CHECK: [[q01_out:%.+]] = quantum.custom "CNOT"() [[q0_out_1]],[[q1]] : !quantum.bit, !quantum.bit // CHECK: [[q01_out2:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] -> (!quantum.bit, !quantum.bit) { // CHECK: [[q01_out]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 {adjoint} : !quantum.bit, !quantum.bit // CHECK: [[q01_out]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 : !quantum.bit, !quantum.bit // CHECK: scf.yield [[q01_out]] : (!quantum.bit, !quantum.bit) - // CHECK: [[%q2:%.+]] = quantum.namedobs [[q01_out2]]#0[ PauliY] : !quantum.obs + // CHECK: [[q2:%.+]] = quantum.namedobs [[q01_out2]]#0[ PauliY] : !quantum.obs // CHECK: [[results:%.+]] = quantum.expval [[q1]] : f64 // CHECK: [[tensorRes:%.+]] = tensor.from_elements [[result]] : tensor - // CHECK: [[%q2:%.+]] = quantum.insert %0[ 0], [[q01_out2]]#0 : !quantum.reg, !quantum.bit - // CHECK: [[%q3:%.+]] = quantum.insert %7[ 1], [[q01_out2]]#1 : !quantum.reg, !quantum.bit + // CHECK: [[q2:%.+]] = quantum.insert %0[ 0], [[q01_out2]]#0 : !quantum.reg, !quantum.bit + // CHECK: [[q3:%.+]] = quantum.insert %7[ 1], [[q01_out2]]#1 : !quantum.reg, !quantum.bit // CHECK: quantum.dealloc [[q2]] : !quantum.reg // CHECK: quantum.device_release // CHECK: return [[tensorRes]] -// CHECK: func.func private @simpleCircuit.quantumAlloc(%arg0: i64) -> !quantum.reg { +// CHECK-LABEL: func.func private @circuit.quantumAlloc(%arg0: i64) -> !quantum.reg { // CHECK: [[allocQreg:%.+]] = quantum.alloc(%arg0) : !quantum.reg // CHECK: return [[allocQreg]] : !quantum.reg -//CHECK-LABEL: func.func @circuit +//CHECK-LABEL: func.func @circuit -> tensor attributes {qnode} { func.func @circuit() -> tensor attributes {qnode} { quantum.device ["rtd_lightning.so", "LightningQubit", "{shots: 0}"] %0 = quantum.alloc( 2) : !quantum.reg @@ -63,7 +63,6 @@ func.func @circuit() -> tensor attributes {qnode} { return %from_elements : tensor } - //CHECK-LABEL: func.func @mitigated_circuit() //CHECK: [[c0:%.+]] = index.constant 0 //CHECK: [[c1:%.+]] = index.constant 1 From 7b189d5bdac3d68905fbe00cd50e14fccab44ddc Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 6 Aug 2024 13:50:00 -0400 Subject: [PATCH 46/94] Debug --- .../Mitigation/Transforms/MitigationMethods/Zne.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 861e9fa713..fb69c6a79a 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -211,13 +211,12 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: Value numberQubitsValue, func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) { - rewriter.create(loc, fnAllocOp, numberQubitsValue); + Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); int64_t sizeArgs = fnFoldedOp.getArguments().size(); Value size = fnFoldedOp.getArgument(sizeArgs - 1); fnWithMeasurementsOp.walk([&](quantum::QubitUnitaryOp op) { - // TODO: Skip measurements and control structures. // Add scf for loop to create the folding rewriter.setInsertionPointAfter(op); rewriter.create( @@ -231,14 +230,21 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: }); return WalkResult::advance(); }); + + std::vector argsAndQreg(fnWithMeasurementsOp.getArguments().begin(), + fnWithMeasurementsOp.getArguments().end()); + argsAndQreg.pop_back(); + argsAndQreg.push_back(allocQreg); + Value loopedQreg = rewriter.create(loc, fnWithMeasurementsOp, argsAndQreg).getResult(0); // Remove device std::vector argsAndRegMeasurement(fnFoldedOp.getArguments().begin(), fnFoldedOp.getArguments().end()); argsAndRegMeasurement.pop_back(); - argsAndRegMeasurement.push_back(fnWithMeasurementsOp->getResult(0)); + argsAndRegMeasurement.push_back(loopedQreg); ValueRange funcFolded = rewriter.create(loc, fnWithMeasurementsOp, argsAndRegMeasurement) .getResults(); + // Remove device rewriter.create(loc); rewriter.create(loc, funcFolded); return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); From 942d38b402b6f221542fffdb5f076c05804eb5c9 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 6 Aug 2024 14:03:48 -0400 Subject: [PATCH 47/94] Generates local folding function --- .../Mitigation/Transforms/MitigationMethods/Zne.cpp | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index fb69c6a79a..6a04bdf2fb 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -235,18 +235,10 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: fnWithMeasurementsOp.getArguments().end()); argsAndQreg.pop_back(); argsAndQreg.push_back(allocQreg); - Value loopedQreg = rewriter.create(loc, fnWithMeasurementsOp, argsAndQreg).getResult(0); - // Remove device - std::vector argsAndRegMeasurement(fnFoldedOp.getArguments().begin(), - fnFoldedOp.getArguments().end()); - argsAndRegMeasurement.pop_back(); - argsAndRegMeasurement.push_back(loopedQreg); - ValueRange funcFolded = - rewriter.create(loc, fnWithMeasurementsOp, argsAndRegMeasurement) - .getResults(); + Value result = rewriter.create(loc, fnWithMeasurementsOp, argsAndQreg).getResult(0); // Remove device rewriter.create(loc); - rewriter.create(loc, funcFolded); + rewriter.create(loc, result); return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); } FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, From f668fbbd2ad42bc28c451b045df378544af373a5 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 6 Aug 2024 14:10:16 -0400 Subject: [PATCH 48/94] Cut fnWithoutMeasurementsOp from local folding --- .../Mitigation/Transforms/MitigationMethods/Zne.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 6a04bdf2fb..f9b7246fba 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -282,10 +282,13 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // Function without measurements: Create function without measurements and with qreg as last // argument - FlatSymbolRefAttr fnWithoutMeasurementsRefAttr = - getOrInsertFnWithoutMeasurements(loc, rewriter, op); - func::FuncOp fnWithoutMeasurementsOp = - SymbolTable::lookupNearestSymbolFrom(op, fnWithoutMeasurementsRefAttr); + func::FuncOp fnWithoutMeasurementsOp; + if (foldingAlgorithm == Folding(1)) { + FlatSymbolRefAttr fnWithoutMeasurementsRefAttr = + getOrInsertFnWithoutMeasurements(loc, rewriter, op); + fnWithoutMeasurementsOp = + SymbolTable::lookupNearestSymbolFrom(op, fnWithoutMeasurementsRefAttr); + } // Function with measurements: Modify the original function to take a quantum register as last // arg and keep measurements From 2276c5fd658b279b36a4f9326f731832bec7a32b Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 6 Aug 2024 15:11:42 -0400 Subject: [PATCH 49/94] Use builder in for loops --- .../Transforms/MitigationMethods/Zne.cpp | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index f9b7246fba..2669811d33 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -217,17 +217,24 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: Value size = fnFoldedOp.getArgument(sizeArgs - 1); fnWithMeasurementsOp.walk([&](quantum::QubitUnitaryOp op) { - // Add scf for loop to create the folding + // Insert a for loop immediately after every quantum::QubitUnitaryOp + auto innerLoc = op->getLoc(); rewriter.setInsertionPointAfter(op); rewriter.create( - loc, c0, size, c1, ValueRange(), - [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - // Call the function without measurements in an adjoint region - auto adjointOp = builder.create(loc, qregType, op.getResult(0)); + innerLoc, c0, size, c1, ValueRange(), + [&](OpBuilder &builder, Location forLoc, Value i, ValueRange iterArgs) { + // Set insertion point within the loop + builder.setInsertionPointToEnd(builder.getBlock()); + // Repeat the adjoint and original operation, after the existing QubitUnitaryOp + auto adjointOp = + builder.create(forLoc, qregType, op.getResult(0)) + .getResult(); auto origOp = - builder.create(loc, qregType, adjointOp.getResult()); - builder.create(loc, origOp.getResult()); + builder.create(forLoc, qregType, adjointOp).getResult(); + // Yield operation + builder.create(forLoc, origOp); }); + return WalkResult::advance(); }); @@ -235,7 +242,8 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: fnWithMeasurementsOp.getArguments().end()); argsAndQreg.pop_back(); argsAndQreg.push_back(allocQreg); - Value result = rewriter.create(loc, fnWithMeasurementsOp, argsAndQreg).getResult(0); + Value result = + rewriter.create(loc, fnWithMeasurementsOp, argsAndQreg).getResult(0); // Remove device rewriter.create(loc); rewriter.create(loc, result); From 81d19781129e7eab952057eb4f6556c41e104bf6 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 7 Aug 2024 14:54:38 -0400 Subject: [PATCH 50/94] Draft local folding (doesn't work) --- .../Transforms/MitigationMethods/Zne.cpp | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 2669811d33..1d6ca9f8e6 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -211,42 +211,48 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: Value numberQubitsValue, func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) { + // Allocate qubits Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); int64_t sizeArgs = fnFoldedOp.getArguments().size(); Value size = fnFoldedOp.getArgument(sizeArgs - 1); - fnWithMeasurementsOp.walk([&](quantum::QubitUnitaryOp op) { - // Insert a for loop immediately after every quantum::QubitUnitaryOp + // Walk through the operations in fnWithMeasurementsOp + fnWithMeasurementsOp.walk([&](quantum::QuantumGate op) { + // Insert a for loop immediately before each quantum::QuantumGate auto innerLoc = op->getLoc(); - rewriter.setInsertionPointAfter(op); + rewriter.setInsertionPoint(op); rewriter.create( innerLoc, c0, size, c1, ValueRange(), [&](OpBuilder &builder, Location forLoc, Value i, ValueRange iterArgs) { // Set insertion point within the loop builder.setInsertionPointToEnd(builder.getBlock()); - // Repeat the adjoint and original operation, after the existing QubitUnitaryOp - auto adjointOp = - builder.create(forLoc, qregType, op.getResult(0)) - .getResult(); - auto origOp = - builder.create(forLoc, qregType, adjointOp).getResult(); - // Yield operation - builder.create(forLoc, origOp); + // Create adjoint and original operations + auto origOp = builder.clone(*op)->getResult(0); + auto adjointOp = builder.create(forLoc, qregType, origOp).getResult(); + // Yield the result of the original operation + builder.create(forLoc, adjointOp); }); return WalkResult::advance(); }); + // Prepare the arguments for the final call std::vector argsAndQreg(fnWithMeasurementsOp.getArguments().begin(), fnWithMeasurementsOp.getArguments().end()); argsAndQreg.pop_back(); argsAndQreg.push_back(allocQreg); - Value result = - rewriter.create(loc, fnWithMeasurementsOp, argsAndQreg).getResult(0); - // Remove device + + // Insert the call to fnWithMeasurementsOp + rewriter.setInsertionPointAfter(fnWithMeasurementsOp.getBody().front().getTerminator()); + Value result = rewriter.create(loc, fnWithMeasurementsOp, argsAndQreg).getResult(0); + + // Insert the device release operation rewriter.create(loc); + // Return rewriter.create(loc, result); + + // Return the function symbol reference return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); } FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, From a9d84768350d32d2a7ed95e00852887643fb95b8 Mon Sep 17 00:00:00 2001 From: Alessandro Cosentino Date: Tue, 13 Aug 2024 15:47:28 -0400 Subject: [PATCH 51/94] fix indentation --- mlir/test/Mitigation/ZneFoldingAllTest.mlir | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mlir/test/Mitigation/ZneFoldingAllTest.mlir b/mlir/test/Mitigation/ZneFoldingAllTest.mlir index 4a95b88a04..c7075b529d 100644 --- a/mlir/test/Mitigation/ZneFoldingAllTest.mlir +++ b/mlir/test/Mitigation/ZneFoldingAllTest.mlir @@ -70,17 +70,17 @@ func.func @circuit() -> tensor attributes {qnode} { //CHECK: [[dense3:%.+]] = arith.constant dense<[1, 2, 3]> //CHECK: [[emptyRes:%.+]] = tensor.empty() : tensor<3xf64> //CHECK: [[results:%.+]] = scf.for [[idx:%.+]] = [[c0]] to [[c3]] step [[c1]] iter_args(%arg1 = [[emptyRes]]) -> (tensor<3xf64>) { - //CHECK: [[scaleFactor:%.+]] = tensor.extract [[dense3]][[[idx]]] : tensor<3xindex> - //CHECK: [[intermediateRes:%.+]] = func.call @circuit.folded([[scaleFactor]]) : (index) -> tensor - //CHECK: [[tensorRes:%.+]] = tensor.from_elements [[intermediateRes]] : tensor<1xf64> - //CHECK: [[resultsFor:%.+]] = scf.for %arg2 = [[c0]] to [[c1]] step [[c1]] iter_args(%arg3 = %arg1) -> (tensor<3xf64>) { - //CHECK: [[extracted:%.+]] = tensor.extract [[tensorRes]][%arg3] : tensor<1xf64> - //CHECK: [[insertedRes:%.+]] = tensor.insert [[extracted]] into %arg3[%arg1] : tensor<5xf64> - //CHECK: scf.yield [[insertedRes]] - //CHECK: scf.yield [[resultsFor]] + //CHECK: [[scaleFactor:%.+]] = tensor.extract [[dense3]][[[idx]]] : tensor<3xindex> + //CHECK: [[intermediateRes:%.+]] = func.call @circuit.folded([[scaleFactor]]) : (index) -> tensor + //CHECK: [[tensorRes:%.+]] = tensor.from_elements [[intermediateRes]] : tensor<1xf64> + //CHECK: [[resultsFor:%.+]] = scf.for %arg2 = [[c0]] to [[c1]] step [[c1]] iter_args(%arg3 = %arg1) -> (tensor<3xf64>) { + //CHECK: [[extracted:%.+]] = tensor.extract [[tensorRes]][%arg3] : tensor<1xf64> + //CHECK: [[insertedRes:%.+]] = tensor.insert [[extracted]] into %arg3[%arg1] : tensor<5xf64> + //CHECK: scf.yield [[insertedRes]] + //CHECK: scf.yield [[resultsFor]] //CHECK: return [[results]] func.func @mitigated_circuit() -> tensor<3xf64> { %scaleFactors = arith.constant dense<[1, 2, 3]> : tensor<3xindex> %0 = mitigation.zne @circuit() folding (all) scaleFactors (%scaleFactors : tensor<3xindex>) : () -> tensor<3xf64> func.return %0 : tensor<3xf64> -} \ No newline at end of file +} From 3f20d120c9355d57b00a20599ea9e52cc7e3733e Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 15 Aug 2024 10:53:53 -0400 Subject: [PATCH 52/94] Save insertion point --- .../Transforms/MitigationMethods/Zne.cpp | 46 +++++++++++-------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 1d6ca9f8e6..7b423fb484 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -217,25 +217,32 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: int64_t sizeArgs = fnFoldedOp.getArguments().size(); Value size = fnFoldedOp.getArgument(sizeArgs - 1); - // Walk through the operations in fnWithMeasurementsOp - fnWithMeasurementsOp.walk([&](quantum::QuantumGate op) { - // Insert a for loop immediately before each quantum::QuantumGate - auto innerLoc = op->getLoc(); - rewriter.setInsertionPoint(op); - rewriter.create( - innerLoc, c0, size, c1, ValueRange(), - [&](OpBuilder &builder, Location forLoc, Value i, ValueRange iterArgs) { - // Set insertion point within the loop - builder.setInsertionPointToEnd(builder.getBlock()); - // Create adjoint and original operations - auto origOp = builder.clone(*op)->getResult(0); - auto adjointOp = builder.create(forLoc, qregType, origOp).getResult(); - // Yield the result of the original operation - builder.create(forLoc, adjointOp); - }); - - return WalkResult::advance(); - }); + if (true) { + // Save the current insertion point + PatternRewriter::InsertionGuard guard(rewriter); + + // Walk through the operations in fnWithMeasurementsOp + fnWithMeasurementsOp.walk([&](quantum::QuantumGate op) { + // Insert a for loop immediately before each quantum::QuantumGate + auto innerLoc = op->getLoc(); + rewriter.setInsertionPoint(op); + rewriter.create( + innerLoc, c0, size, c1, ValueRange(), + [&](OpBuilder &builder, Location forLoc, Value i, ValueRange iterArgs) { + // Set insertion point within the loop + builder.setInsertionPointToEnd(builder.getBlock()); + // Create adjoint and original operations + auto origOp = builder.clone(*op)->getResult(0); + auto adjointOp = builder.create(forLoc, qregType, origOp).getResult(); + // Yield the result of the original operation + builder.create(forLoc, adjointOp); + }); + + return WalkResult::advance(); + }); + + // Restore original insertion point when PatternRewriter::InsertionGuard goes out of scope + } // Prepare the arguments for the final call std::vector argsAndQreg(fnWithMeasurementsOp.getArguments().begin(), @@ -244,7 +251,6 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: argsAndQreg.push_back(allocQreg); // Insert the call to fnWithMeasurementsOp - rewriter.setInsertionPointAfter(fnWithMeasurementsOp.getBody().front().getTerminator()); Value result = rewriter.create(loc, fnWithMeasurementsOp, argsAndQreg).getResult(0); // Insert the device release operation From fdf0f7157e81a218e04b50661b57610dd85b94c0 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Mon, 19 Aug 2024 15:26:09 -0400 Subject: [PATCH 53/94] Get args from fnFoldedOp --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 7b423fb484..e135754c94 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -245,8 +245,8 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: } // Prepare the arguments for the final call - std::vector argsAndQreg(fnWithMeasurementsOp.getArguments().begin(), - fnWithMeasurementsOp.getArguments().end()); + std::vector argsAndQreg(fnFoldedOp.getArguments().begin(), + fnFoldedOp.getArguments().end()); argsAndQreg.pop_back(); argsAndQreg.push_back(allocQreg); From 34c5df988e92b087d2f36ff4a7f8267cd2a72252 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 21 Aug 2024 15:19:03 -0400 Subject: [PATCH 54/94] Pass qubits from result to input --- .../Transforms/MitigationMethods/Zne.cpp | 49 +++++++++++++------ 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index e135754c94..effcf2a99c 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -223,20 +223,40 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: // Walk through the operations in fnWithMeasurementsOp fnWithMeasurementsOp.walk([&](quantum::QuantumGate op) { - // Insert a for loop immediately before each quantum::QuantumGate - auto innerLoc = op->getLoc(); rewriter.setInsertionPoint(op); - rewriter.create( - innerLoc, c0, size, c1, ValueRange(), - [&](OpBuilder &builder, Location forLoc, Value i, ValueRange iterArgs) { - // Set insertion point within the loop - builder.setInsertionPointToEnd(builder.getBlock()); - // Create adjoint and original operations - auto origOp = builder.clone(*op)->getResult(0); - auto adjointOp = builder.create(forLoc, qregType, origOp).getResult(); - // Yield the result of the original operation - builder.create(forLoc, adjointOp); - }); + auto innerLoc = op->getLoc(); + + std::vector opQubitArgs(op.getQubitOperands()); + std::vector opArgs(op->getOperands().begin(), op->getOperands().end()); + + // Insert a for loop immediately before each quantum::QuantumGate + const auto forVal = + rewriter + .create( + innerLoc, c0, size, c1, /*iterArgsInit=*/opQubitArgs, + [&](OpBuilder &builder, Location forLoc, Value i, ValueRange iterArgs) { + // Set insertion point within the loop + builder.setInsertionPointToEnd(builder.getBlock()); + + // When we clone the original operation, + // we replace opQubitArgs with the current qubits in iterArgs + IRMapping irm; + irm.map(opQubitArgs, iterArgs); + + // Create adjoint and original operations + auto origOp = builder.clone(*op, irm)->getResult(0); + auto adjointOp = + builder.create(forLoc, qregType, origOp) + .getResult(); + + // Yield the result of the original operation + builder.create(forLoc, adjointOp); + }) + .getResult(0); + + opArgs.pop_back(); + opArgs.push_back(forVal); + op->setOperands(opArgs); return WalkResult::advance(); }); @@ -251,7 +271,8 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: argsAndQreg.push_back(allocQreg); // Insert the call to fnWithMeasurementsOp - Value result = rewriter.create(loc, fnWithMeasurementsOp, argsAndQreg).getResult(0); + Value result = + rewriter.create(loc, fnWithMeasurementsOp, argsAndQreg).getResult(0); // Insert the device release operation rewriter.create(loc); From 9d9f2051b1a510634ba718048947ec5332a133a1 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 21 Aug 2024 15:24:07 -0400 Subject: [PATCH 55/94] Refactor Location usage --- .../Mitigation/Transforms/MitigationMethods/Zne.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index effcf2a99c..082fa9b631 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -224,7 +224,7 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: // Walk through the operations in fnWithMeasurementsOp fnWithMeasurementsOp.walk([&](quantum::QuantumGate op) { rewriter.setInsertionPoint(op); - auto innerLoc = op->getLoc(); + auto loc = op->getLoc(); std::vector opQubitArgs(op.getQubitOperands()); std::vector opArgs(op->getOperands().begin(), op->getOperands().end()); @@ -233,8 +233,8 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: const auto forVal = rewriter .create( - innerLoc, c0, size, c1, /*iterArgsInit=*/opQubitArgs, - [&](OpBuilder &builder, Location forLoc, Value i, ValueRange iterArgs) { + loc, c0, size, c1, /*iterArgsInit=*/opQubitArgs, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { // Set insertion point within the loop builder.setInsertionPointToEnd(builder.getBlock()); @@ -246,11 +246,11 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: // Create adjoint and original operations auto origOp = builder.clone(*op, irm)->getResult(0); auto adjointOp = - builder.create(forLoc, qregType, origOp) + builder.create(loc, qregType, origOp) .getResult(); // Yield the result of the original operation - builder.create(forLoc, adjointOp); + builder.create(loc, adjointOp); }) .getResult(0); From feeae40209437d3e819e607e3f7c3c56c62fdfb2 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 21 Aug 2024 15:38:07 -0400 Subject: [PATCH 56/94] Fix AdjointOP --- .../Transforms/MitigationMethods/Zne.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 082fa9b631..0d252d5381 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -128,7 +128,7 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const } // In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, - StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, + StringAttr lib, StringAttr name, StringAttr kwargs, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, Value numberQubitsValue, func::FuncOp fnWithoutMeasurementsOp, @@ -136,6 +136,7 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st { // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements + Type qregType = quantum::QuregType::get(rewriter.getContext()); Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); int64_t sizeArgs = fnFoldedOp.getArguments().size(); @@ -191,7 +192,7 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st // In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, StringAttr lib, StringAttr name, - StringAttr kwargs, Type qregType, FunctionType fnFoldedType, + StringAttr kwargs, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, Value numberQubitsValue, func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) @@ -205,7 +206,7 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, } // In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, - StringAttr lib, StringAttr name, StringAttr kwargs, Type qregType, + StringAttr lib, StringAttr name, StringAttr kwargs, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, Value numberQubitsValue, func::FuncOp fnWithMeasurementsOp, @@ -246,7 +247,7 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: // Create adjoint and original operations auto origOp = builder.clone(*op, irm)->getResult(0); auto adjointOp = - builder.create(loc, qregType, origOp) + builder.create(loc, origOp.getType(), origOp) .getResult(); // Yield the result of the original operation @@ -311,7 +312,6 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // Get the device quantum::DeviceInitOp deviceInitOp = *fnOp.getOps().begin(); - Type qregType = quantum::QuregType::get(ctx); StringAttr lib = deviceInitOp.getLibAttr(); StringAttr name = deviceInitOp.getNameAttr(); StringAttr kwargs = deviceInitOp.getKwargsAttr(); @@ -357,17 +357,17 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew rewriter.create(loc, lib, name, kwargs); if (foldingAlgorithm == Folding(1)) { - return globalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, fnFoldedType, + return globalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, fnWithoutMeasurementsOp, fnWithMeasurementsOp, c0, c1); } if (foldingAlgorithm == Folding(2)) { - return randomLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, + return randomLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, fnWithMeasurementsOp, c0, c1); } // Else, if (foldingAlgorithm == Folding(3)): - return allLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, qregType, fnFoldedType, + return allLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, fnWithMeasurementsOp, c0, c1); } From 7a13ab5a7819648372ada006b1113bbaaca68f44 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 22 Aug 2024 16:47:55 -0400 Subject: [PATCH 57/94] Passing qubit args through ForOp --- .../Transforms/MitigationMethods/Zne.cpp | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 0d252d5381..c88d8f9d93 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -212,6 +212,7 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: Value numberQubitsValue, func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) { + Type qregType = quantum::QuregType::get(rewriter.getContext()); // Allocate qubits Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); @@ -227,8 +228,7 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: rewriter.setInsertionPoint(op); auto loc = op->getLoc(); - std::vector opQubitArgs(op.getQubitOperands()); - std::vector opArgs(op->getOperands().begin(), op->getOperands().end()); + const std::vector opQubitArgs = op.getQubitOperands(); // Insert a for loop immediately before each quantum::QuantumGate const auto forVal = @@ -245,19 +245,19 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: irm.map(opQubitArgs, iterArgs); // Create adjoint and original operations - auto origOp = builder.clone(*op, irm)->getResult(0); + auto origOp = builder.clone(*op, irm)->getResults(); + #if 0 auto adjointOp = - builder.create(loc, origOp.getType(), origOp) - .getResult(); + builder.create(loc, qregType, origOp) + .getResults(); + #endif - // Yield the result of the original operation - builder.create(loc, adjointOp); + // Yield the qubits. + builder.create(loc, origOp); }) - .getResult(0); + .getResults(); - opArgs.pop_back(); - opArgs.push_back(forVal); - op->setOperands(opArgs); + op.setQubitOperands(forVal); return WalkResult::advance(); }); From 8aab9617d940b57dcbdfa570193326c8ee3e1391 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 22 Aug 2024 17:01:07 -0400 Subject: [PATCH 58/94] Use setQubitOperands() on clone --- .../Transforms/MitigationMethods/Zne.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index c88d8f9d93..6476941775 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -239,21 +239,19 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: // Set insertion point within the loop builder.setInsertionPointToEnd(builder.getBlock()); - // When we clone the original operation, - // we replace opQubitArgs with the current qubits in iterArgs - IRMapping irm; - irm.map(opQubitArgs, iterArgs); - // Create adjoint and original operations - auto origOp = builder.clone(*op, irm)->getResults(); + quantum::QuantumGate origOp = dyn_cast(builder.clone(*op)); + origOp.setQubitOperands(iterArgs); + auto origOpVal = origOp->getResults(); + #if 0 - auto adjointOp = - builder.create(loc, qregType, origOp) + auto adjointOpVal = + builder.create(loc, qregType, origOpVal) .getResults(); #endif // Yield the qubits. - builder.create(loc, origOp); + builder.create(loc, origOpVal); }) .getResults(); From af0f3850feda55cf417ff2d7b872364f0b49305e Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Sat, 24 Aug 2024 08:56:35 -0400 Subject: [PATCH 59/94] Move LCVs inside fnWithMeasurementsOp --- .../Transforms/MitigationMethods/Zne.cpp | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 6476941775..698e052e08 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -132,11 +132,16 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, Value numberQubitsValue, func::FuncOp fnWithoutMeasurementsOp, - func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) + func::FuncOp fnWithMeasurementsOp) { // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements Type qregType = quantum::QuregType::get(rewriter.getContext()); + // Loop control variables + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + // Add device + rewriter.create(loc, lib, name, kwargs); Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); int64_t sizeArgs = fnFoldedOp.getArguments().size(); @@ -195,7 +200,7 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, StringAttr kwargs, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, Value numberQubitsValue, - func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) + func::FuncOp fnWithMeasurementsOp) { // TODO: Implement. @@ -209,13 +214,15 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: StringAttr lib, StringAttr name, StringAttr kwargs, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, - Value numberQubitsValue, func::FuncOp fnWithMeasurementsOp, - Value c0, Value c1) + Value numberQubitsValue, func::FuncOp fnWithMeasurementsOp) { Type qregType = quantum::QuregType::get(rewriter.getContext()); + // Add device + rewriter.create(loc, lib, name, kwargs); // Allocate qubits Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); + // TODO: Can't use this argument in fnWithMeasurementsOp! int64_t sizeArgs = fnFoldedOp.getArguments().size(); Value size = fnFoldedOp.getArgument(sizeArgs - 1); @@ -223,11 +230,13 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: // Save the current insertion point PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&(fnWithMeasurementsOp.getRegion().front())); + // Loop control variables + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + // Walk through the operations in fnWithMeasurementsOp fnWithMeasurementsOp.walk([&](quantum::QuantumGate op) { - rewriter.setInsertionPoint(op); - auto loc = op->getLoc(); - const std::vector opQubitArgs = op.getQubitOperands(); // Insert a for loop immediately before each quantum::QuantumGate @@ -240,18 +249,18 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: builder.setInsertionPointToEnd(builder.getBlock()); // Create adjoint and original operations + #if 0 quantum::QuantumGate origOp = dyn_cast(builder.clone(*op)); origOp.setQubitOperands(iterArgs); auto origOpVal = origOp->getResults(); - #if 0 auto adjointOpVal = builder.create(loc, qregType, origOpVal) .getResults(); #endif // Yield the qubits. - builder.create(loc, origOpVal); + builder.create(loc, iterArgs); }) .getResults(); @@ -348,26 +357,21 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew rewriter.setInsertionPointToStart(foldedBlock); TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); - // Loop control variables - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - // Add device - rewriter.create(loc, lib, name, kwargs); if (foldingAlgorithm == Folding(1)) { return globalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, - fnWithoutMeasurementsOp, fnWithMeasurementsOp, c0, c1); + fnWithoutMeasurementsOp, fnWithMeasurementsOp); } if (foldingAlgorithm == Folding(2)) { return randomLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, - numberQubitsValue, fnWithMeasurementsOp, c0, c1); + numberQubitsValue, fnWithMeasurementsOp); } // Else, if (foldingAlgorithm == Folding(3)): return allLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, - fnWithMeasurementsOp, c0, c1); + fnWithMeasurementsOp); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) From 1f4bd4b490f5e1d85cf9f3a8eda6721d5c52f64c Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Sat, 24 Aug 2024 10:23:37 -0400 Subject: [PATCH 60/94] Revert "Move LCVs inside fnWithMeasurementsOp" This reverts commit af0f3850feda55cf417ff2d7b872364f0b49305e. --- .../Transforms/MitigationMethods/Zne.cpp | 38 +++++++++---------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 698e052e08..6476941775 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -132,16 +132,11 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, Value numberQubitsValue, func::FuncOp fnWithoutMeasurementsOp, - func::FuncOp fnWithMeasurementsOp) + func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) { // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements Type qregType = quantum::QuregType::get(rewriter.getContext()); - // Loop control variables - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - // Add device - rewriter.create(loc, lib, name, kwargs); Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); int64_t sizeArgs = fnFoldedOp.getArguments().size(); @@ -200,7 +195,7 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, StringAttr kwargs, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, Value numberQubitsValue, - func::FuncOp fnWithMeasurementsOp) + func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) { // TODO: Implement. @@ -214,15 +209,13 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: StringAttr lib, StringAttr name, StringAttr kwargs, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, - Value numberQubitsValue, func::FuncOp fnWithMeasurementsOp) + Value numberQubitsValue, func::FuncOp fnWithMeasurementsOp, + Value c0, Value c1) { Type qregType = quantum::QuregType::get(rewriter.getContext()); - // Add device - rewriter.create(loc, lib, name, kwargs); // Allocate qubits Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); - // TODO: Can't use this argument in fnWithMeasurementsOp! int64_t sizeArgs = fnFoldedOp.getArguments().size(); Value size = fnFoldedOp.getArgument(sizeArgs - 1); @@ -230,13 +223,11 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: // Save the current insertion point PatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&(fnWithMeasurementsOp.getRegion().front())); - // Loop control variables - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - // Walk through the operations in fnWithMeasurementsOp fnWithMeasurementsOp.walk([&](quantum::QuantumGate op) { + rewriter.setInsertionPoint(op); + auto loc = op->getLoc(); + const std::vector opQubitArgs = op.getQubitOperands(); // Insert a for loop immediately before each quantum::QuantumGate @@ -249,18 +240,18 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: builder.setInsertionPointToEnd(builder.getBlock()); // Create adjoint and original operations - #if 0 quantum::QuantumGate origOp = dyn_cast(builder.clone(*op)); origOp.setQubitOperands(iterArgs); auto origOpVal = origOp->getResults(); + #if 0 auto adjointOpVal = builder.create(loc, qregType, origOpVal) .getResults(); #endif // Yield the qubits. - builder.create(loc, iterArgs); + builder.create(loc, origOpVal); }) .getResults(); @@ -357,21 +348,26 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew rewriter.setInsertionPointToStart(foldedBlock); TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); + // Loop control variables + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + // Add device + rewriter.create(loc, lib, name, kwargs); if (foldingAlgorithm == Folding(1)) { return globalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, - fnWithoutMeasurementsOp, fnWithMeasurementsOp); + fnWithoutMeasurementsOp, fnWithMeasurementsOp, c0, c1); } if (foldingAlgorithm == Folding(2)) { return randomLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, - numberQubitsValue, fnWithMeasurementsOp); + numberQubitsValue, fnWithMeasurementsOp, c0, c1); } // Else, if (foldingAlgorithm == Folding(3)): return allLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, - fnWithMeasurementsOp); + fnWithMeasurementsOp, c0, c1); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) From 2e24e07d9462cc06d635856f4284f7a2c9e160a4 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Mon, 26 Aug 2024 11:11:19 -0400 Subject: [PATCH 61/94] Follow TDD --- .../Transforms/MitigationMethods/Zne.cpp | 128 +++++++----------- 1 file changed, 51 insertions(+), 77 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 6476941775..1e0704c833 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -194,8 +194,8 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, StringAttr lib, StringAttr name, StringAttr kwargs, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, - func::FuncOp fnAllocOp, Value numberQubitsValue, - func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) + func::FuncOp fnAllocOp, Value numberQubitsValue, Value c0, + Value c1) { // TODO: Implement. @@ -209,74 +209,35 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: StringAttr lib, StringAttr name, StringAttr kwargs, FunctionType fnFoldedType, SmallVector typesFolded, func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, - Value numberQubitsValue, func::FuncOp fnWithMeasurementsOp, - Value c0, Value c1) + Value numberQubitsValue, Value c0, Value c1) { - Type qregType = quantum::QuregType::get(rewriter.getContext()); - // Allocate qubits - Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); - - int64_t sizeArgs = fnFoldedOp.getArguments().size(); - Value size = fnFoldedOp.getArgument(sizeArgs - 1); - - if (true) { - // Save the current insertion point - PatternRewriter::InsertionGuard guard(rewriter); - - // Walk through the operations in fnWithMeasurementsOp - fnWithMeasurementsOp.walk([&](quantum::QuantumGate op) { - rewriter.setInsertionPoint(op); - auto loc = op->getLoc(); + // Type qregType = quantum::QuregType::get(rewriter.getContext()); - const std::vector opQubitArgs = op.getQubitOperands(); + // int64_t sizeArgs = fnFoldedOp.getArguments().size(); + // Value size = fnFoldedOp.getArgument(sizeArgs - 1); - // Insert a for loop immediately before each quantum::QuantumGate - const auto forVal = - rewriter - .create( - loc, c0, size, c1, /*iterArgsInit=*/opQubitArgs, - [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - // Set insertion point within the loop - builder.setInsertionPointToEnd(builder.getBlock()); - - // Create adjoint and original operations - quantum::QuantumGate origOp = dyn_cast(builder.clone(*op)); - origOp.setQubitOperands(iterArgs); - auto origOpVal = origOp->getResults(); - - #if 0 - auto adjointOpVal = - builder.create(loc, qregType, origOpVal) - .getResults(); - #endif - - // Yield the qubits. - builder.create(loc, origOpVal); - }) - .getResults(); - - op.setQubitOperands(forVal); - - return WalkResult::advance(); - }); - - // Restore original insertion point when PatternRewriter::InsertionGuard goes out of scope - } - - // Prepare the arguments for the final call - std::vector argsAndQreg(fnFoldedOp.getArguments().begin(), - fnFoldedOp.getArguments().end()); - argsAndQreg.pop_back(); - argsAndQreg.push_back(allocQreg); - - // Insert the call to fnWithMeasurementsOp - Value result = - rewriter.create(loc, fnWithMeasurementsOp, argsAndQreg).getResult(0); + // Walk through the operations in fnWithMeasurementsOp + // fnWithMeasurementsOp.walk([&](mlir::Operation *op) { + // return WalkResult::advance(); + // }); + // Return + // rewriter.create(loc, result); // Insert the device release operation + // fnFoldedOp.walk([&](func::ReturnOp returnOp) { + // rewriter.setInsertionPoint(returnOp); + // }); + + // rewriter.setInsertionPointToEnd(&(fnFoldedOp.getRegion().back())); rewriter.create(loc); - // Return - rewriter.create(loc, result); + // fnFoldedOp.walk([&](tensor::FromElementsOp fromElementsOp) { + // rewriter.create(loc, fromElementsOp.getResult()); + // }); + RankedTensorType resultType = cast(fnFoldedOp.getResultTypes().front()); + // Initialize the results as empty tensor + Value results = + rewriter.create(loc, resultType.getShape(), resultType.getElementType()); + rewriter.create(loc, results); // Return the function symbol reference return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); @@ -322,19 +283,19 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // Function without measurements: Create function without measurements and with qreg as last // argument func::FuncOp fnWithoutMeasurementsOp; + func::FuncOp fnWithMeasurementsOp; if (foldingAlgorithm == Folding(1)) { FlatSymbolRefAttr fnWithoutMeasurementsRefAttr = getOrInsertFnWithoutMeasurements(loc, rewriter, op); fnWithoutMeasurementsOp = SymbolTable::lookupNearestSymbolFrom(op, fnWithoutMeasurementsRefAttr); + // Function with measurements: Modify the original function to take a quantum register as + // last arg and keep measurements + FlatSymbolRefAttr fnWithMeasurementsRefAttr = + getOrInsertFnWithMeasurements(loc, rewriter, op); + fnWithMeasurementsOp = + SymbolTable::lookupNearestSymbolFrom(op, fnWithMeasurementsRefAttr); } - - // Function with measurements: Modify the original function to take a quantum register as last - // arg and keep measurements - FlatSymbolRefAttr fnWithMeasurementsRefAttr = getOrInsertFnWithMeasurements(loc, rewriter, op); - func::FuncOp fnWithMeasurementsOp = - SymbolTable::lookupNearestSymbolFrom(op, fnWithMeasurementsRefAttr); - rewriter.setInsertionPointToStart(moduleOp.getBody()); FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ @@ -343,7 +304,6 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); fnFoldedOp.setPrivate(); - Block *foldedBlock = fnFoldedOp.addEntryBlock(); rewriter.setInsertionPointToStart(foldedBlock); TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); @@ -351,6 +311,22 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // Loop control variables Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); + if (foldingAlgorithm != Folding(1)) { + func::FuncOp circOp = + SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr()); + rewriter.cloneRegionBefore(circOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); + Block *fnFoldedBlock = &fnFoldedOp.front(); + quantum::DeviceInitOp deviceInitOp = *fnFoldedOp.getOps().begin(); + rewriter.eraseOp(deviceInitOp); + quantum::DeviceReleaseOp deviceReleaseOp = + *fnFoldedOp.getOps().begin(); + rewriter.eraseOp(deviceReleaseOp); + quantum::AllocOp allocOpWithMeasurements = *fnFoldedOp.getOps().begin(); + auto lastArgQregIndex = fnFoldedBlock->getArguments().size(); + allocOpWithMeasurements.replaceAllUsesWith( + fnFoldedBlock->getArgument(lastArgQregIndex - 1)); + rewriter.eraseOp(allocOpWithMeasurements); + } // Add device rewriter.create(loc, lib, name, kwargs); @@ -360,14 +336,12 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew fnWithoutMeasurementsOp, fnWithMeasurementsOp, c0, c1); } if (foldingAlgorithm == Folding(2)) { - return randomLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, - fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, - numberQubitsValue, fnWithMeasurementsOp, c0, c1); + return randomLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, + typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, c0, c1); } // Else, if (foldingAlgorithm == Folding(3)): return allLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, - typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, - fnWithMeasurementsOp, c0, c1); + typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, c0, c1); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) From a5d0cdacaac13a69bfa94c3b21ed2fbdb9617ea8 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Mon, 26 Aug 2024 11:59:57 -0400 Subject: [PATCH 62/94] Debug --- .../Transforms/MitigationMethods/Zne.cpp | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 1e0704c833..39b2562ddc 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -230,14 +230,9 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: // rewriter.setInsertionPointToEnd(&(fnFoldedOp.getRegion().back())); rewriter.create(loc); - // fnFoldedOp.walk([&](tensor::FromElementsOp fromElementsOp) { - // rewriter.create(loc, fromElementsOp.getResult()); - // }); - RankedTensorType resultType = cast(fnFoldedOp.getResultTypes().front()); - // Initialize the results as empty tensor - Value results = - rewriter.create(loc, resultType.getShape(), resultType.getElementType()); - rewriter.create(loc, results); + fnFoldedOp.walk([&](tensor::FromElementsOp fromElementsOp) { + rewriter.create(loc, fromElementsOp.getResult()); + }); // Return the function symbol reference return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); @@ -315,20 +310,23 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew func::FuncOp circOp = SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr()); rewriter.cloneRegionBefore(circOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); - Block *fnFoldedBlock = &fnFoldedOp.front(); quantum::DeviceInitOp deviceInitOp = *fnFoldedOp.getOps().begin(); + rewriter.create(loc, lib, name, kwargs); rewriter.eraseOp(deviceInitOp); + // Add device quantum::DeviceReleaseOp deviceReleaseOp = *fnFoldedOp.getOps().begin(); rewriter.eraseOp(deviceReleaseOp); quantum::AllocOp allocOpWithMeasurements = *fnFoldedOp.getOps().begin(); - auto lastArgQregIndex = fnFoldedBlock->getArguments().size(); - allocOpWithMeasurements.replaceAllUsesWith( - fnFoldedBlock->getArgument(lastArgQregIndex - 1)); + // Allocate + Value allocQreg = + rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); + allocOpWithMeasurements.replaceAllUsesWith(allocQreg); rewriter.eraseOp(allocOpWithMeasurements); } - // Add device - rewriter.create(loc, lib, name, kwargs); + else { + rewriter.create(loc, lib, name, kwargs); + } if (foldingAlgorithm == Folding(1)) { return globalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, From 9a21a0e60a647eef4ef7784522257ec4270db232 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Mon, 26 Aug 2024 12:42:59 -0400 Subject: [PATCH 63/94] Debug --- .../Transforms/MitigationMethods/Zne.cpp | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 39b2562ddc..9364661b67 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -229,10 +229,13 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: // }); // rewriter.setInsertionPointToEnd(&(fnFoldedOp.getRegion().back())); - rewriter.create(loc); - fnFoldedOp.walk([&](tensor::FromElementsOp fromElementsOp) { - rewriter.create(loc, fromElementsOp.getResult()); - }); + // fnFoldedOp.walk([&](tensor::FromElementsOp fromElementsOp) { + // rewriter.create(loc, fromElementsOp.getResult()); + // }); + RankedTensorType resultType = cast(fnFoldedOp.getResultTypes().front()); + Value results = + rewriter.create(loc, resultType.getShape(), resultType.getElementType()); + rewriter.create(loc, results); // Return the function symbol reference return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); @@ -299,33 +302,31 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); fnFoldedOp.setPrivate(); - Block *foldedBlock = fnFoldedOp.addEntryBlock(); - rewriter.setInsertionPointToStart(foldedBlock); - TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); - Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); + rewriter.setInsertionPointToStart(fnFoldedOp.addEntryBlock()); // Loop control variables Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); - if (foldingAlgorithm != Folding(1)) { + TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); + Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); + if (foldingAlgorithm == Folding(1)) { + rewriter.create(loc, lib, name, kwargs); + } + else { func::FuncOp circOp = SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr()); rewriter.cloneRegionBefore(circOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); quantum::DeviceInitOp deviceInitOp = *fnFoldedOp.getOps().begin(); rewriter.create(loc, lib, name, kwargs); rewriter.eraseOp(deviceInitOp); - // Add device - quantum::DeviceReleaseOp deviceReleaseOp = - *fnFoldedOp.getOps().begin(); - rewriter.eraseOp(deviceReleaseOp); quantum::AllocOp allocOpWithMeasurements = *fnFoldedOp.getOps().begin(); - // Allocate Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); allocOpWithMeasurements.replaceAllUsesWith(allocQreg); rewriter.eraseOp(allocOpWithMeasurements); - } - else { - rewriter.create(loc, lib, name, kwargs); + quantum::DeviceReleaseOp deviceReleaseOp = + *fnFoldedOp.getOps().begin(); + rewriter.eraseOp(deviceReleaseOp); + rewriter.create(loc); } if (foldingAlgorithm == Folding(1)) { From 88ebb6e24aae38e28f86069319840e527be7fb6c Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Mon, 26 Aug 2024 15:12:18 -0400 Subject: [PATCH 64/94] Debug --- .../Transforms/MitigationMethods/Zne.cpp | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 9364661b67..66746bab7a 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -228,15 +228,6 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: // rewriter.setInsertionPoint(returnOp); // }); - // rewriter.setInsertionPointToEnd(&(fnFoldedOp.getRegion().back())); - // fnFoldedOp.walk([&](tensor::FromElementsOp fromElementsOp) { - // rewriter.create(loc, fromElementsOp.getResult()); - // }); - RankedTensorType resultType = cast(fnFoldedOp.getResultTypes().front()); - Value results = - rewriter.create(loc, resultType.getShape(), resultType.getElementType()); - rewriter.create(loc, results); - // Return the function symbol reference return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); } @@ -302,7 +293,8 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); fnFoldedOp.setPrivate(); - rewriter.setInsertionPointToStart(fnFoldedOp.addEntryBlock()); + auto fnFoldedOpBlock = fnFoldedOp.addEntryBlock(); + rewriter.setInsertionPointToStart(fnFoldedOpBlock); // Loop control variables Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); @@ -312,21 +304,32 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew rewriter.create(loc, lib, name, kwargs); } else { - func::FuncOp circOp = - SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr()); - rewriter.cloneRegionBefore(circOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); + rewriter.cloneRegionBefore(fnOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); + quantum::DeviceInitOp deviceInitOp = *fnFoldedOp.getOps().begin(); rewriter.create(loc, lib, name, kwargs); rewriter.eraseOp(deviceInitOp); + quantum::AllocOp allocOpWithMeasurements = *fnFoldedOp.getOps().begin(); Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); allocOpWithMeasurements.replaceAllUsesWith(allocQreg); rewriter.eraseOp(allocOpWithMeasurements); + quantum::DeviceReleaseOp deviceReleaseOp = *fnFoldedOp.getOps().begin(); + rewriter.setInsertionPoint(deviceReleaseOp); rewriter.eraseOp(deviceReleaseOp); + + rewriter.setInsertionPointToEnd(fnFoldedOpBlock); rewriter.create(loc); + + // TODO: Why doesn't this next line work as ReturnOp Value? + // Value results = (*fnFoldedOp.getOps().begin()).getResult(); + RankedTensorType resultType = cast(fnFoldedOp.getResultTypes().front()); + Value results = + rewriter.create(loc, resultType.getShape(), resultType.getElementType()); + rewriter.create(loc, results); } if (foldingAlgorithm == Folding(1)) { From e8b9ae32e44792dfa6902c169e8ae49ac470e858 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 27 Aug 2024 10:18:21 -0400 Subject: [PATCH 65/94] ForOp --- .../Transforms/MitigationMethods/Zne.cpp | 48 ++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 66746bab7a..1f0b65ae0f 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -213,20 +213,44 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: { // Type qregType = quantum::QuregType::get(rewriter.getContext()); - // int64_t sizeArgs = fnFoldedOp.getArguments().size(); - // Value size = fnFoldedOp.getArgument(sizeArgs - 1); + int64_t sizeArgs = fnFoldedOp.getArguments().size(); + Value size = fnFoldedOp.getArgument(sizeArgs - 1); // Walk through the operations in fnWithMeasurementsOp - // fnWithMeasurementsOp.walk([&](mlir::Operation *op) { - // return WalkResult::advance(); - // }); - // Return - // rewriter.create(loc, result); - - // Insert the device release operation - // fnFoldedOp.walk([&](func::ReturnOp returnOp) { - // rewriter.setInsertionPoint(returnOp); - // }); + fnFoldedOp.walk([&](quantum::QuantumGate op) { + rewriter.setInsertionPoint(op); + auto loc = op->getLoc(); + const std::vector opQubitArgs = op.getQubitOperands(); + + // Insert a for loop immediately before each quantum::QuantumGate + const auto forVal = + rewriter + .create( + loc, c0, size, c1, /*iterArgsInit=*/opQubitArgs, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { + // Set insertion point within the loop + builder.setInsertionPointToEnd(builder.getBlock()); + + // Create adjoint and original operations + quantum::QuantumGate origOp = dyn_cast(builder.clone(*op)); + origOp.setQubitOperands(iterArgs); + auto origOpVal = origOp->getResults(); + + #if 0 + auto adjointOpVal = + builder.create(loc, qregType, origOpVal) + .getResults(); + #endif + + // Yield the qubits. + builder.create(loc, origOpVal); + }) + .getResults(); + + op.setQubitOperands(forVal); + + return WalkResult::advance(); + }); // Return the function symbol reference return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); From 7a7ca56c2573a9f95c6aa805cde2460242f24296 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 27 Aug 2024 13:41:11 -0400 Subject: [PATCH 66/94] Fix local folding ForOp --- .../Transforms/MitigationMethods/Zne.cpp | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 1f0b65ae0f..35baab9e99 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -232,18 +232,19 @@ FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std:: builder.setInsertionPointToEnd(builder.getBlock()); // Create adjoint and original operations - quantum::QuantumGate origOp = dyn_cast(builder.clone(*op)); + quantum::QuantumGate origOp = + dyn_cast(builder.clone(*op)); origOp.setQubitOperands(iterArgs); auto origOpVal = origOp->getResults(); - #if 0 - auto adjointOpVal = - builder.create(loc, qregType, origOpVal) - .getResults(); - #endif + quantum::QuantumGate adjointOp = + dyn_cast(builder.clone(*origOp)); + adjointOp.setQubitOperands(origOpVal); + adjointOp.setAdjointFlag(!adjointOp.getAdjointFlag()); + auto adjointOpVal = adjointOp->getResults(); // Yield the qubits. - builder.create(loc, origOpVal); + builder.create(loc, adjointOpVal); }) .getResults(); @@ -329,11 +330,11 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew } else { rewriter.cloneRegionBefore(fnOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); - + quantum::DeviceInitOp deviceInitOp = *fnFoldedOp.getOps().begin(); rewriter.create(loc, lib, name, kwargs); rewriter.eraseOp(deviceInitOp); - + quantum::AllocOp allocOpWithMeasurements = *fnFoldedOp.getOps().begin(); Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); @@ -351,8 +352,9 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // TODO: Why doesn't this next line work as ReturnOp Value? // Value results = (*fnFoldedOp.getOps().begin()).getResult(); RankedTensorType resultType = cast(fnFoldedOp.getResultTypes().front()); - Value results = - rewriter.create(loc, resultType.getShape(), resultType.getElementType()); + Value results = rewriter.create(loc, resultType.getShape(), + resultType.getElementType()); + rewriter.create(loc, results); } From e0557991021fd0d6a34084c444ab08244df84789 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 28 Aug 2024 09:18:16 -0400 Subject: [PATCH 67/94] Just copy @circuit --- .../Transforms/MitigationMethods/Zne.cpp | 26 +++++-------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 35baab9e99..4e1d9e72f9 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -310,6 +310,8 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew fnWithMeasurementsOp = SymbolTable::lookupNearestSymbolFrom(op, fnWithMeasurementsRefAttr); } + else { + } rewriter.setInsertionPointToStart(moduleOp.getBody()); FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ @@ -331,30 +333,14 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew else { rewriter.cloneRegionBefore(fnOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); - quantum::DeviceInitOp deviceInitOp = *fnFoldedOp.getOps().begin(); - rewriter.create(loc, lib, name, kwargs); - rewriter.eraseOp(deviceInitOp); - - quantum::AllocOp allocOpWithMeasurements = *fnFoldedOp.getOps().begin(); - Value allocQreg = - rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); - allocOpWithMeasurements.replaceAllUsesWith(allocQreg); - rewriter.eraseOp(allocOpWithMeasurements); + // tensor::FromElementsOp fromElementsOp = + // *fnFoldedOp.getOps().begin(); + // rewriter.setInsertionPointToEnd(fnFoldedOpBlock); + // rewriter.create(loc, fromElementsOp.getResult()); - quantum::DeviceReleaseOp deviceReleaseOp = - *fnFoldedOp.getOps().begin(); - rewriter.setInsertionPoint(deviceReleaseOp); - rewriter.eraseOp(deviceReleaseOp); - - rewriter.setInsertionPointToEnd(fnFoldedOpBlock); - rewriter.create(loc); - - // TODO: Why doesn't this next line work as ReturnOp Value? - // Value results = (*fnFoldedOp.getOps().begin()).getResult(); RankedTensorType resultType = cast(fnFoldedOp.getResultTypes().front()); Value results = rewriter.create(loc, resultType.getShape(), resultType.getElementType()); - rewriter.create(loc, results); } From cbc95cb60fe0cb170d35994f04693f0823efb9c6 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 28 Aug 2024 09:43:08 -0400 Subject: [PATCH 68/94] Cut unused args --- .../Transforms/MitigationMethods/Zne.cpp | 27 ++++++------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 4e1d9e72f9..788af161a1 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -190,12 +190,8 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); } // In *.cpp module only, to keep extraneous headers out of *.hpp -FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, - std::string fnFoldedName, StringAttr lib, StringAttr name, - StringAttr kwargs, FunctionType fnFoldedType, - SmallVector typesFolded, func::FuncOp fnFoldedOp, - func::FuncOp fnAllocOp, Value numberQubitsValue, Value c0, - Value c1) +FlatSymbolRefAttr randomLocalFolding(PatternRewriter &rewriter, std::string fnFoldedName, + func::FuncOp fnFoldedOp, Value c0, Value c1) { // TODO: Implement. @@ -205,11 +201,8 @@ FlatSymbolRefAttr randomLocalFolding(Location loc, PatternRewriter &rewriter, return FlatSymbolRefAttr(); } // In *.cpp module only, to keep extraneous headers out of *.hpp -FlatSymbolRefAttr allLocalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, - StringAttr lib, StringAttr name, StringAttr kwargs, - FunctionType fnFoldedType, SmallVector typesFolded, - func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, - Value numberQubitsValue, Value c0, Value c1) +FlatSymbolRefAttr allLocalFolding(PatternRewriter &rewriter, std::string fnFoldedName, + func::FuncOp fnFoldedOp, Value c0, Value c1) { // Type qregType = quantum::QuregType::get(rewriter.getContext()); @@ -332,12 +325,10 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew } else { rewriter.cloneRegionBefore(fnOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); - // tensor::FromElementsOp fromElementsOp = - // *fnFoldedOp.getOps().begin(); - // rewriter.setInsertionPointToEnd(fnFoldedOpBlock); + // *fnFoldedOp.getOps().begin(); + // rewriter.setInsertionPointToEnd(&fnFoldedOp.getBody().back()); // rewriter.create(loc, fromElementsOp.getResult()); - RankedTensorType resultType = cast(fnFoldedOp.getResultTypes().front()); Value results = rewriter.create(loc, resultType.getShape(), resultType.getElementType()); @@ -350,12 +341,10 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew fnWithoutMeasurementsOp, fnWithMeasurementsOp, c0, c1); } if (foldingAlgorithm == Folding(2)) { - return randomLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, - typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, c0, c1); + return randomLocalFolding(rewriter, fnFoldedName, fnFoldedOp, c0, c1); } // Else, if (foldingAlgorithm == Folding(3)): - return allLocalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, - typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, c0, c1); + return allLocalFolding(rewriter, fnFoldedName, fnFoldedOp, c0, c1); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) From d97ebc3159ff73c8fd8a873e434a9394975fc703 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 28 Aug 2024 12:13:27 -0400 Subject: [PATCH 69/94] All but func::ReturnOp --- .../Transforms/MitigationMethods/Zne.cpp | 38 +++++++++++-------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 788af161a1..c9e91bc8c7 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -312,27 +312,35 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew /*outputs=*/fnOp.getResultTypes()); func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); + Value c0, c1, numberQubitsValue; fnFoldedOp.setPrivate(); - auto fnFoldedOpBlock = fnFoldedOp.addEntryBlock(); - rewriter.setInsertionPointToStart(fnFoldedOpBlock); - // Loop control variables - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); - Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); if (foldingAlgorithm == Folding(1)) { + Block *fnFoldedOpBlock = fnFoldedOp.addEntryBlock(); + rewriter.setInsertionPointToStart(fnFoldedOpBlock); + // Loop control variables + c0 = rewriter.create(loc, 0); + c1 = rewriter.create(loc, 1); + TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); + numberQubitsValue = rewriter.create(loc, numberQubitsAttr); rewriter.create(loc, lib, name, kwargs); } else { rewriter.cloneRegionBefore(fnOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); - // tensor::FromElementsOp fromElementsOp = - // *fnFoldedOp.getOps().begin(); - // rewriter.setInsertionPointToEnd(&fnFoldedOp.getBody().back()); - // rewriter.create(loc, fromElementsOp.getResult()); - RankedTensorType resultType = cast(fnFoldedOp.getResultTypes().front()); - Value results = rewriter.create(loc, resultType.getShape(), - resultType.getElementType()); - rewriter.create(loc, results); + + Block *fnFoldedOpBlock = &fnFoldedOp.getBody().front(); + rewriter.setInsertionPointToStart(fnFoldedOpBlock); + // Loop control variables + c0 = rewriter.create(loc, 0); + c1 = rewriter.create(loc, 1); + TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); + numberQubitsValue = rewriter.create(loc, numberQubitsAttr); + + fnFoldedOpBlock->addArgument(fnFoldedOp.getArgumentTypes().front(), loc); + + tensor::FromElementsOp fromElementsOp = + *fnFoldedOpBlock->getOps().begin(); + rewriter.setInsertionPointToEnd(fnFoldedOpBlock); + rewriter.create(loc, fromElementsOp.getResult()); } if (foldingAlgorithm == Folding(1)) { From 7a687f8956cc0520c476976123907c4186977e98 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 28 Aug 2024 12:25:18 -0400 Subject: [PATCH 70/94] Generates code for @circuit.folded --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index c9e91bc8c7..05296b255c 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -336,11 +336,6 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew numberQubitsValue = rewriter.create(loc, numberQubitsAttr); fnFoldedOpBlock->addArgument(fnFoldedOp.getArgumentTypes().front(), loc); - - tensor::FromElementsOp fromElementsOp = - *fnFoldedOpBlock->getOps().begin(); - rewriter.setInsertionPointToEnd(fnFoldedOpBlock); - rewriter.create(loc, fromElementsOp.getResult()); } if (foldingAlgorithm == Folding(1)) { From 1aaa73fd05f62c0af8e0b789bdfe8a05cb502ba0 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 28 Aug 2024 12:40:19 -0400 Subject: [PATCH 71/94] No alloc. func. for local folding --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 05296b255c..16139ed96c 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -268,9 +268,12 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // Set insertion in the module rewriter.setInsertionPointToStart(moduleOp.getBody()); // Quantum Alloc function - FlatSymbolRefAttr quantumAllocRefAttr = getOrInsertQuantumAlloc(loc, rewriter, op); - func::FuncOp fnAllocOp = - SymbolTable::lookupNearestSymbolFrom(op, quantumAllocRefAttr); + FlatSymbolRefAttr quantumAllocRefAttr; + func::FuncOp fnAllocOp; + if (foldingAlgorithm == Folding(1)) { + quantumAllocRefAttr = getOrInsertQuantumAlloc(loc, rewriter, op); + fnAllocOp = SymbolTable::lookupNearestSymbolFrom(op, quantumAllocRefAttr); + } // Get the number of qubits const int64_t numberQubits = From 2ba29dbafdb9aa9a92279bbb1d66272e06820eb6 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 28 Aug 2024 12:43:16 -0400 Subject: [PATCH 72/94] Cut redundant code in ZNE rewrite --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 16139ed96c..6ea1fb736f 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -335,8 +335,6 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // Loop control variables c0 = rewriter.create(loc, 0); c1 = rewriter.create(loc, 1); - TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); - numberQubitsValue = rewriter.create(loc, numberQubitsAttr); fnFoldedOpBlock->addArgument(fnFoldedOp.getArgumentTypes().front(), loc); } From bc0c51783d81d679cacd10c64c0d7f1563d788e7 Mon Sep 17 00:00:00 2001 From: Daniel Strano Date: Fri, 30 Aug 2024 12:54:40 -0400 Subject: [PATCH 73/94] Update mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp Co-authored-by: Romain Moyard --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 6ea1fb736f..7a7a24999f 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -204,7 +204,6 @@ FlatSymbolRefAttr randomLocalFolding(PatternRewriter &rewriter, std::string fnFo FlatSymbolRefAttr allLocalFolding(PatternRewriter &rewriter, std::string fnFoldedName, func::FuncOp fnFoldedOp, Value c0, Value c1) { - // Type qregType = quantum::QuregType::get(rewriter.getContext()); int64_t sizeArgs = fnFoldedOp.getArguments().size(); Value size = fnFoldedOp.getArgument(sizeArgs - 1); From facfec1b5c49a9f9dab65c4e8c37a2fd21cbef3f Mon Sep 17 00:00:00 2001 From: Daniel Strano Date: Fri, 30 Aug 2024 12:54:58 -0400 Subject: [PATCH 74/94] Update mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp Co-authored-by: Romain Moyard --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 7a7a24999f..98b49e40aa 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -305,8 +305,6 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew fnWithMeasurementsOp = SymbolTable::lookupNearestSymbolFrom(op, fnWithMeasurementsRefAttr); } - else { - } rewriter.setInsertionPointToStart(moduleOp.getBody()); FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ From 4d5735b9dabf8813731f57088775dc989a87ff05 Mon Sep 17 00:00:00 2001 From: Daniel Strano Date: Fri, 30 Aug 2024 12:55:15 -0400 Subject: [PATCH 75/94] Update mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp Co-authored-by: Romain Moyard --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 98b49e40aa..e6f9f6c976 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -208,7 +208,7 @@ FlatSymbolRefAttr allLocalFolding(PatternRewriter &rewriter, std::string fnFolde int64_t sizeArgs = fnFoldedOp.getArguments().size(); Value size = fnFoldedOp.getArgument(sizeArgs - 1); - // Walk through the operations in fnWithMeasurementsOp + // Walk through the operations in fnFoldedOp fnFoldedOp.walk([&](quantum::QuantumGate op) { rewriter.setInsertionPoint(op); auto loc = op->getLoc(); From 9e945b58a15e7991a07d85955622365eff16a03f Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Fri, 30 Aug 2024 12:58:32 -0400 Subject: [PATCH 76/94] Per @rmoyard review --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index e6f9f6c976..95c79c43ce 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -220,9 +220,6 @@ FlatSymbolRefAttr allLocalFolding(PatternRewriter &rewriter, std::string fnFolde .create( loc, c0, size, c1, /*iterArgsInit=*/opQubitArgs, [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - // Set insertion point within the loop - builder.setInsertionPointToEnd(builder.getBlock()); - // Create adjoint and original operations quantum::QuantumGate origOp = dyn_cast(builder.clone(*op)); From 1e9e95bab6021eff0b4a545a0ce8739dde47c9a1 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Fri, 30 Aug 2024 13:46:24 -0400 Subject: [PATCH 77/94] Per @rmoyard review --- .../Transforms/MitigationMethods/Zne.cpp | 91 +++++++++---------- 1 file changed, 42 insertions(+), 49 deletions(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 95c79c43ce..3929e0847f 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -129,14 +129,23 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const // In *.cpp module only, to keep extraneous headers out of *.hpp FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, StringAttr lib, StringAttr name, StringAttr kwargs, - FunctionType fnFoldedType, SmallVector typesFolded, - func::FuncOp fnFoldedOp, func::FuncOp fnAllocOp, - Value numberQubitsValue, func::FuncOp fnWithoutMeasurementsOp, - func::FuncOp fnWithMeasurementsOp, Value c0, Value c1) + int64_t numberQubits, FunctionType fnFoldedType, + SmallVector typesFolded, func::FuncOp fnFoldedOp, + func::FuncOp fnAllocOp, func::FuncOp fnWithoutMeasurementsOp, + func::FuncOp fnWithMeasurementsOp) { // Function folded: Create the folded circuit (withoutMeasurement * // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements Type qregType = quantum::QuregType::get(rewriter.getContext()); + + rewriter.setInsertionPointToStart(fnFoldedOp.addEntryBlock()); + // Loop control variables + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); + Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); + rewriter.create(loc, lib, name, kwargs); + Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); int64_t sizeArgs = fnFoldedOp.getArguments().size(); @@ -263,13 +272,6 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // Set insertion in the module rewriter.setInsertionPointToStart(moduleOp.getBody()); - // Quantum Alloc function - FlatSymbolRefAttr quantumAllocRefAttr; - func::FuncOp fnAllocOp; - if (foldingAlgorithm == Folding(1)) { - quantumAllocRefAttr = getOrInsertQuantumAlloc(loc, rewriter, op); - fnAllocOp = SymbolTable::lookupNearestSymbolFrom(op, quantumAllocRefAttr); - } // Get the number of qubits const int64_t numberQubits = @@ -286,58 +288,49 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew Type indexType = rewriter.getIndexType(); typesFolded.push_back(indexType); - // Function without measurements: Create function without measurements and with qreg as last - // argument - func::FuncOp fnWithoutMeasurementsOp; - func::FuncOp fnWithMeasurementsOp; + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ + typesFolded, + /*outputs=*/fnOp.getResultTypes()); + + func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); + fnFoldedOp.setPrivate(); if (foldingAlgorithm == Folding(1)) { + // Quantum Alloc function + FlatSymbolRefAttr quantumAllocRefAttr = getOrInsertQuantumAlloc(loc, rewriter, op); + func::FuncOp fnAllocOp = + SymbolTable::lookupNearestSymbolFrom(op, quantumAllocRefAttr); + + // Function without measurements: Create function without measurements and with qreg as last + // argument FlatSymbolRefAttr fnWithoutMeasurementsRefAttr = getOrInsertFnWithoutMeasurements(loc, rewriter, op); - fnWithoutMeasurementsOp = + func::FuncOp fnWithoutMeasurementsOp = SymbolTable::lookupNearestSymbolFrom(op, fnWithoutMeasurementsRefAttr); + // Function with measurements: Modify the original function to take a quantum register as // last arg and keep measurements FlatSymbolRefAttr fnWithMeasurementsRefAttr = getOrInsertFnWithMeasurements(loc, rewriter, op); - fnWithMeasurementsOp = + func::FuncOp fnWithMeasurementsOp = SymbolTable::lookupNearestSymbolFrom(op, fnWithMeasurementsRefAttr); - } - rewriter.setInsertionPointToStart(moduleOp.getBody()); - - FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ - typesFolded, - /*outputs=*/fnOp.getResultTypes()); - func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); - Value c0, c1, numberQubitsValue; - fnFoldedOp.setPrivate(); - if (foldingAlgorithm == Folding(1)) { - Block *fnFoldedOpBlock = fnFoldedOp.addEntryBlock(); - rewriter.setInsertionPointToStart(fnFoldedOpBlock); - // Loop control variables - c0 = rewriter.create(loc, 0); - c1 = rewriter.create(loc, 1); - TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); - numberQubitsValue = rewriter.create(loc, numberQubitsAttr); - rewriter.create(loc, lib, name, kwargs); + return globalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, numberQubits, + fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, + fnWithoutMeasurementsOp, fnWithMeasurementsOp); } - else { - rewriter.cloneRegionBefore(fnOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); - Block *fnFoldedOpBlock = &fnFoldedOp.getBody().front(); - rewriter.setInsertionPointToStart(fnFoldedOpBlock); - // Loop control variables - c0 = rewriter.create(loc, 0); - c1 = rewriter.create(loc, 1); + rewriter.cloneRegionBefore(fnOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); - fnFoldedOpBlock->addArgument(fnFoldedOp.getArgumentTypes().front(), loc); - } + Block *fnFoldedOpBlock = &fnFoldedOp.getBody().front(); + rewriter.setInsertionPointToStart(fnFoldedOpBlock); + // Loop control variables + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + + fnFoldedOpBlock->addArgument(fnFoldedOp.getArgumentTypes().front(), loc); - if (foldingAlgorithm == Folding(1)) { - return globalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, fnFoldedType, - typesFolded, fnFoldedOp, fnAllocOp, numberQubitsValue, - fnWithoutMeasurementsOp, fnWithMeasurementsOp, c0, c1); - } if (foldingAlgorithm == Folding(2)) { return randomLocalFolding(rewriter, fnFoldedName, fnFoldedOp, c0, c1); } From ea78c00f8895dd1efb0eafbdb152cf1e074bbf17 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 3 Sep 2024 13:12:57 -0400 Subject: [PATCH 78/94] Per @rmoyard review --- mlir/include/Mitigation/IR/MitigationOps.td | 4 ++-- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/include/Mitigation/IR/MitigationOps.td b/mlir/include/Mitigation/IR/MitigationOps.td index 8cb682d26b..4e9b806ad8 100644 --- a/mlir/include/Mitigation/IR/MitigationOps.td +++ b/mlir/include/Mitigation/IR/MitigationOps.td @@ -27,8 +27,8 @@ def Folding : I32EnumAttr<"Folding", "Folding types", [ I32EnumAttrCase<"global", 1>, - I32EnumAttrCase<"random", 2>, - I32EnumAttrCase<"all", 3>, + I32EnumAttrCase<"all", 2>, + I32EnumAttrCase<"random", 3>, ]> { let cppNamespace = "catalyst::mitigation"; let genSpecializedAttr = 0; diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 3929e0847f..ff0c7aa3f4 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -332,10 +332,10 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew fnFoldedOpBlock->addArgument(fnFoldedOp.getArgumentTypes().front(), loc); if (foldingAlgorithm == Folding(2)) { - return randomLocalFolding(rewriter, fnFoldedName, fnFoldedOp, c0, c1); + return allLocalFolding(rewriter, fnFoldedName, fnFoldedOp, c0, c1); } // Else, if (foldingAlgorithm == Folding(3)): - return allLocalFolding(rewriter, fnFoldedName, fnFoldedOp, c0, c1); + return randomLocalFolding(rewriter, fnFoldedName, fnFoldedOp, c0, c1); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) From 204386dff1824aefc356494f9279c0fa8816ea47 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 3 Sep 2024 13:15:31 -0400 Subject: [PATCH 79/94] Per @cosenal review --- frontend/catalyst/api_extensions/error_mitigation.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index b5381788b7..61f4d0ffe6 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -91,9 +91,6 @@ def mitigated_circuit(args, n): return mitigate_with_zne(circuit, scale_factors=s)(args, n) """ - if folding.upper() == "RANDOM": - raise NotImplementedError("Random global folding not yet implemented!") - kwargs = copy.copy(locals()) kwargs.pop("fn") @@ -152,7 +149,7 @@ def __call__(self, *args, **kwargs): except ValueError as e: raise ValueError(f"Folding type must be one of {list(map(str, Folding))}") from e # TODO: remove the following check once #755 is completed - if folding != Folding.GLOBAL: + if folding == Folding.RANDOM: raise NotImplementedError(f"Folding type {folding.value} is being developed") results = zne_p.bind( From 301f2a2e60cb59dd4c3ec1f90a5e59ab2a2901c0 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 3 Sep 2024 13:32:01 -0400 Subject: [PATCH 80/94] Partial fix for unit test --- mlir/test/Mitigation/ZneFoldingAllTest.mlir | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/mlir/test/Mitigation/ZneFoldingAllTest.mlir b/mlir/test/Mitigation/ZneFoldingAllTest.mlir index c7075b529d..e54d503884 100644 --- a/mlir/test/Mitigation/ZneFoldingAllTest.mlir +++ b/mlir/test/Mitigation/ZneFoldingAllTest.mlir @@ -15,23 +15,22 @@ // RUN: quantum-opt %s --lower-mitigation --split-input-file --verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func private @circuit.folded(%arg0: index) -> tensor { - // CHECK: [[nQubits:%.+]] = arith.constant 2 // CHECK: [[c0:%.+]] = index.constant 0 // CHECK: [[c1:%.+]] = index.constant 1 // CHECK: quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] - // CHECK: [[qReg:%.+]] = call @circuit.quantumAlloc([[nQubits]]) : (i64) -> !quantum.reg + // CHECK: [[qReg:%.+]] = quantum.alloc( 2) : !quantum.reg // CHECK: [[q0:%.+]] = quantum.extract [[qReg]][ 0] : !quantum.reg -> !quantum.bit - // CHECK: [[q0_out:%.+]] = quantum.custom "Hadamard"() [[q0]] : !quantum.bit - // CHECK: [[q0_out_1:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] -> (!quantum.bit) { + // CHECK: [[q0_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q0_in:%.+]] = [[q0]]) -> (!quantum.bit) { + // CHECK: [[q0_out]] = quantum.custom "Hadamard"() [[q0_in]] : !quantum.bit // CHECK: [[q0_out]] = quantum.custom "Hadamard"() [[q0_out]] {adjoint} : !quantum.bit - // CHECK: [[q0_out]] = quantum.custom "Hadamard"() [[q0_out]] : !quantum.bit // CHECK: scf.yield [[q0_out]]: !quantum.bit + // CHECK: [[q0_out1:%.+]] = quantum.custom "Hadamard"() [[q0]] : !quantum.bit // CHECK: [[q1:%.+]] = quantum.extract [[qReg]][ 1] : !quantum.reg -> !quantum.bit - // CHECK: [[q01_out:%.+]] = quantum.custom "CNOT"() [[q0_out_1]],[[q1]] : !quantum.bit, !quantum.bit - // CHECK: [[q01_out2:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] -> (!quantum.bit, !quantum.bit) { - // CHECK: [[q01_out]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 {adjoint} : !quantum.bit, !quantum.bit + // CHECK: [[q01_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q01_in:%.+]] = [[q1]]) -> (!quantum.bit, !quantum.bit) { // CHECK: [[q01_out]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q01_out]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 {adjoint} : !quantum.bit, !quantum.bit // CHECK: scf.yield [[q01_out]] : (!quantum.bit, !quantum.bit) + // CHECK: [[q01_out2:%.+]] = quantum.custom "CNOT"() [[q0_out1]],[[q1]] : !quantum.bit, !quantum.bit // CHECK: [[q2:%.+]] = quantum.namedobs [[q01_out2]]#0[ PauliY] : !quantum.obs // CHECK: [[results:%.+]] = quantum.expval [[q1]] : f64 // CHECK: [[tensorRes:%.+]] = tensor.from_elements [[result]] : tensor @@ -41,11 +40,7 @@ // CHECK: quantum.device_release // CHECK: return [[tensorRes]] -// CHECK-LABEL: func.func private @circuit.quantumAlloc(%arg0: i64) -> !quantum.reg { - // CHECK: [[allocQreg:%.+]] = quantum.alloc(%arg0) : !quantum.reg - // CHECK: return [[allocQreg]] : !quantum.reg - -//CHECK-LABEL: func.func @circuit -> tensor attributes {qnode} { +//CHECK-LABEL: func.func @circuit() -> tensor attributes {qnode} { func.func @circuit() -> tensor attributes {qnode} { quantum.device ["rtd_lightning.so", "LightningQubit", "{shots: 0}"] %0 = quantum.alloc( 2) : !quantum.reg From b10aaefaa86608b88f1dd8b932bae79f84e1bedd Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 3 Sep 2024 15:09:04 -0400 Subject: [PATCH 81/94] Partial fix for unit test --- mlir/test/Mitigation/ZneFoldingAllTest.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Mitigation/ZneFoldingAllTest.mlir b/mlir/test/Mitigation/ZneFoldingAllTest.mlir index e54d503884..cdf8ffc784 100644 --- a/mlir/test/Mitigation/ZneFoldingAllTest.mlir +++ b/mlir/test/Mitigation/ZneFoldingAllTest.mlir @@ -27,7 +27,7 @@ // CHECK: [[q0_out1:%.+]] = quantum.custom "Hadamard"() [[q0]] : !quantum.bit // CHECK: [[q1:%.+]] = quantum.extract [[qReg]][ 1] : !quantum.reg -> !quantum.bit // CHECK: [[q01_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q01_in:%.+]] = [[q1]]) -> (!quantum.bit, !quantum.bit) { - // CHECK: [[q01_out]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q01_out]]:2 = quantum.custom "CNOT"() [[q01_in]]#0, [[q01_in]]#1 : !quantum.bit, !quantum.bit // CHECK: [[q01_out]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 {adjoint} : !quantum.bit, !quantum.bit // CHECK: scf.yield [[q01_out]] : (!quantum.bit, !quantum.bit) // CHECK: [[q01_out2:%.+]] = quantum.custom "CNOT"() [[q0_out1]],[[q1]] : !quantum.bit, !quantum.bit From 9d296eacc434116334a33337ebdc3a237628ebaf Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 4 Sep 2024 15:55:16 -0400 Subject: [PATCH 82/94] Partial unit test fix --- mlir/test/Mitigation/ZneFoldingAllTest.mlir | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/test/Mitigation/ZneFoldingAllTest.mlir b/mlir/test/Mitigation/ZneFoldingAllTest.mlir index cdf8ffc784..0012dadba7 100644 --- a/mlir/test/Mitigation/ZneFoldingAllTest.mlir +++ b/mlir/test/Mitigation/ZneFoldingAllTest.mlir @@ -21,16 +21,16 @@ // CHECK: [[qReg:%.+]] = quantum.alloc( 2) : !quantum.reg // CHECK: [[q0:%.+]] = quantum.extract [[qReg]][ 0] : !quantum.reg -> !quantum.bit // CHECK: [[q0_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q0_in:%.+]] = [[q0]]) -> (!quantum.bit) { - // CHECK: [[q0_out]] = quantum.custom "Hadamard"() [[q0_in]] : !quantum.bit - // CHECK: [[q0_out]] = quantum.custom "Hadamard"() [[q0_out]] {adjoint} : !quantum.bit - // CHECK: scf.yield [[q0_out]]: !quantum.bit - // CHECK: [[q0_out1:%.+]] = quantum.custom "Hadamard"() [[q0]] : !quantum.bit + // CHECK: [[q0_loop:%.+]] = quantum.custom "Hadamard"() [[q0_in]] : !quantum.bit + // CHECK: [[q0_loop2:%.+]] = quantum.custom "Hadamard"() [[q0_loop]] {adjoint} : !quantum.bit + // CHECK: scf.yield [[q0_loop2]] : !quantum.bit + // CHECK: [[q0_out2:%.+]] = quantum.custom "Hadamard"() [[q0_out]] : !quantum.bit // CHECK: [[q1:%.+]] = quantum.extract [[qReg]][ 1] : !quantum.reg -> !quantum.bit // CHECK: [[q01_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q01_in:%.+]] = [[q1]]) -> (!quantum.bit, !quantum.bit) { - // CHECK: [[q01_out]]:2 = quantum.custom "CNOT"() [[q01_in]]#0, [[q01_in]]#1 : !quantum.bit, !quantum.bit - // CHECK: [[q01_out]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 {adjoint} : !quantum.bit, !quantum.bit - // CHECK: scf.yield [[q01_out]] : (!quantum.bit, !quantum.bit) - // CHECK: [[q01_out2:%.+]] = quantum.custom "CNOT"() [[q0_out1]],[[q1]] : !quantum.bit, !quantum.bit + // CHECK: [[q01_loop:%.+]]:2 = quantum.custom "CNOT"() [[q01_in]]#0, [[q01_in]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q01_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q01_loop]]#0, [[q01_out]]#1 {adjoint} : !quantum.bit, !quantum.bit + // CHECK: scf.yield [[q01_loop2]] : (!quantum.bit, !quantum.bit) + // CHECK: [[q01_out2:%.+]] = quantum.custom "CNOT"() [[q0_out]]#0, [[q0_out]]#1 : !quantum.bit, !quantum.bit // CHECK: [[q2:%.+]] = quantum.namedobs [[q01_out2]]#0[ PauliY] : !quantum.obs // CHECK: [[results:%.+]] = quantum.expval [[q1]] : f64 // CHECK: [[tensorRes:%.+]] = tensor.from_elements [[result]] : tensor From f5b2c40fcc1ac32744b234273a38a5a113d34a69 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 4 Sep 2024 16:08:42 -0400 Subject: [PATCH 83/94] Partial unit test fix --- mlir/test/Mitigation/ZneFoldingAllTest.mlir | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/test/Mitigation/ZneFoldingAllTest.mlir b/mlir/test/Mitigation/ZneFoldingAllTest.mlir index 0012dadba7..40fd25f531 100644 --- a/mlir/test/Mitigation/ZneFoldingAllTest.mlir +++ b/mlir/test/Mitigation/ZneFoldingAllTest.mlir @@ -26,13 +26,13 @@ // CHECK: scf.yield [[q0_loop2]] : !quantum.bit // CHECK: [[q0_out2:%.+]] = quantum.custom "Hadamard"() [[q0_out]] : !quantum.bit // CHECK: [[q1:%.+]] = quantum.extract [[qReg]][ 1] : !quantum.reg -> !quantum.bit - // CHECK: [[q01_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q01_in:%.+]] = [[q1]]) -> (!quantum.bit, !quantum.bit) { - // CHECK: [[q01_loop:%.+]]:2 = quantum.custom "CNOT"() [[q01_in]]#0, [[q01_in]]#1 : !quantum.bit, !quantum.bit - // CHECK: [[q01_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q01_loop]]#0, [[q01_out]]#1 {adjoint} : !quantum.bit, !quantum.bit - // CHECK: scf.yield [[q01_loop2]] : (!quantum.bit, !quantum.bit) - // CHECK: [[q01_out2:%.+]] = quantum.custom "CNOT"() [[q0_out]]#0, [[q0_out]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q01_out:%.+]]:2 = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q01_in1:%.+]] = [[q0_out2]], [[q01_in2:%.+]] = [[q1]]) -> (!quantum.bit, !quantum.bit) { + // CHECK: [[q01_loop:%.+]]:2 = quantum.custom "CNOT"() [[q01_in1]], [[q01_in2]] : !quantum.bit, !quantum.bit + // CHECK: [[q01_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q01_loop]]#0, [[q01_loop]]#1 {adjoint} : !quantum.bit, !quantum.bit + // CHECK: scf.yield [[q01_loop2]]#0, [[q01_loop2]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q01_out2:%.+]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 : !quantum.bit, !quantum.bit // CHECK: [[q2:%.+]] = quantum.namedobs [[q01_out2]]#0[ PauliY] : !quantum.obs - // CHECK: [[results:%.+]] = quantum.expval [[q1]] : f64 + // CHECK: [[result:%.+]] = quantum.expval [[q2]] : f64 // CHECK: [[tensorRes:%.+]] = tensor.from_elements [[result]] : tensor // CHECK: [[q2:%.+]] = quantum.insert %0[ 0], [[q01_out2]]#0 : !quantum.reg, !quantum.bit // CHECK: [[q3:%.+]] = quantum.insert %7[ 1], [[q01_out2]]#1 : !quantum.reg, !quantum.bit From 3642b8dbfae4ba279bf64665a8923af4c7769548 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 4 Sep 2024 16:13:03 -0400 Subject: [PATCH 84/94] Partial unit test fix --- mlir/test/Mitigation/ZneFoldingAllTest.mlir | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/test/Mitigation/ZneFoldingAllTest.mlir b/mlir/test/Mitigation/ZneFoldingAllTest.mlir index 40fd25f531..46f2c77eff 100644 --- a/mlir/test/Mitigation/ZneFoldingAllTest.mlir +++ b/mlir/test/Mitigation/ZneFoldingAllTest.mlir @@ -36,7 +36,7 @@ // CHECK: [[tensorRes:%.+]] = tensor.from_elements [[result]] : tensor // CHECK: [[q2:%.+]] = quantum.insert %0[ 0], [[q01_out2]]#0 : !quantum.reg, !quantum.bit // CHECK: [[q3:%.+]] = quantum.insert %7[ 1], [[q01_out2]]#1 : !quantum.reg, !quantum.bit - // CHECK: quantum.dealloc [[q2]] : !quantum.reg + // CHECK: quantum.dealloc [[q3]] : !quantum.reg // CHECK: quantum.device_release // CHECK: return [[tensorRes]] @@ -62,6 +62,7 @@ func.func @circuit() -> tensor attributes {qnode} { //CHECK: [[c0:%.+]] = index.constant 0 //CHECK: [[c1:%.+]] = index.constant 1 //CHECK: [[c3:%.+]] = index.constant 3 + //CHECK: [[t:%.+]] = tensor.empty() : tensor<3xf64> //CHECK: [[dense3:%.+]] = arith.constant dense<[1, 2, 3]> //CHECK: [[emptyRes:%.+]] = tensor.empty() : tensor<3xf64> //CHECK: [[results:%.+]] = scf.for [[idx:%.+]] = [[c0]] to [[c3]] step [[c1]] iter_args(%arg1 = [[emptyRes]]) -> (tensor<3xf64>) { From d81bddade2edb7e955474ef8db8baf88382b6fff Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Wed, 4 Sep 2024 17:05:54 -0400 Subject: [PATCH 85/94] Passing unit test --- mlir/test/Mitigation/ZneFoldingAllTest.mlir | 32 ++++++++++----------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/mlir/test/Mitigation/ZneFoldingAllTest.mlir b/mlir/test/Mitigation/ZneFoldingAllTest.mlir index 46f2c77eff..dbfb777d99 100644 --- a/mlir/test/Mitigation/ZneFoldingAllTest.mlir +++ b/mlir/test/Mitigation/ZneFoldingAllTest.mlir @@ -59,22 +59,22 @@ func.func @circuit() -> tensor attributes {qnode} { } //CHECK-LABEL: func.func @mitigated_circuit() - //CHECK: [[c0:%.+]] = index.constant 0 - //CHECK: [[c1:%.+]] = index.constant 1 - //CHECK: [[c3:%.+]] = index.constant 3 - //CHECK: [[t:%.+]] = tensor.empty() : tensor<3xf64> - //CHECK: [[dense3:%.+]] = arith.constant dense<[1, 2, 3]> - //CHECK: [[emptyRes:%.+]] = tensor.empty() : tensor<3xf64> - //CHECK: [[results:%.+]] = scf.for [[idx:%.+]] = [[c0]] to [[c3]] step [[c1]] iter_args(%arg1 = [[emptyRes]]) -> (tensor<3xf64>) { - //CHECK: [[scaleFactor:%.+]] = tensor.extract [[dense3]][[[idx]]] : tensor<3xindex> - //CHECK: [[intermediateRes:%.+]] = func.call @circuit.folded([[scaleFactor]]) : (index) -> tensor - //CHECK: [[tensorRes:%.+]] = tensor.from_elements [[intermediateRes]] : tensor<1xf64> - //CHECK: [[resultsFor:%.+]] = scf.for %arg2 = [[c0]] to [[c1]] step [[c1]] iter_args(%arg3 = %arg1) -> (tensor<3xf64>) { - //CHECK: [[extracted:%.+]] = tensor.extract [[tensorRes]][%arg3] : tensor<1xf64> - //CHECK: [[insertedRes:%.+]] = tensor.insert [[extracted]] into %arg3[%arg1] : tensor<5xf64> - //CHECK: scf.yield [[insertedRes]] - //CHECK: scf.yield [[resultsFor]] - //CHECK: return [[results]] + //CHECK-DAG: [[c0:%.+]] = index.constant 0 + //CHECK-DAG: [[c1:%.+]] = index.constant 1 + //CHECK-DAG: [[c3:%.+]] = index.constant 3 + //CHECK-DAG: [[emptyRes:%.+]] = tensor.empty() : tensor<3xf64> + //CHECK-DAG: [[dense3:%.+]] = arith.constant dense<[1, 2, 3]> + // CHECK: [[results:%.+]] = scf.for [[idx:%.+]] = [[c0]] to [[c3]] step [[c1]] iter_args([[emptyArg:%.+]] = [[emptyRes]]) -> (tensor<3xf64>) { + // CHECK: [[scalarFactor:%.+]] = tensor.extract [[dense3]][[[idx]]] : tensor<3xindex> + // CHECK: [[intermediateRes:%.+]] = func.call @circuit.folded([[scalarFactor]]) : (index) -> tensor + // CHECK: [[extracted:%.+]] = tensor.extract [[intermediateRes]][] : tensor + // CHECK: [[from_elements:%.+]] = tensor.from_elements [[extracted]] : tensor<1xf64> + // CHECK: [[resultsFor:%.+]] = scf.for [[idxJ:%.+]] = [[c0]] to [[c1]] step [[c1]] iter_args([[scalarArg:%.+]] = [[emptyArg]]) -> (tensor<3xf64>) { + // CHECK: [[extracted:%.+]] = tensor.extract %from_elements[%arg2] : tensor<1xf64> + // CHECK: [[insertedRes:%.+]] = tensor.insert [[extracted]] into %arg3[%arg0] : tensor<3xf64> + // CHECK: scf.yield [[insertedRes]] + // CHECK: scf.yield [[resultsFor]] + // CHECK: return [[results]] func.func @mitigated_circuit() -> tensor<3xf64> { %scaleFactors = arith.constant dense<[1, 2, 3]> : tensor<3xindex> %0 = mitigation.zne @circuit() folding (all) scaleFactors (%scaleFactors : tensor<3xindex>) : () -> tensor<3xf64> From b6acb1ce89ecd2837a45ce4c49675e8f680b23d9 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 5 Sep 2024 14:58:51 -0400 Subject: [PATCH 86/94] Unit test for adjoint and more qubits, per @rmoyard review --- mlir/test/Mitigation/ZneFoldingAllTest2.mlir | 105 +++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 mlir/test/Mitigation/ZneFoldingAllTest2.mlir diff --git a/mlir/test/Mitigation/ZneFoldingAllTest2.mlir b/mlir/test/Mitigation/ZneFoldingAllTest2.mlir new file mode 100644 index 0000000000..fcd4e94eba --- /dev/null +++ b/mlir/test/Mitigation/ZneFoldingAllTest2.mlir @@ -0,0 +1,105 @@ +// Copyright 2023 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt %s --lower-mitigation --split-input-file --verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func private @circuit.folded(%arg0: index) -> tensor { + // CHECK: [[c0:%.+]] = index.constant 0 + // CHECK: [[c1:%.+]] = index.constant 1 + // CHECK: quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] + // CHECK: [[qReg:%.+]] = quantum.alloc( 4) : !quantum.reg + // CHECK: [[q0:%.+]] = quantum.extract [[qReg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[q0_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q0_in:%.+]] = [[q0]]) -> (!quantum.bit) { + // CHECK: [[q0_loop:%.+]] = quantum.custom "Hadamard"() [[q0_in]] : !quantum.bit + // CHECK: [[q0_loop2:%.+]] = quantum.custom "Hadamard"() [[q0_loop]] {adjoint} : !quantum.bit + // CHECK: scf.yield [[q0_loop2]] : !quantum.bit + // CHECK: [[q0_out2:%.+]] = quantum.custom "Hadamard"() [[q0_out]] : !quantum.bit + // CHECK: [[q1:%.+]] = quantum.extract [[qReg]][ 1] : !quantum.reg -> !quantum.bit + // CHECK: [[q01_out:%.+]]:2 = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q01_in1:%.+]] = [[q0_out2]], [[q01_in2:%.+]] = [[q1]]) -> (!quantum.bit, !quantum.bit) { + // CHECK: [[q01_loop:%.+]]:2 = quantum.custom "CNOT"() [[q01_in1]], [[q01_in2]] : !quantum.bit, !quantum.bit + // CHECK: [[q01_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q01_loop]]#0, [[q01_loop]]#1 {adjoint} : !quantum.bit, !quantum.bit + // CHECK: scf.yield [[q01_loop2]]#0, [[q01_loop2]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q01_out2:%.+]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q2:%.+]] = quantum.extract [[qReg]][ 2] : !quantum.reg -> !quantum.bit + // CHECK: [[q12_out:%.+]]:2 = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q12_in1:%.+]] = [[q01_out2]]#1, [[q12_in2:%.+]] = [[q2]]) -> (!quantum.bit, !quantum.bit) { + // CHECK: [[q12_loop:%.+]]:2 = quantum.custom "CNOT"() [[q12_in1]], [[q12_in2]] : !quantum.bit, !quantum.bit + // CHECK: [[q12_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q12_loop]]#0, [[q12_loop]]#1 {adjoint} : !quantum.bit, !quantum.bit + // CHECK: scf.yield [[q12_loop2]]#0, [[q12_loop2]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q12_out2:%.+]]:2 = quantum.custom "CNOT"() [[q12_out]]#0, [[q12_out]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q1_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q1_in:%.+]] = [[q12_out2]]#0) -> (!quantum.bit) { + // CHECK: [[q1_loop:%.+]] = quantum.custom "T"() [[q1_in]] : !quantum.bit + // CHECK: [[q1_loop2:%.+]] = quantum.custom "T"() [[q1_loop]] {adjoint} : !quantum.bit + // CHECK: scf.yield [[q1_loop2]] : !quantum.bit + // CHECK: [[q1_out2:%.+]] = quantum.custom "T"() [[q1_out]] : !quantum.bit + // CHECK: [[q3:%.+]] = quantum.extract [[qReg]][ 3] : !quantum.reg -> !quantum.bit + // CHECK: [[q23_out:%.+]]:2 = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q23_in1:%.+]] = [[q12_out2]]#1, [[q23_in2:%.+]] = [[q3]]) -> (!quantum.bit, !quantum.bit) { + // CHECK: [[q23_loop:%.+]]:2 = quantum.custom "CNOT"() [[q23_in1]], [[q23_in2]] : !quantum.bit, !quantum.bit + // CHECK: [[q23_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q23_loop]]#0, [[q23_loop]]#1 {adjoint} : !quantum.bit, !quantum.bit + // CHECK: scf.yield [[q23_loop2]]#0, [[q23_loop2]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q23_out2:%.+]]:2 = quantum.custom "CNOT"() [[q23_out]]#0, [[q23_out]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q3_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q3_in:%.+]] = [[q23_out2]]#1) -> (!quantum.bit) { + // CHECK: [[q3_loop:%.+]] = quantum.custom "T"() [[q3_in]] {adjoint} : !quantum.bit + // CHECK: [[q3_loop2:%.+]] = quantum.custom "T"() [[q3_loop]] : !quantum.bit + // CHECK: scf.yield [[q3_loop2]] : !quantum.bit + // CHECK: [[q3_out2:%.+]] = quantum.custom "T"() [[q3_out]] {adjoint} : !quantum.bit + + +//CHECK-LABEL: func.func @circuit() -> tensor attributes {qnode} { +func.func @circuit() -> tensor attributes {qnode} { + quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] + %0 = quantum.alloc( 4) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %out_qubits_0:2 = quantum.custom "CNOT"() %out_qubits, %2 : !quantum.bit, !quantum.bit + %3 = quantum.extract %0[ 2] : !quantum.reg -> !quantum.bit + %out_qubits_1:2 = quantum.custom "CNOT"() %out_qubits_0#1, %3 : !quantum.bit, !quantum.bit + %out_qubits_2 = quantum.custom "T"() %out_qubits_1#0 : !quantum.bit + %4 = quantum.extract %0[ 3] : !quantum.reg -> !quantum.bit + %out_qubits_3:2 = quantum.custom "CNOT"() %out_qubits_1#1, %4 : !quantum.bit, !quantum.bit + %out_qubits_4 = quantum.custom "T"() %out_qubits_3#1 {adjoint} : !quantum.bit + %5 = quantum.namedobs %out_qubits_0#0[ PauliY] : !quantum.obs + %6 = quantum.expval %5 {shots = 5 : i64} : f64 + %from_elements = tensor.from_elements %6 : tensor + %7 = quantum.insert %0[ 0], %out_qubits_0#0 : !quantum.reg, !quantum.bit + %8 = quantum.insert %7[ 1], %out_qubits_2 : !quantum.reg, !quantum.bit + %9 = quantum.insert %8[ 2], %out_qubits_3#0 : !quantum.reg, !quantum.bit + %10 = quantum.insert %9[ 3], %out_qubits_4 : !quantum.reg, !quantum.bit + quantum.dealloc %10 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + +//CHECK-LABEL: func.func @mitigated_circuit() + //CHECK-DAG: [[c0:%.+]] = index.constant 0 + //CHECK-DAG: [[c1:%.+]] = index.constant 1 + //CHECK-DAG: [[c3:%.+]] = index.constant 3 + //CHECK-DAG: [[emptyRes:%.+]] = tensor.empty() : tensor<3xf64> + //CHECK-DAG: [[dense3:%.+]] = arith.constant dense<[1, 2, 3]> + // CHECK: [[results:%.+]] = scf.for [[idx:%.+]] = [[c0]] to [[c3]] step [[c1]] iter_args([[emptyArg:%.+]] = [[emptyRes]]) -> (tensor<3xf64>) { + // CHECK: [[scalarFactor:%.+]] = tensor.extract [[dense3]][[[idx]]] : tensor<3xindex> + // CHECK: [[intermediateRes:%.+]] = func.call @circuit.folded([[scalarFactor]]) : (index) -> tensor + // CHECK: [[extracted:%.+]] = tensor.extract [[intermediateRes]][] : tensor + // CHECK: [[from_elements:%.+]] = tensor.from_elements [[extracted]] : tensor<1xf64> + // CHECK: [[resultsFor:%.+]] = scf.for [[idxJ:%.+]] = [[c0]] to [[c1]] step [[c1]] iter_args([[scalarArg:%.+]] = [[emptyArg]]) -> (tensor<3xf64>) { + // CHECK: [[extracted:%.+]] = tensor.extract %from_elements[%arg2] : tensor<1xf64> + // CHECK: [[insertedRes:%.+]] = tensor.insert [[extracted]] into %arg3[%arg0] : tensor<3xf64> + // CHECK: scf.yield [[insertedRes]] + // CHECK: scf.yield [[resultsFor]] + // CHECK: return [[results]] +func.func @mitigated_circuit() -> tensor<3xf64> { + %scaleFactors = arith.constant dense<[1, 2, 3]> : tensor<3xindex> + %0 = mitigation.zne @circuit() folding (all) scaleFactors (%scaleFactors : tensor<3xindex>) : () -> tensor<3xf64> + func.return %0 : tensor<3xf64> +} From ce2ab05b299713b78e79ecfe48a4d1d88ef0e0c9 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Thu, 5 Sep 2024 16:04:06 -0400 Subject: [PATCH 87/94] Pytest parameterization for local folding --- frontend/test/pytest/test_mitigation.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/frontend/test/pytest/test_mitigation.py b/frontend/test/pytest/test_mitigation.py index ec0dfda7a3..f163bf8740 100644 --- a/frontend/test/pytest/test_mitigation.py +++ b/frontend/test/pytest/test_mitigation.py @@ -34,7 +34,8 @@ def skip_if_exponential_extrapolation_unstable(circuit_param, extrapolation_func @pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5]) @pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate]) -def test_single_measurement(params, extrapolation): +@pytest.mark.parametrize("folding", ["global", "all"]) +def test_single_measurement(params, extrapolation, folding): """Test that without noise the same results are returned for single measurements.""" skip_if_exponential_extrapolation_unstable(params, extrapolation) @@ -52,7 +53,7 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation + circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation, folding=folding )(args) assert np.allclose(mitigated_qnode(params), circuit(params)) @@ -60,7 +61,8 @@ def mitigated_qnode(args): @pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5]) @pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate]) -def test_multiple_measurements(params, extrapolation): +@pytest.mark.parametrize("folding", ["global", "all"]) +def test_multiple_measurements(params, extrapolation, folding): """Test that without noise the same results are returned for multiple measurements""" skip_if_exponential_extrapolation_unstable(params, extrapolation) @@ -78,14 +80,15 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation + circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation, folding=folding )(args) assert np.allclose(mitigated_qnode(params), circuit(params)) @pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5]) -def test_single_measurement_control_flow(params): +@pytest.mark.parametrize("folding", ["global", "all"]) +def test_single_measurement_control_flow(params, folding): """Test that without noise the same results are returned for single measurement and with control flow.""" dev = qml.device("lightning.qubit", wires=2) @@ -113,7 +116,7 @@ def loop_1(i): # pylint: disable=unused-argument @catalyst.qjit def mitigated_qnode(args, n): - return catalyst.mitigate_with_zne(circuit, scale_factors=jax.numpy.array([1, 2, 3]))( + return catalyst.mitigate_with_zne(circuit, scale_factors=jax.numpy.array([1, 2, 3]), folding=folding)( args, n ) @@ -238,7 +241,7 @@ def circuit(): return 0.0 def mitigated_qnode(): - return catalyst.mitigate_with_zne(circuit, scale_factors=[], folding="all")() + return catalyst.mitigate_with_zne(circuit, scale_factors=[], folding="random")() with pytest.raises(NotImplementedError): catalyst.qjit(mitigated_qnode) @@ -246,7 +249,8 @@ def mitigated_qnode(): @pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5]) @pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate]) -def test_zne_usage_patterns(params, extrapolation): +@pytest.mark.parametrize("folding", ["global", "all"]) +def test_zne_usage_patterns(params, extrapolation, folding): """Test usage patterns of catalyst.zne.""" skip_if_exponential_extrapolation_unstable(params, extrapolation) @@ -264,13 +268,13 @@ def fn(x): @catalyst.qjit def mitigated_qnode_fn_as_argument(args): return catalyst.mitigate_with_zne( - fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation + fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation, folding=folding )(args) @catalyst.qjit def mitigated_qnode_partial(args): return catalyst.mitigate_with_zne( - scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation + scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation, folding=folding )(fn)(args) assert np.allclose(mitigated_qnode_fn_as_argument(params), fn(params)) From 4dc195c9a7e42b4d125dad9510e040d048c9ba5c Mon Sep 17 00:00:00 2001 From: Daniel Strano Date: Fri, 6 Sep 2024 10:45:50 -0400 Subject: [PATCH 88/94] Update mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp Co-authored-by: Romain Moyard --- mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index ff0c7aa3f4..ec3005860b 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -329,7 +329,7 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); - fnFoldedOpBlock->addArgument(fnFoldedOp.getArgumentTypes().front(), loc); + fnFoldedOpBlock->addArgument(fnFoldedOp.getArgumentTypes().back(), loc); if (foldingAlgorithm == Folding(2)) { return allLocalFolding(rewriter, fnFoldedName, fnFoldedOp, c0, c1); From 3c23dc7aa750dbf27b6f4c1f9262d6558a797395 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Fri, 6 Sep 2024 13:29:31 -0400 Subject: [PATCH 89/94] Local folding docs (per @rmoyard review) --- doc/releases/changelog-dev.md | 35 ++++++++++++++++++- .../api_extensions/error_mitigation.py | 1 + 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 04714f19e4..19013299fd 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -26,6 +26,38 @@ Array([[1], [0], [1], [1], [0], [1],[0]], dtype=int64)) ``` +* ZNE local folding: Introduces the option to fold gates "locally" as well as the existing method of "globally." Global folding (as in previous versions) applies the scale factor by forming the inverse of the entire quantum circuit (without measurements) and repeating the circuit with its inverse; local folding inserts per-gate folding sequences directly in place of each gate in the original circuit instead of applying the scale factor to the entire circuit at once. [(#1006)](https://github.com/PennyLaneAI/catalyst/pull/1006) + + For example, + + ```python + import jax + import pennylane as qml + from catalyst import qjit, mitigate_with_zne + from pennylane.transforms import exponential_extrapolate + + dev = qml.device("lightning.qubit", wires=4, shots=5) + + @qml.qnode(dev) + def circuit(): + qml.Hadamard(wires=0) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliY(wires=0)) + + @qjit(keep_intermediate=True) + def mitigated_circuit(): + s = jax.numpy.array([1, 2, 3]) + return mitigate_with_zne( + circuit, + scale_factors=s, + extrapolate=exponential_extrapolate, + folding="all" #"all" for local or "global" for the original method, default is "global + )() + +print(circuit()) +print(mitigated_circuit()) + ``` +

Improvements

Breaking changes

@@ -49,4 +81,5 @@ This release contains contributions from (in alphabetical order): Romain Moyard, Paul Haochen Wang, -Sengthai Heng, \ No newline at end of file +Sengthai Heng, +Daniel Strano diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index b504ed9110..f303dd8a52 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -56,6 +56,7 @@ def mitigate_with_zne( function. folding (str): Unitary folding technique to be used to scale the circuit. Possible values: - global: the global unitary of the input circuit is folded + - all: per-gate folding sequences replace original gates in-place in the circuit Returns: Callable: A callable object that computes the mitigated of the wrapped :class:`~.QNode` From 2c1192af35cd73372360a5d1c23ac6b0e5ddc715 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Fri, 6 Sep 2024 14:17:25 -0400 Subject: [PATCH 90/94] Per @cosenal review --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 19013299fd..7b252baf3b 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -26,7 +26,7 @@ Array([[1], [0], [1], [1], [0], [1],[0]], dtype=int64)) ``` -* ZNE local folding: Introduces the option to fold gates "locally" as well as the existing method of "globally." Global folding (as in previous versions) applies the scale factor by forming the inverse of the entire quantum circuit (without measurements) and repeating the circuit with its inverse; local folding inserts per-gate folding sequences directly in place of each gate in the original circuit instead of applying the scale factor to the entire circuit at once. [(#1006)](https://github.com/PennyLaneAI/catalyst/pull/1006) +* ZNE local folding: Introduces the option to fold gates locally as well as the existing method of globally. Global folding (as in previous versions) applies the scale factor by forming the inverse of the entire quantum circuit (without measurements) and repeating the circuit with its inverse; local folding inserts per-gate folding sequences directly in place of each gate in the original circuit instead of applying the scale factor to the entire circuit at once. [(#1006)](https://github.com/PennyLaneAI/catalyst/pull/1006) For example, From 4bdad2b641a25118c1a0536665c7425df788f948 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Fri, 6 Sep 2024 14:18:14 -0400 Subject: [PATCH 91/94] Per @cosenal review --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 7b252baf3b..e15896a49d 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -26,7 +26,7 @@ Array([[1], [0], [1], [1], [0], [1],[0]], dtype=int64)) ``` -* ZNE local folding: Introduces the option to fold gates locally as well as the existing method of globally. Global folding (as in previous versions) applies the scale factor by forming the inverse of the entire quantum circuit (without measurements) and repeating the circuit with its inverse; local folding inserts per-gate folding sequences directly in place of each gate in the original circuit instead of applying the scale factor to the entire circuit at once. [(#1006)](https://github.com/PennyLaneAI/catalyst/pull/1006) +* Zero-Noise Extrapolation (ZNE) local folding: Introduces the option to fold gates locally as well as the existing method of globally. Global folding (as in previous versions) applies the scale factor by forming the inverse of the entire quantum circuit (without measurements) and repeating the circuit with its inverse; local folding inserts per-gate folding sequences directly in place of each gate in the original circuit instead of applying the scale factor to the entire circuit at once. [(#1006)](https://github.com/PennyLaneAI/catalyst/pull/1006) For example, From bde26f412d08927cf47afde9cac3f7d85dd981cb Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Fri, 6 Sep 2024 14:19:34 -0400 Subject: [PATCH 92/94] Per @cosenal review --- .../{ZneFoldingAllTest2.mlir => ZneFoldingAllFullTest.mlir} | 0 .../{ZneFoldingAllTest.mlir => ZneFoldingAllMinimalTest.mlir} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename mlir/test/Mitigation/{ZneFoldingAllTest2.mlir => ZneFoldingAllFullTest.mlir} (100%) rename mlir/test/Mitigation/{ZneFoldingAllTest.mlir => ZneFoldingAllMinimalTest.mlir} (100%) diff --git a/mlir/test/Mitigation/ZneFoldingAllTest2.mlir b/mlir/test/Mitigation/ZneFoldingAllFullTest.mlir similarity index 100% rename from mlir/test/Mitigation/ZneFoldingAllTest2.mlir rename to mlir/test/Mitigation/ZneFoldingAllFullTest.mlir diff --git a/mlir/test/Mitigation/ZneFoldingAllTest.mlir b/mlir/test/Mitigation/ZneFoldingAllMinimalTest.mlir similarity index 100% rename from mlir/test/Mitigation/ZneFoldingAllTest.mlir rename to mlir/test/Mitigation/ZneFoldingAllMinimalTest.mlir From 43ca20853783ed50aa9c1885e3afc9d5e0d4f457 Mon Sep 17 00:00:00 2001 From: Daniel Strano Date: Fri, 6 Sep 2024 14:25:59 -0400 Subject: [PATCH 93/94] Update doc/releases/changelog-dev.md Co-authored-by: Romain Moyard --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index e15896a49d..b8637a7d7f 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -51,7 +51,7 @@ circuit, scale_factors=s, extrapolate=exponential_extrapolate, - folding="all" #"all" for local or "global" for the original method, default is "global + folding="all" #"all" for local or "global" for the original method, default is "global" )() print(circuit()) From d11cd9d02be47632839ba604cad8cad4835eca94 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Fri, 6 Sep 2024 14:31:04 -0400 Subject: [PATCH 94/94] make format --- frontend/test/pytest/test_mitigation.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/frontend/test/pytest/test_mitigation.py b/frontend/test/pytest/test_mitigation.py index f163bf8740..f75d8ce9db 100644 --- a/frontend/test/pytest/test_mitigation.py +++ b/frontend/test/pytest/test_mitigation.py @@ -53,7 +53,10 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation, folding=folding + circuit, + scale_factors=jax.numpy.array([1, 2, 3]), + extrapolate=extrapolation, + folding=folding, )(args) assert np.allclose(mitigated_qnode(params), circuit(params)) @@ -80,7 +83,10 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation, folding=folding + circuit, + scale_factors=jax.numpy.array([1, 2, 3]), + extrapolate=extrapolation, + folding=folding, )(args) assert np.allclose(mitigated_qnode(params), circuit(params)) @@ -116,9 +122,9 @@ def loop_1(i): # pylint: disable=unused-argument @catalyst.qjit def mitigated_qnode(args, n): - return catalyst.mitigate_with_zne(circuit, scale_factors=jax.numpy.array([1, 2, 3]), folding=folding)( - args, n - ) + return catalyst.mitigate_with_zne( + circuit, scale_factors=jax.numpy.array([1, 2, 3]), folding=folding + )(args, n) assert np.allclose(mitigated_qnode(params, 3), catalyst.qjit(circuit)(params, 3))