Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add seeded execution to the Catalyst runtime #936

Merged
merged 42 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
da4dd08
unskipping the flaky test `test_dynamic_one_shot_several_mcms` to run…
paul0403 Jul 16, 2024
784da98
Implemented a random seeding infrastructure for qjit.
paul0403 Jul 18, 2024
d8439e7
separating the runtime init capi into a seeded one and an unseeded one
paul0403 Jul 18, 2024
b1643cd
format
paul0403 Jul 18, 2024
3a3bd23
add oqc device
paul0403 Jul 18, 2024
c26e2bf
Adding a `has_seed` method in the device classes as a quality of life…
paul0403 Jul 19, 2024
4077b79
Merge remote-tracking branch 'origin' into flaky_test_dynamic_one_shot
paul0403 Jul 19, 2024
d9e22b7
addressing some comments
paul0403 Jul 19, 2024
99b2139
Unifying ALL __catalyst__rt__initilize to take in a char*
paul0403 Jul 19, 2024
ddf0abc
Set the default char *seed = nullptr in the declaration of __catalyst…
paul0403 Jul 19, 2024
f108b08
removing the seed string from the devices, as they only need a pointe…
paul0403 Jul 19, 2024
0a18c19
add frontend tests
paul0403 Jul 19, 2024
4b1d6af
- making `SetDevicePRNG` a method in the base class `QuantumDevice` w…
paul0403 Jul 22, 2024
34d0dc6
format
paul0403 Jul 22, 2024
7ff4746
Merge remote-tracking branch 'origin/main' into flaky_test_dynamic_on…
paul0403 Jul 22, 2024
38a8f74
giving the decive prng attributes a default nullptr value
paul0403 Jul 22, 2024
622f158
changing the `__catalyst__rt__initialize` to take in char* (i8* null)…
paul0403 Jul 22, 2024
de4b4db
pylint
paul0403 Jul 22, 2024
77f7d58
Merge remote-tracking branch 'origin/main' into flaky_test_dynamic_on…
paul0403 Jul 22, 2024
9b51e0c
changelog
paul0403 Jul 22, 2024
60e4b07
changelog typo
paul0403 Jul 22, 2024
a059c16
add a getter for the devicePRNG; add runtime tests for devicePRNG
paul0403 Jul 22, 2024
d213566
skipping frontend kokkos test on x86 mac; add runtime tests for seede…
paul0403 Jul 23, 2024
40771d8
Merge remote-tracking branch 'origin/main' into flaky_test_dynamic_on…
paul0403 Jul 23, 2024
32a4aff
[TEMPORARY] removing the seed in the frontend test to see if mac x86 …
paul0403 Jul 23, 2024
3579255
add seeding tests for openqasm to test the default set/getdevicePRNG …
paul0403 Jul 23, 2024
7a6bd7d
Merge remote-tracking branch 'origin/main' into flaky_test_dynamic_on…
paul0403 Jul 23, 2024
23aab7d
disallow seeding to be used together with async
paul0403 Jul 23, 2024
46bf0ec
format
paul0403 Jul 23, 2024
d943882
Merge remote-tracking branch 'origin/main' into flaky_test_dynamic_on…
paul0403 Jul 23, 2024
fb8042f
addressing comments
paul0403 Jul 23, 2024
859d772
Merge remote-tracking branch 'origin/main' into flaky_test_dynamic_on…
paul0403 Jul 23, 2024
b40aa4d
add qjit documentation to reflect the unseeded shots in lightning
paul0403 Jul 23, 2024
ae3ea16
As per Lee and Josh's suggestion, we change the seed from a string to…
paul0403 Jul 24, 2024
75ba76b
remove the GetDevicePRNG getter, and remove the associated tests
paul0403 Jul 24, 2024
64443f9
format
paul0403 Jul 24, 2024
aa60345
changelog example fix
paul0403 Jul 24, 2024
333a593
Merge remote-tracking branch 'origin/main' into flaky_test_dynamic_on…
paul0403 Jul 25, 2024
9d57b13
Merge remote-tracking branch 'origin/main' into flaky_test_dynamic_on…
paul0403 Jul 25, 2024
a49cd41
removing the default argument in __catalyst__rt__initialize and chang…
paul0403 Jul 25, 2024
ad1c6bd
using initilizer list for ExecutionContext.seed instead of having two…
paul0403 Jul 25, 2024
276761e
remove `this->seed = seed` after initialization list
paul0403 Jul 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class CompileOptions:
static_argnums: Optional[Union[int, Iterable[int]]] = None
abstracted_axes: Optional[Union[Iterable[Iterable[str]], Dict[int, str]]] = None
lower_to_llvm: Optional[bool] = True
seed: Optional[str] = ""

def __post_init__(self):
# Make the format of static_argnums easier to handle.
Expand Down
8 changes: 7 additions & 1 deletion frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def qjit(
pipelines=None,
static_argnums=None,
abstracted_axes=None,
seed="",
): # pylint: disable=too-many-arguments,unused-argument
"""A just-in-time decorator for PennyLane and JAX programs using Catalyst.

Expand Down Expand Up @@ -123,6 +124,11 @@ def qjit(
Function arguments with ``abstracted_axes`` specified will be compiled to ranked tensors
with dynamic shapes. For more details, please see the Dynamically-shaped Arrays section
below.
seed (str):
The seed for random operations in a qjit call, such as circuit measurement results.
The default value is an empty string, which means no seeding is performed, and all
processes are random.
Note that if the circuit is run from shots, the sampled results are NOT seeded.

Returns:
QJIT object.
Expand Down Expand Up @@ -652,7 +658,7 @@ def generate_ir(self):
mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)

# Inject Runtime Library-specific functions (e.g. setup/teardown).
inject_functions(mlir_module, ctx)
inject_functions(mlir_module, ctx, self.compile_options.seed)

# Canonicalize the MLIR since there can be a lot of redundancy coming from JAX.
options = copy.deepcopy(self.compile_options)
Expand Down
4 changes: 4 additions & 0 deletions frontend/catalyst/third_party/oqc/src/OQCDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ void OQCDevice::SetDeviceShots([[maybe_unused]] size_t shots) { device_shots = s

auto OQCDevice::GetDeviceShots() const -> size_t { return device_shots; }

void OQCDevice::SetDeviceSeed([[maybe_unused]] std::string _seed) { seed = _seed; }

void OQCDevice::SetDevicePRNG([[maybe_unused]] std::mt19937 *_gen) { gen = _gen; }
paul0403 marked this conversation as resolved.
Show resolved Hide resolved

auto OQCDevice::Zero() const -> Result { return const_cast<Result>(&GLOBAL_RESULT_FALSE_CONST); }

auto OQCDevice::One() const -> Result { return const_cast<Result>(&GLOBAL_RESULT_TRUE_CONST); }
Expand Down
3 changes: 3 additions & 0 deletions frontend/catalyst/third_party/oqc/src/OQCDevice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class OQCDevice final : public Catalyst::Runtime::QuantumDevice {
Catalyst::Runtime::CacheManager<std::complex<double>> cache_manager{};
bool tape_recording{false};
size_t device_shots;

std::string seed;
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
std::mt19937 *gen;

std::unordered_map<std::string, std::string> device_kwargs;

Expand Down
16 changes: 12 additions & 4 deletions frontend/catalyst/utils/gen_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@
from jaxlib.mlir.dialects._func_ops_gen import FuncOp


def gen_setup(ctx):
def gen_setup(ctx, seed):
"""
This function returns an MLIR module with the "setup" function. The setup
function is a function that needs to be called before calling a
JIT-compiled function. It initializes the global device context in the runtime.
"""
txt = """
if seed != "":
txt = f"""
func.func @setup() -> () {{
"quantum.init"() {{seed = "{seed}"}} : () -> ()
return
}}
"""
else:
txt = """
func.func @setup() -> () {
"quantum.init"() : () -> ()
return
Expand All @@ -50,14 +58,14 @@ def gen_teardown(ctx):
return ir.Module.parse(txt, ctx)


def inject_functions(module, ctx):
def inject_functions(module, ctx, seed):
"""
This function appends functions to the input module.
"""
# Add C interface for the quantum function.
module.body.operations[0].attributes["llvm.emit_c_interface"] = ir.UnitAttr.get(context=ctx)

setup_module = gen_setup(ctx)
setup_module = gen_setup(ctx, seed)
setup_func = setup_module.body.operations[0]
module.body.append(setup_func)

Expand Down
7 changes: 1 addition & 6 deletions frontend/test/pytest/test_mid_circuit_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,6 @@ def outer():
assert result.shape == (shots,)
assert jnp.allclose(result, 1.0)

# TODO: dynamic_one_shot_several_mcms is a flaky test.
# We remove this test for now and revisit in the future.
@pytest.mark.skip(
reason="dynamic_one_shot_several_mcms is a flaky test and needs further investigation"
)
@pytest.mark.parametrize("shots", [10000])
@pytest.mark.parametrize("postselect", [None, 0, 1])
@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var])
Expand Down Expand Up @@ -542,7 +537,7 @@ def ref_func(x, y):

dev = qml.device(backend, wires=2, shots=shots)

@qjit
@qjit(seed="8v4Lj7L")
@qml.qnode(dev, postselect_mode=postselect_mode, mcm_method="one-shot")
def func(x, y):
qml.RX(x, 0)
Expand Down
31 changes: 25 additions & 6 deletions mlir/lib/Quantum/Transforms/ConversionPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,36 @@ template <typename T> struct RTBasedPattern : public OpConversionPattern<T> {
StringRef qirName;
if constexpr (std::is_same_v<T, InitializeOp>) {
qirName = "__catalyst__rt__initialize";
InitializeOp InitOp = cast<InitializeOp>(op);
Location loc = InitOp.getLoc();
ModuleOp mod = InitOp->getParentOfType<ModuleOp>();
Type charPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
Type qirSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx),
/* seed = */ {charPtrType});
Value seed_gs;
if (InitOp->hasAttr("seed")) {
auto seed_str = cast<StringAttr>(InitOp->getAttr("seed")).str();
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
seed_gs = getGlobalString(loc, rewriter, seed_str,
StringRef(seed_str.c_str(), seed_str.length() + 1), mod);
}
else {
seed_gs = getGlobalString(loc, rewriter, "unseeded",
"__catalyst__unseeded__qjit__run__", mod);
}
LLVM::LLVMFuncOp fnDecl =
ensureFunctionDeclaration(rewriter, op, qirName, qirSignature);
SmallVector<Value> operands = {seed_gs};
rewriter.create<LLVM::CallOp>(loc, fnDecl, operands);
rewriter.eraseOp(op);
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
}
else {
qirName = "__catalyst__rt__finalize";
Type qirSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), {});
LLVM::LLVMFuncOp fnDecl =
ensureFunctionDeclaration(rewriter, op, qirName, qirSignature);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, fnDecl, ValueRange{});
}

Type qirSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(ctx), {});

LLVM::LLVMFuncOp fnDecl = ensureFunctionDeclaration(rewriter, op, qirName, qirSignature);

rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, fnDecl, ValueRange{});

return success();
}
};
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Quantum/ConversionTest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
// Runtime Management //
////////////////////////

// CHECK: llvm.func @__catalyst__rt__initialize()
// CHECK: llvm.func @__catalyst__rt__initialize(!llvm.ptr)

// CHECK-LABEL: @init
func.func @init() {

// CHECK: llvm.call @__catalyst__rt__initialize()
// CHECK: llvm.call @__catalyst__rt__initialize({{%.+}}) : (!llvm.ptr) -> ()
quantum.init

return
Expand Down
15 changes: 15 additions & 0 deletions runtime/include/QuantumDevice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <complex>
#include <memory>
#include <optional>
#include <random>
#include <vector>

#include "DataView.hpp"
Expand Down Expand Up @@ -103,6 +104,20 @@ struct QuantumDevice {
*/
[[nodiscard]] virtual auto GetDeviceShots() const -> size_t = 0;

/**
* @brief Set the PRNG seed of the device.
*
* @param seed The PRNG seed.
*/
virtual void SetDeviceSeed(std::string seed) = 0;

/**
* @brief Set the PRNG of the device.
*
* @param gen The std::mt19937 PRNG object.
*/
virtual void SetDevicePRNG(std::mt19937 *gen) = 0;

/**
* @brief Start recording a quantum tape if provided.
*
Expand Down
2 changes: 1 addition & 1 deletion runtime/include/RuntimeCAPI.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ extern "C" {

// Quantum Runtime Instructions
void __catalyst__rt__fail_cstr(const char *);
void __catalyst__rt__initialize();
void __catalyst__rt__initialize(char *seed = nullptr);
void __catalyst__rt__device_init(int8_t *, int8_t *, int8_t *);
void __catalyst__rt__device_release();
void __catalyst__rt__finalize();
Expand Down
21 changes: 16 additions & 5 deletions runtime/lib/backend/common/Utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
void StopTapeRecording() override; \
void SetDeviceShots(size_t shots) override; \
[[nodiscard]] auto GetDeviceShots() const->size_t override; \
void SetDeviceSeed(std::string seed) override; \
void SetDevicePRNG(std::mt19937 *gen) override; \
void PrintState() override; \
[[nodiscard]] auto Zero() const->Result override; \
[[nodiscard]] auto One() const->Result override;
Expand Down Expand Up @@ -271,8 +273,8 @@ constexpr auto has_gate(const SimulatorGateInfoDataT<size> &arr, const std::stri
return false;
}

static inline auto simulateDraw(const std::vector<double> &probs, std::optional<int32_t> postselect)
-> bool
static inline auto simulateDraw(const std::vector<double> &probs, std::optional<int32_t> postselect,
std::mt19937 *gen, bool has_seed) -> bool
{
if (postselect) {
auto postselect_value = postselect.value();
Expand All @@ -283,10 +285,19 @@ static inline auto simulateDraw(const std::vector<double> &probs, std::optional<

// Normal flow, no post-selection
// Draw a number according to the given distribution
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<> dis(0., 1.);
float draw = dis(gen);

float draw;
if (has_seed) {
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
draw = dis(*gen);
(*gen)();
}
else {
std::random_device rd;
std::mt19937 gen_no_seed(rd());
draw = dis(gen_no_seed);
}

return draw > probs[0];
}

Expand Down
2 changes: 2 additions & 0 deletions runtime/lib/backend/dummy/dummy_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ struct DummyDevice final : public Catalyst::Runtime::QuantumDevice {
[[nodiscard]] auto GetNumQubits() const -> size_t override { return 0; }
void SetDeviceShots(size_t shots) override {}
[[nodiscard]] auto GetDeviceShots() const -> size_t override { return 0; }
void SetDeviceSeed(std::string seed) override {}
void SetDevicePRNG(std::mt19937 *gen) override {}
void StartTapeRecording() override {}
void StopTapeRecording() override {}
void PrintState() override {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ void LightningSimulator::SetDeviceShots(size_t shots) { this->device_shots = sho

auto LightningSimulator::GetDeviceShots() const -> size_t { return this->device_shots; }

void LightningSimulator::SetDeviceSeed(std::string _seed) { this->seed = _seed; }

void LightningSimulator::SetDevicePRNG(std::mt19937 *_gen) { this->gen = _gen; }

void LightningSimulator::PrintState()
{
using std::cout;
Expand Down Expand Up @@ -433,7 +437,7 @@ auto LightningSimulator::Measure(QubitIdType wire, std::optional<int32_t> postse
SetDeviceShots(device_shots);

// It represents the measured result, true for 1, false for 0
bool mres = Lightning::simulateDraw(probs, postselect);
bool mres = Lightning::simulateDraw(probs, postselect, this->gen, this->hasSeed());
auto dev_wires = getDeviceWires(wires);
this->device_sv->collapse(dev_wires[0], mres ? 1 : 0);
return mres ? this->One() : this->Zero();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class LightningSimulator final : public Catalyst::Runtime::QuantumDevice {
bool tape_recording{false};
size_t device_shots;

std::string seed;
std::mt19937 *gen;

bool mcmc{false};
size_t num_burnin{0};
std::string kernel_name;
Expand Down Expand Up @@ -84,6 +87,8 @@ class LightningSimulator final : public Catalyst::Runtime::QuantumDevice {
return res;
}

inline auto hasSeed() -> bool { return this->seed != ""; }

public:
explicit LightningSimulator(const std::string &kwargs = "{}")
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ void LightningKokkosSimulator::SetDeviceShots(size_t shots) { this->device_shots

auto LightningKokkosSimulator::GetDeviceShots() const -> size_t { return this->device_shots; }

void LightningKokkosSimulator::SetDeviceSeed(std::string _seed) { this->seed = _seed; }

void LightningKokkosSimulator::SetDevicePRNG(std::mt19937 *_gen) { this->gen = _gen; }

void LightningKokkosSimulator::PrintState()
{
using std::cout;
Expand Down Expand Up @@ -455,7 +459,7 @@ auto LightningKokkosSimulator::Measure(QubitIdType wire, std::optional<int32_t>
SetDeviceShots(device_shots);

// It represents the measured result, true for 1, false for 0
bool mres = Lightning::simulateDraw(probs, postselect);
bool mres = Lightning::simulateDraw(probs, postselect, this->gen, this->hasSeed());
auto dev_wires = getDeviceWires(wires);
this->device_sv->collapse(dev_wires[0], mres ? 1 : 0);
return mres ? this->One() : this->Zero();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class LightningKokkosSimulator final : public Catalyst::Runtime::QuantumDevice {

size_t device_shots;

std::string seed;
std::mt19937 *gen;

std::unique_ptr<StateVectorT> device_sv = std::make_unique<StateVectorT>(0);
LightningKokkosObsManager<double> obs_manager{};

Expand Down Expand Up @@ -78,6 +81,8 @@ class LightningKokkosSimulator final : public Catalyst::Runtime::QuantumDevice {
return res;
}

inline auto hasSeed() -> bool { return this->seed != ""; }

public:
explicit LightningKokkosSimulator(const std::string &kwargs = "{}")
{
Expand Down
4 changes: 4 additions & 0 deletions runtime/lib/backend/openqasm/OpenQasmDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ void OpenQasmDevice::SetDeviceShots([[maybe_unused]] size_t shots) { device_shot

auto OpenQasmDevice::GetDeviceShots() const -> size_t { return device_shots; }

void OpenQasmDevice::SetDeviceSeed([[maybe_unused]] std::string _seed) { seed = _seed; }

void OpenQasmDevice::SetDevicePRNG([[maybe_unused]] std::mt19937 *_gen) { gen = _gen; }

void OpenQasmDevice::PrintState()
{
using std::cout;
Expand Down
3 changes: 3 additions & 0 deletions runtime/lib/backend/openqasm/OpenQasmDevice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class OpenQasmDevice final : public Catalyst::Runtime::QuantumDevice {
bool tape_recording{false};
size_t device_shots;

std::string seed;
std::mt19937 *gen;

OpenQasm::OpenQasmObsManager obs_manager{};
OpenQasm::BuilderType builder_type;
std::unordered_map<std::string, std::string> device_kwargs;
Expand Down
Loading