diff --git a/scripts/import_test.py b/scripts/import_test.py index 48a4d4b8c..b712901ae 100644 --- a/scripts/import_test.py +++ b/scripts/import_test.py @@ -38,6 +38,8 @@ def test_imports(): _ = tfq.math.inner_product _ = tfq.math.fidelity _ = tfq.math.mps_1d_expectation + _ = tfq.math.mps_1d_sample + _ = tfq.math.mps_1d_sampled_expectation # Noisy simulation ops. _ = tfq.noise.expectation diff --git a/tensorflow_quantum/core/ops/math_ops/BUILD b/tensorflow_quantum/core/ops/math_ops/BUILD index f62f47311..51bad72fb 100644 --- a/tensorflow_quantum/core/ops/math_ops/BUILD +++ b/tensorflow_quantum/core/ops/math_ops/BUILD @@ -18,6 +18,8 @@ cc_binary( "tfq_inner_product.cc", "tfq_inner_product_grad.cc", "tfq_simulate_1d_expectation.cc", + "tfq_simulate_1d_samples.cc", + "tfq_simulate_1d_sampled_expectation.cc", ], copts = select({ ":windows": [ @@ -127,7 +129,9 @@ py_library( srcs = ["simulate_mps.py"], data = [":_tfq_math_ops.so"], deps = [ + "//tensorflow_quantum/core/ops:batch_util", "//tensorflow_quantum/core/ops:load_module", + "//tensorflow_quantum/core/ops:tfq_utility_ops_py", ], ) diff --git a/tensorflow_quantum/core/ops/math_ops/__init__.py b/tensorflow_quantum/core/ops/math_ops/__init__.py index caf9e7648..982e26911 100644 --- a/tensorflow_quantum/core/ops/math_ops/__init__.py +++ b/tensorflow_quantum/core/ops/math_ops/__init__.py @@ -16,4 +16,5 @@ from tensorflow_quantum.core.ops.math_ops.fidelity_op import fidelity from tensorflow_quantum.core.ops.math_ops.inner_product_op import inner_product -from tensorflow_quantum.core.ops.math_ops.simulate_mps import mps_1d_expectation +from tensorflow_quantum.core.ops.math_ops.simulate_mps import ( + mps_1d_expectation, mps_1d_sample, mps_1d_sampled_expectation) diff --git a/tensorflow_quantum/core/ops/math_ops/simulate_mps.py b/tensorflow_quantum/core/ops/math_ops/simulate_mps.py index 799fe38c0..468ce97d3 100644 --- a/tensorflow_quantum/core/ops/math_ops/simulate_mps.py +++ b/tensorflow_quantum/core/ops/math_ops/simulate_mps.py @@ -16,6 +16,7 @@ import os import tensorflow as tf from tensorflow_quantum.core.ops.load_module import load_module +from tensorflow_quantum.core.ops import tfq_utility_ops MATH_OP_MODULE = load_module(os.path.join("math_ops", "_tfq_math_ops.so")) @@ -55,3 +56,90 @@ def mps_1d_expectation(programs, tf.float32), pauli_sums, bond_dim=bond_dim) + + +def mps_1d_sample(programs, + symbol_names, + symbol_values, + num_samples, + bond_dim=4): + """Generate samples using the C++ MPS simulator. + + Simulate the final state of `programs` given `symbol_values` are placed + inside of the symbols with the name in `symbol_names` in each circuit. + From there we will then sample from the final state. + + Args: + programs: `tf.Tensor` of strings with shape [batch_size] containing + the string representations of the circuits to be executed. + symbol_names: `tf.Tensor` of strings with shape [n_params], which + is used to specify the order in which the values in + `symbol_values` should be placed inside of the circuits in + `programs`. + symbol_values: `tf.Tensor` of real numbers with shape + [batch_size, n_params] specifying parameter values to resolve + into the circuits specified by programs, following the ordering + dictated by `symbol_names`. + num_samples: `tf.Tensor` with one element indicating the number of + samples to draw. + bond_dim: Integer value used for the bond dimension during simulation. + + Returns: + A `tf.RaggedTensor` containing the samples taken from each circuit in + `programs`. + """ + padded_samples = MATH_OP_MODULE.tfq_simulate_mps1d_samples( + programs, + symbol_names, + tf.cast(symbol_values, tf.float32), + num_samples, + bond_dim=bond_dim) + + return tfq_utility_ops.padded_to_ragged(padded_samples) + + +def mps_1d_sampled_expectation(programs, + symbol_names, + symbol_values, + pauli_sums, + num_samples, + bond_dim=4): + """Calculate the expectation value of circuits using samples. + + Simulate the final state of `programs` given `symbol_values` are placed + inside of the symbols with the name in `symbol_names` in each circuit. + Them, sample the resulting state `num_samples` times and use these samples + to compute expectation values of the given `pauli_sums`. + + Args: + programs: `tf.Tensor` of strings with shape [batch_size] containing + the string representations of the circuits to be executed. + symbol_names: `tf.Tensor` of strings with shape [n_params], which + is used to specify the order in which the values in + `symbol_values` should be placed inside of the circuits in + `programs`. + symbol_values: `tf.Tensor` of real numbers with shape + [batch_size, n_params] specifying parameter values to resolve + into the circuits specificed by programs, following the ordering + dictated by `symbol_names`. + pauli_sums: `tf.Tensor` of strings with shape [batch_size, n_ops] + containing the string representation of the operators that will + be used on all of the circuits in the expectation calculations. + num_samples: `tf.Tensor` with `num_samples[i][j]` is equal to the + number of samples to draw in each term of `pauli_sums[i][j]` + when estimating the expectation. Therefore, `num_samples` must + have the same shape as `pauli_sums`. + bond_dim: Integer value used for the bond dimension during simulation. + + Returns: + `tf.Tensor` with shape [batch_size, n_ops] that holds the + expectation value for each circuit with each op applied to it + (after resolving the corresponding parameters in). + """ + return MATH_OP_MODULE.tfq_simulate_mps1d_sampled_expectation( + programs, + symbol_names, + tf.cast(symbol_values, tf.float32), + pauli_sums, + tf.cast(num_samples, dtype=tf.int32), + bond_dim=bond_dim) diff --git a/tensorflow_quantum/core/ops/math_ops/simulate_mps_test.py b/tensorflow_quantum/core/ops/math_ops/simulate_mps_test.py index 7682d1f9a..1120b1392 100644 --- a/tensorflow_quantum/core/ops/math_ops/simulate_mps_test.py +++ b/tensorflow_quantum/core/ops/math_ops/simulate_mps_test.py @@ -20,16 +20,42 @@ sys.path = NEW_PATH # pylint: enable=wrong-import-position +from absl.testing import parameterized import numpy as np import tensorflow as tf import cirq import cirq_google import sympy +from scipy import stats + +from tensorflow_quantum.core.ops import batch_util from tensorflow_quantum.core.ops.math_ops import simulate_mps from tensorflow_quantum.python import util +def _make_1d_circuit(qubits, depth): + """Create a 1d ladder circuit.""" + even_pairs = list(zip(qubits[::2], qubits[1::2])) + odd_pairs = list(zip(qubits[1::2], qubits[2::2])) + ret = cirq.Circuit() + + for _ in range(depth): + # return ret + ret += [(cirq.Y(q)**np.random.random()) for q in qubits] + ret += [ + cirq_google.SycamoreGate()(q0, q1)**np.random.random() + for q0, q1 in even_pairs + ] + ret += [(cirq.Y(q)**np.random.random()) for q in qubits] + ret += [ + cirq_google.SycamoreGate()(q1, q0)**np.random.random() + for q0, q1 in odd_pairs + ] + + return ret + + class SimulateMPS1DExpectationTest(tf.test.TestCase): """Tests mps_1d_expectation.""" @@ -254,11 +280,17 @@ def test_simulate_mps_1d_expectation_inputs(self): symbol_names, symbol_values_array, util.convert_to_tensor([[x] for x in pauli_sums])) - res = simulate_mps.mps_1d_expectation( - util.convert_to_tensor([cirq.Circuit() for _ in pauli_sums]), - symbol_names, symbol_values_array.astype(np.float64), - util.convert_to_tensor([[x] for x in pauli_sums])) - self.assertDTypeEqual(res, np.float32) + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='minimum 3 qubits'): + # too few qubits. + circuit_small = cirq.Circuit(cirq.X(qubits[0]), cirq.X(qubits[1]), + cirq.X(qubits[2])) + small_pauli = cirq.Z(qubits[0]) + + simulate_mps.mps_1d_expectation( + util.convert_to_tensor([circuit_small for _ in pauli_sums]), + symbol_names, symbol_values_array, + util.convert_to_tensor([[small_pauli] for _ in pauli_sums])) def test_simulate_mps_1d_expectation_simple(self): """Makes sure that the op shows the same result with Cirq.""" @@ -297,34 +329,11 @@ def test_simulate_mps_1d_expectation_simple(self): # Expected value of 0.349... self.assertAllClose(mps_result, cirq_result) - def _make_1d_circuit(self, qubits, depth): - """Create a 1d ladder circuit.""" - even_pairs = list(zip(qubits[::2], qubits[1::2])) - odd_pairs = list(zip(qubits[1::2], qubits[2::2])) - ret = cirq.Circuit() - - for _ in range(depth): - # return ret - ret += [(cirq.Y(q)**np.random.random()) for q in qubits] - ret += [ - cirq_google.SycamoreGate()(q0, q1)**np.random.random() - for q0, q1 in even_pairs - ] - ret += [(cirq.Y(q)**np.random.random()) for q in qubits] - ret += [ - cirq_google.SycamoreGate()(q1, q0)**np.random.random() - for q0, q1 in odd_pairs - ] - - return ret - def test_complex_equality(self): """Check moderate sized 1d random circuits.""" batch_size = 10 qubits = cirq.GridQubit.rect(1, 8) - circuit_batch = [ - self._make_1d_circuit(qubits, 3) for _ in range(batch_size) - ] + circuit_batch = [_make_1d_circuit(qubits, 3) for _ in range(batch_size)] pauli_sums = [[ cirq.Z(qubits[0]), @@ -366,5 +375,645 @@ def test_correctness_empty(self): self.assertShapeEqual(np.zeros((0, 0)), out) +class SimulateMPS1DSamplesTest(tf.test.TestCase, parameterized.TestCase): + """Tests tfq_simulate_mps1d_samples.""" + + def test_simulate_mps1d_samples_inputs(self): + """Make sure the sample op fails gracefully on bad inputs.""" + n_qubits = 5 + num_samples = 10 + batch_size = 5 + symbol_names = ['alpha'] + qubits = cirq.GridQubit.rect(1, n_qubits) + circuit_batch = [ + cirq.Circuit( + cirq.X(qubits[0])**sympy.Symbol(symbol_names[0]), + cirq.Z(qubits[1]), + cirq.CNOT(qubits[2], qubits[3]), + cirq.Y(qubits[4])**sympy.Symbol(symbol_names[0]), + ) for _ in range(batch_size) + ] + resolver_batch = [{symbol_names[0]: 0.123} for _ in range(batch_size)] + + symbol_values_array = np.array( + [[resolver[symbol] + for symbol in symbol_names] + for resolver in resolver_batch]) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'rank 1. Got rank 2'): + # programs tensor has the wrong shape. + simulate_mps.mps_1d_sample(util.convert_to_tensor([circuit_batch]), + symbol_names, symbol_values_array, + [num_samples]) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'rank 1. Got rank 2'): + # symbol_names tensor has the wrong shape. + simulate_mps.mps_1d_sample(util.convert_to_tensor(circuit_batch), + np.array([symbol_names]), + symbol_values_array, [num_samples]) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'rank 2. Got rank 3'): + # symbol_values tensor has the wrong shape. + simulate_mps.mps_1d_sample(util.convert_to_tensor(circuit_batch), + symbol_names, + np.array([symbol_values_array]), + [num_samples]) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'rank 2. Got rank 1'): + # symbol_values tensor has the wrong shape 2. + simulate_mps.mps_1d_sample(util.convert_to_tensor(circuit_batch), + symbol_names, symbol_values_array[0], + [num_samples]) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'rank 1. Got rank 2'): + # num_samples tensor has the wrong shape. + simulate_mps.mps_1d_sample(util.convert_to_tensor(circuit_batch), + symbol_names, symbol_values_array, + [[num_samples]]) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'Could not find symbol in parameter map'): + # symbol_names tensor has the right type, but invalid value. + simulate_mps.mps_1d_sample(util.convert_to_tensor(circuit_batch), + ['junk'], symbol_values_array, + [num_samples]) + + with self.assertRaisesRegex(TypeError, 'Cannot convert'): + # programs tensor has the wrong type. + simulate_mps.mps_1d_sample([1] * batch_size, symbol_names, + symbol_values_array, [num_samples]) + + with self.assertRaisesRegex(TypeError, 'Cannot convert'): + # programs tensor has the wrong type. + simulate_mps.mps_1d_sample(util.convert_to_tensor(circuit_batch), + [1], symbol_values_array, [num_samples]) + + with self.assertRaisesRegex(tf.errors.UnimplementedError, + 'Cast string to float is not supported'): + # programs tensor has the wrong type. + simulate_mps.mps_1d_sample(util.convert_to_tensor(circuit_batch), + symbol_names, [['junk']] * batch_size, + [num_samples]) + + with self.assertRaisesRegex(Exception, 'junk'): + # num_samples tensor has the wrong shape. + simulate_mps.mps_1d_sample(util.convert_to_tensor(circuit_batch), + symbol_names, symbol_values_array, + ['junk']) + + with self.assertRaisesRegex(TypeError, 'missing'): + # too few tensors. + # pylint: disable=no-value-for-parameter + simulate_mps.mps_1d_sample(util.convert_to_tensor(circuit_batch), + symbol_names, symbol_values_array) + # pylint: enable=no-value-for-parameter + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='do not match'): + # wrong symbol_values size. + simulate_mps.mps_1d_sample( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array[:int(batch_size * 0.5)], num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='cirq.Channel'): + # attempting to use noisy circuit. + noisy_circuit = cirq.Circuit(cirq.depolarize(0.3).on_each(*qubits)) + simulate_mps.mps_1d_sample( + util.convert_to_tensor([noisy_circuit for _ in circuit_batch]), + symbol_names, symbol_values_array, [num_samples]) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'at least minimum 4'): + # pylint: disable=too-many-function-args + simulate_mps.mps_1d_sample(util.convert_to_tensor(circuit_batch), + symbol_names, symbol_values_array, + [num_samples], 1) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='not in 1D topology'): + # attempting to use a circuit not in 1D topology + # 0--1--2--3 + # \-4 + circuit_not_1d = cirq.Circuit( + cirq.X(qubits[0])**sympy.Symbol(symbol_names[0]), + cirq.Z(qubits[1])**sympy.Symbol(symbol_names[0]), + cirq.CNOT(qubits[2], qubits[3]), + cirq.CNOT(qubits[2], qubits[4]), + ) + simulate_mps.mps_1d_sample( + util.convert_to_tensor([circuit_not_1d for _ in circuit_batch]), + symbol_names, symbol_values_array, [num_samples]) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='not in 1D topology'): + # attempting to use a circuit in 1D topology, which looks in 2D. + # 0--1 + # \-2-\ + # 3--4 == 1--0--2--4--3 + circuit_not_1d = cirq.Circuit( + cirq.CNOT(qubits[0], qubits[1]), + cirq.CNOT(qubits[0], qubits[2]), + cirq.CNOT(qubits[2], qubits[4]), + cirq.CNOT(qubits[3], qubits[4]), + ) + simulate_mps.mps_1d_sample( + util.convert_to_tensor([circuit_not_1d for _ in circuit_batch]), + symbol_names, symbol_values_array, [num_samples]) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='minimum 3 qubits'): + # too few qubits. + circuit_small = cirq.Circuit(cirq.X(qubits[0]), cirq.X(qubits[1]), + cirq.X(qubits[2])) + + simulate_mps.mps_1d_sample( + util.convert_to_tensor([circuit_small for _ in circuit_batch]), + symbol_names, symbol_values_array, [num_samples]) + + @parameterized.parameters([ + { + 'all_n_qubits': [4, 5], + 'n_samples': 10 + }, + { + 'all_n_qubits': [4, 5, 8], + 'n_samples': 10 + }, + ]) + def test_sampling_output_padding(self, all_n_qubits, n_samples): + """Check that the sampling ops pad outputs correctly""" + op = simulate_mps.mps_1d_sample + circuits = [] + expected_outputs = [] + for n_qubits in all_n_qubits: + expected_outputs.append(np.ones((n_samples, n_qubits))) + circuits.append( + cirq.Circuit(cirq.X.on_each(*cirq.GridQubit.rect(1, n_qubits)))) + results = op(util.convert_to_tensor(circuits), [], [[]] * len(circuits), + [n_samples]).to_list() + for a, b in zip(expected_outputs, results): + self.assertAllClose(a, b) + + def test_ghz_state(self): + """Test a simple GHZ-like state.""" + op = simulate_mps.mps_1d_sample + qubits = cirq.GridQubit.rect(1, 6) + circuit = cirq.Circuit(cirq.I.on_each(*qubits)) + circuit += [ + cirq.X(qubits[0]), + cirq.H(qubits[1]), + cirq.CNOT(qubits[1], qubits[2]) + ] + + circuit_batch = [circuit] + resolver_batch = [cirq.ParamResolver({})] + n_samples = 1000 + + cirq_samples = batch_util.batch_sample(circuit_batch, resolver_batch, + n_samples, cirq.Simulator()) + + op_samples = np.array( + op(util.convert_to_tensor(circuit_batch), [], [[]], [n_samples], + bond_dim=16).to_list()) + self.assertAllClose(np.mean(op_samples, axis=1), + np.mean(cirq_samples, axis=1), + atol=1e-1) + + def test_sampling_fuzz(self): + """Compare sampling with tfq ops and Cirq.""" + op = simulate_mps.mps_1d_sample + batch_size = 10 + n_qubits = 6 + qubits = cirq.GridQubit.rect(1, n_qubits) + symbol_names = [] + n_samples = 10_000 + + circuit_batch = [_make_1d_circuit(qubits, 1) for _ in range(batch_size)] + resolver_batch = [cirq.ParamResolver({}) for _ in range(batch_size)] + + symbol_values_array = np.array( + [[resolver[symbol] + for symbol in symbol_names] + for resolver in resolver_batch]) + + op_samples = np.array( + op(util.convert_to_tensor(circuit_batch), + symbol_names, + symbol_values_array, [n_samples], + bond_dim=16).to_list()) + + op_histograms = [ + np.histogram( + sample.dot(1 << np.arange(sample.shape[-1] - 1, -1, -1)), + range=(0, 2**len(qubits)), + bins=2**len(qubits))[0] for sample in op_samples + ] + + cirq_samples = batch_util.batch_sample(circuit_batch, resolver_batch, + n_samples, cirq.Simulator()) + + cirq_histograms = [ + np.histogram( + sample.dot(1 << np.arange(sample.shape[-1] - 1, -1, -1)), + range=(0, 2**len(qubits)), + bins=2**len(qubits))[0] for sample in cirq_samples + ] + + for a, b in zip(op_histograms, cirq_histograms): + self.assertLess(stats.entropy(a + 1e-8, b + 1e-8), 0.05) + + +class SimulateMPS1DSampledExpectationTest(tf.test.TestCase): + """Tests tfq_simulate_mps1d_sampled_expectation.""" + + def test_simulate_mps1d_sampled_expectation_inputs(self): + """Make sure sampled expectation op fails gracefully on bad inputs.""" + n_qubits = 5 + batch_size = 5 + symbol_names = ['alpha'] + qubits = cirq.GridQubit.rect(1, n_qubits) + circuit_batch = [ + cirq.Circuit( + cirq.X(qubits[0])**sympy.Symbol(symbol_names[0]), + cirq.Z(qubits[1]), + cirq.CNOT(qubits[2], qubits[3]), + cirq.Y(qubits[4])**sympy.Symbol(symbol_names[0]), + ) for _ in range(batch_size) + ] + resolver_batch = [{symbol_names[0]: 0.123} for _ in range(batch_size)] + + symbol_values_array = np.array( + [[resolver[symbol] + for symbol in symbol_names] + for resolver in resolver_batch]) + + pauli_sums = util.random_pauli_sums(qubits, 3, batch_size) + num_samples = [[10]] * batch_size + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'programs must be rank 1'): + # Circuit tensor has too many dimensions. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor([circuit_batch]), symbol_names, + symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'symbol_names must be rank 1.'): + # symbol_names tensor has too many dimensions. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), np.array([symbol_names]), + symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'symbol_values must be rank 2.'): + # symbol_values_array tensor has too many dimensions. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), symbol_names, + np.array([symbol_values_array]), + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'symbol_values must be rank 2.'): + # symbol_values_array tensor has too few dimensions. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array[0], + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'pauli_sums must be rank 2.'): + # pauli_sums tensor has too few dimensions. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), + symbol_names, symbol_values_array, + util.convert_to_tensor(list(pauli_sums)), num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'pauli_sums must be rank 2.'): + # pauli_sums tensor has too many dimensions. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, + [util.convert_to_tensor([[x] for x in pauli_sums])], + num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'num_samples must be rank 2'): + # num_samples tensor has the wrong shape. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), + [num_samples]) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'num_samples must be rank 2'): + # num_samples tensor has the wrong shape. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), + num_samples[0]) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'Unparseable proto'): + # circuit tensor has the right type but invalid values. + simulate_mps.mps_1d_sampled_expectation( + ['junk'] * batch_size, symbol_names, symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'Could not find symbol in parameter map'): + # symbol_names tensor has the right type but invalid values. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), ['junk'], + symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'qubits not found in circuit'): + # pauli_sums tensor has the right type but invalid values. + new_qubits = [cirq.GridQubit(5, 5), cirq.GridQubit(9, 9)] + new_pauli_sums = util.random_pauli_sums(new_qubits, 2, batch_size) + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, + util.convert_to_tensor([[x] for x in new_pauli_sums]), + num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'Unparseable proto'): + # pauli_sums tensor has the right type but invalid values 2. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, [['junk']] * batch_size, num_samples) + + with self.assertRaisesRegex(TypeError, 'Cannot convert'): + # circuits tensor has the wrong type. + simulate_mps.mps_1d_sampled_expectation( + [1.0] * batch_size, symbol_names, symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(TypeError, 'Cannot convert'): + # symbol_names tensor has the wrong type. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), [0.1234], + symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(tf.errors.UnimplementedError, ''): + # symbol_values tensor has the wrong type. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), symbol_names, + [['junk']] * batch_size, + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(TypeError, 'Cannot convert'): + # pauli_sums tensor has the wrong type. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, [[1.0]] * batch_size, num_samples) + + with self.assertRaisesRegex(TypeError, 'missing'): + # we are missing an argument. + # pylint: disable=no-value-for-parameter + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, num_samples) + # pylint: enable=no-value-for-parameter + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='do not match'): + # wrong op size. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor([cirq.Circuit()]), symbol_names, + symbol_values_array.astype(np.float64), + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'minimum 4'): + # pylint: disable=too-many-function-args + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), + symbol_names, + symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), + num_samples, + bond_dim=-10) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='do not match'): + # wrong symbol_values size. + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array[:int(batch_size * 0.5)], + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='cirq.Channel'): + # attempting to use noisy circuit. + noisy_circuit = cirq.Circuit(cirq.depolarize(0.3).on_each(*qubits)) + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor([noisy_circuit for _ in pauli_sums]), + symbol_names, symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + 'at least minimum 4'): + # pylint: disable=too-many-function-args + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples, + 1) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='not in 1D topology'): + # attempting to use a circuit not in 1D topology + # 0--1--2--3 + # \-4 + circuit_not_1d = cirq.Circuit( + cirq.X(qubits[0])**sympy.Symbol(symbol_names[0]), + cirq.Z(qubits[1])**sympy.Symbol(symbol_names[0]), + cirq.CNOT(qubits[2], qubits[3]), + cirq.CNOT(qubits[2], qubits[4]), + ) + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor([circuit_not_1d for _ in pauli_sums]), + symbol_names, symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='not in 1D topology'): + # attempting to use a circuit in 1D topology, which looks in 2D. + # 0--1 + # \-2-\ + # 3--4 == 1--0--2--4--3 + circuit_not_1d = cirq.Circuit( + cirq.CNOT(qubits[0], qubits[1]), + cirq.CNOT(qubits[0], qubits[2]), + cirq.CNOT(qubits[2], qubits[4]), + cirq.CNOT(qubits[3], qubits[4]), + ) + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor([circuit_not_1d for _ in pauli_sums]), + symbol_names, symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + expected_regex='minimum 3 qubits'): + # too few qubits. + circuit_small = cirq.Circuit(cirq.X(qubits[0]), cirq.X(qubits[1]), + cirq.X(qubits[2])) + small_pauli = cirq.Z(qubits[0]) + + simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor([circuit_small for _ in pauli_sums]), + symbol_names, symbol_values_array, + util.convert_to_tensor([[small_pauli] for _ in pauli_sums]), + num_samples) + + def test_simulate_sampled_mps_1d_expectation_simple(self): + """Makes sure that the op shows the same result with Cirq.""" + n_qubits = 5 + batch_size = 5 + symbol_names = ['alpha'] + qubits = cirq.GridQubit.rect(1, n_qubits) + circuit_batch = [ + cirq.Circuit( + cirq.X(qubits[0])**sympy.Symbol(symbol_names[0]), + cirq.Z(qubits[1]), + cirq.CNOT(qubits[2], qubits[3]), + cirq.Y(qubits[4])**sympy.Symbol(symbol_names[0]), + ) for _ in range(batch_size) + ] + resolver_batch = [{symbol_names[0]: 0.123} for _ in range(batch_size)] + + symbol_values_array = np.array( + [[resolver[symbol] + for symbol in symbol_names] + for resolver in resolver_batch]) + + pauli_sums = [ + cirq.Z(qubits[0]) * cirq.X(qubits[4]) for _ in range(batch_size) + ] + + num_samples = np.ones(shape=(len(pauli_sums), 1)) * 10000 + + cirq_result = [ + cirq.Simulator().simulate_expectation_values(c, p, r) + for c, p, r in zip(circuit_batch, pauli_sums, resolver_batch) + ] + # Default bond_dim=4 + mps_result = simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), symbol_names, + symbol_values_array, + util.convert_to_tensor([[x] for x in pauli_sums]), num_samples) + # Expected value of 0.349... + self.assertAllClose(mps_result, cirq_result, atol=5e-2) + + def test_complex_equality(self): + """Check moderate sized 1d random circuits.""" + batch_size = 10 + qubits = cirq.GridQubit.rect(1, 8) + circuit_batch = [_make_1d_circuit(qubits, 3) for _ in range(batch_size)] + + pauli_sums = [[ + cirq.Z(qubits[0]), + cirq.Z(qubits[-1]), + cirq.Z(qubits[0]) * cirq.Z(qubits[-1]), + cirq.Z(qubits[0]) + cirq.Z(qubits[-1]) + ] for _ in range(batch_size)] + symbol_names = [] + resolver_batch = [{} for _ in range(batch_size)] + num_samples = np.ones_like(pauli_sums, dtype=np.int32) * 1000 + + symbol_values_array = np.array( + [[resolver[symbol] + for symbol in symbol_names] + for resolver in resolver_batch]) + + cirq_result = [ + cirq.Simulator().simulate_expectation_values(c, p, r) + for c, p, r in zip(circuit_batch, pauli_sums, resolver_batch) + ] + mps_result = simulate_mps.mps_1d_sampled_expectation( + util.convert_to_tensor(circuit_batch), + symbol_names, + symbol_values_array, + util.convert_to_tensor(pauli_sums), + num_samples, + bond_dim=32) + self.assertAllClose(mps_result, cirq_result, atol=2e-1) + + def test_correctness_empty(self): + """Tests the mps op with empty circuits.""" + + empty_circuit = tf.raw_ops.Empty(shape=(0,), dtype=tf.string) + empty_symbols = tf.raw_ops.Empty(shape=(0,), dtype=tf.string) + empty_values = tf.raw_ops.Empty(shape=(0, 0), dtype=tf.float32) + empty_paulis = tf.raw_ops.Empty(shape=(0, 0), dtype=tf.string) + num_samples = tf.raw_ops.Empty(shape=(0, 0), dtype=tf.int32) + + out = simulate_mps.mps_1d_sampled_expectation(empty_circuit, + empty_symbols, + empty_values, + empty_paulis, num_samples, + 32) + + self.assertShapeEqual(np.zeros((0, 0)), out) + + +class InputTypesTest(tf.test.TestCase, parameterized.TestCase): + """Tests that different inputs types work for all of the ops. """ + + @parameterized.parameters([ + { + 'symbol_type': tf.float32 + }, + { + 'symbol_type': tf.float64 + }, + { + 'symbol_type': tf.int32 + }, + { + 'symbol_type': tf.int64 + }, + { + 'symbol_type': tf.complex64 + }, + ]) + def test_symbol_values_type(self, symbol_type): + """Tests all three ops for the different types. """ + qubits = cirq.GridQubit.rect(1, 5) + circuits = util.convert_to_tensor( + [cirq.Circuit(cirq.H.on_each(*qubits))]) + symbol_names = ['symbol'] + symbol_values = tf.convert_to_tensor([[1]], dtype=symbol_type) + pauli_sums = util.random_pauli_sums(qubits, 3, 1) + pauli_sums = util.convert_to_tensor([[x] for x in pauli_sums]) + + result = simulate_mps.mps_1d_expectation(circuits, symbol_names, + symbol_values, pauli_sums) + self.assertDTypeEqual(result, np.float32) + + result = simulate_mps.mps_1d_sample(circuits, symbol_names, + symbol_values, [100]) + self.assertDTypeEqual(result.numpy(), np.int8) + + result = simulate_mps.mps_1d_sampled_expectation( + circuits, symbol_names, symbol_values, pauli_sums, [[100]]) + self.assertDTypeEqual(result, np.float32) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow_quantum/core/ops/math_ops/tfq_simulate_1d_expectation.cc b/tensorflow_quantum/core/ops/math_ops/tfq_simulate_1d_expectation.cc index 03aecae7c..adb1d9bb6 100644 --- a/tensorflow_quantum/core/ops/math_ops/tfq_simulate_1d_expectation.cc +++ b/tensorflow_quantum/core/ops/math_ops/tfq_simulate_1d_expectation.cc @@ -120,11 +120,19 @@ class TfqSimulateMPS1DExpectationOp : public tensorflow::OpKernel { output_dim_batch_size, num_cycles, construct_f); OP_REQUIRES_OK(context, parse_status); + // Find largest circuit for tensor size padding and allocate + // the output tensor. int max_num_qubits = 0; + int min_num_qubits = 1 << 30; for (const int num : num_qubits) { max_num_qubits = std::max(max_num_qubits, num); + min_num_qubits = std::min(min_num_qubits, num); } + OP_REQUIRES(context, min_num_qubits > 3, + tensorflow::errors::InvalidArgument( + "All input circuits require minimum 3 qubits.")); + // Since MPS simulations have much smaller memory footprint, // we do not need a ComputeLarge like we do for state vector simulation. ComputeSmall(num_qubits, max_num_qubits, qsim_circuits, pauli_sums, context, diff --git a/tensorflow_quantum/core/ops/math_ops/tfq_simulate_1d_sampled_expectation.cc b/tensorflow_quantum/core/ops/math_ops/tfq_simulate_1d_sampled_expectation.cc new file mode 100644 index 000000000..750531f16 --- /dev/null +++ b/tensorflow_quantum/core/ops/math_ops/tfq_simulate_1d_sampled_expectation.cc @@ -0,0 +1,297 @@ +/* Copyright 2020 The TensorFlow Quantum Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "../qsim/lib/circuit.h" +#include "../qsim/lib/formux.h" +#include "../qsim/lib/gate_appl.h" +#include "../qsim/lib/gates_cirq.h" +#include "../qsim/lib/mps_simulator.h" +#include "../qsim/lib/mps_statespace.h" +#include "../qsim/lib/seqfor.h" +#include "../qsim/lib/simmux.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/guarded_philox_random.h" +#include "tensorflow_quantum/core/ops/parse_context.h" +#include "tensorflow_quantum/core/proto/pauli_sum.pb.h" +#include "tensorflow_quantum/core/proto/program.pb.h" +#include "tensorflow_quantum/core/src/program_resolution.h" +#include "tensorflow_quantum/core/src/util_qsim.h" + +namespace tfq { + +using ::tensorflow::Status; +using ::tfq::proto::PauliSum; +using ::tfq::proto::Program; + +typedef qsim::Cirq::GateCirq QsimGate; +typedef qsim::Circuit QsimCircuit; + +class TfqSimulateMPS1DSampledExpectationOp : public tensorflow::OpKernel { + public: + explicit TfqSimulateMPS1DSampledExpectationOp( + tensorflow::OpKernelConstruction* context) + : OpKernel(context) { + // Get the bond dimension of MPS + OP_REQUIRES_OK(context, context->GetAttr("bond_dim", &bond_dim_)); + } + + void Compute(tensorflow::OpKernelContext* context) override { + // TODO (mbbrough): add more dimension checks for other inputs here. + const int num_inputs = context->num_inputs(); + OP_REQUIRES(context, num_inputs == 5, + tensorflow::errors::InvalidArgument(absl::StrCat( + "Expected 5 inputs, got ", num_inputs, " inputs."))); + + // Create the output Tensor. + const int output_dim_batch_size = context->input(0).dim_size(0); + const int output_dim_op_size = context->input(3).dim_size(1); + tensorflow::TensorShape output_shape; + output_shape.AddDim(output_dim_batch_size); + output_shape.AddDim(output_dim_op_size); + + tensorflow::Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + auto output_tensor = output->matrix(); + + std::vector programs; + std::vector num_qubits; + std::vector> pauli_sums; + OP_REQUIRES_OK(context, + GetProgramsAndNumQubits(context, &programs, &num_qubits, + &pauli_sums, true)); + + std::vector maps; + OP_REQUIRES_OK(context, GetSymbolMaps(context, &maps)); + + OP_REQUIRES(context, programs.size() == maps.size(), + tensorflow::errors::InvalidArgument(absl::StrCat( + "Number of circuits and symbol_values do not match. Got ", + programs.size(), " circuits and ", maps.size(), + " symbol values."))); + + std::vector> num_samples; + OP_REQUIRES_OK(context, GetNumSamples(context, &num_samples)); + + OP_REQUIRES(context, num_samples.size() == pauli_sums.size(), + tensorflow::errors::InvalidArgument(absl::StrCat( + "Dimension 0 of num_samples and pauli_sums do not match.", + "Got ", num_samples.size(), " lists of sample sizes and ", + pauli_sums.size(), " lists of pauli sums."))); + + OP_REQUIRES( + context, context->input(4).dim_size(1) == context->input(3).dim_size(1), + tensorflow::errors::InvalidArgument(absl::StrCat( + "Dimension 1 of num_samples and pauli_sums do not match.", "Got ", + context->input(4).dim_size(1), " lists of sample sizes and ", + context->input(3).dim_size(1), " lists of pauli sums."))); + + // Construct qsim circuits. + std::vector qsim_circuits(programs.size(), QsimCircuit()); + std::vector>> fused_circuits( + programs.size(), std::vector>({})); + + Status parse_status = Status::OK(); + auto p_lock = tensorflow::mutex(); + auto construct_f = [&](int start, int end) { + for (int i = start; i < end; i++) { + Status local = + QsimCircuitFromProgram(programs[i], maps[i], num_qubits[i], + &qsim_circuits[i], &fused_circuits[i]); + // If parsing works, check MPS constraints. + if (local.ok()) { + local = CheckMPSSupported(programs[i]); + } + NESTED_FN_STATUS_SYNC(parse_status, local, p_lock); + } + }; + + const int num_cycles = 1000; + context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor( + programs.size(), num_cycles, construct_f); + OP_REQUIRES_OK(context, parse_status); + + // Find largest circuit for tensor size padding and allocate + // the output tensor. + int max_num_qubits = 0; + int min_num_qubits = 1 << 30; + for (const int num : num_qubits) { + max_num_qubits = std::max(max_num_qubits, num); + min_num_qubits = std::min(min_num_qubits, num); + } + + OP_REQUIRES(context, min_num_qubits > 3, + tensorflow::errors::InvalidArgument( + "All input circuits require minimum 3 qubits.")); + + // Since MPS simulations have much smaller memory footprint, + // we do not need a ComputeLarge like we do for state vector simulation. + ComputeSmall(num_qubits, max_num_qubits, qsim_circuits, pauli_sums, + num_samples, context, &output_tensor); + } + + private: + int bond_dim_; + void ComputeSmall(const std::vector& num_qubits, + const int max_num_qubits, + const std::vector& unfused_circuits, + const std::vector>& pauli_sums, + const std::vector>& num_samples, + tensorflow::OpKernelContext* context, + tensorflow::TTypes::Matrix* output_tensor) { + // Instantiate qsim objects. + using Simulator = qsim::mps::MPSSimulator; + using StateSpace = Simulator::MPSStateSpace_; + + const int output_dim_op_size = output_tensor->dimension(1); + + tensorflow::GuardedPhiloxRandom random_gen; + random_gen.Init(tensorflow::random::New64(), tensorflow::random::New64()); + int largest_sum = -1; + for (const auto& sums : pauli_sums) { + for (const auto& sum : sums) { + largest_sum = std::max(largest_sum, sum.terms().size()); + } + } + const int num_threads = context->device() + ->tensorflow_cpu_worker_threads() + ->workers->NumThreads(); + + Status compute_status = Status::OK(); + auto c_lock = tensorflow::mutex(); + auto DoWork = [&](int start, int end) { + int old_batch_index = -2; + int cur_batch_index = -1; + int largest_nq = 1; + int cur_op_index; + + // Note: ForArgs in MPSSimulator and MPSStateState are currently unused. + // So, this 1 is a dummy for qsim::For. + Simulator sim = Simulator(1); + StateSpace ss = StateSpace(1); + auto sv = ss.Create(largest_nq, bond_dim_); + auto scratch = ss.Create(largest_nq, bond_dim_); + auto scratch2 = ss.Create(largest_nq, bond_dim_); + auto scratch3 = ss.Create(largest_nq, bond_dim_); + + int n_random = largest_sum * output_dim_op_size * unfused_circuits.size(); + n_random /= num_threads; + n_random += 1; + auto local_gen = random_gen.ReserveSamples32(n_random); + tensorflow::random::SimplePhilox rand_source(&local_gen); + + for (int i = start; i < end; i++) { + cur_batch_index = i / output_dim_op_size; + cur_op_index = i % output_dim_op_size; + + const int nq = num_qubits[cur_batch_index]; + + // (#679) Just ignore empty program + auto unfused_gates = unfused_circuits[cur_batch_index].gates; + // (#679) Just ignore empty program + if (unfused_gates.size() == 0) { + (*output_tensor)(cur_batch_index, cur_op_index) = -2.0; + continue; + } + + if (cur_batch_index != old_batch_index) { + // We've run into a new state vector we must compute. + // Only compute a new state vector when we have to. + if (nq > largest_nq) { + largest_nq = nq; + sv = ss.Create(largest_nq, bond_dim_); + scratch = ss.Create(largest_nq, bond_dim_); + scratch2 = ss.Create(largest_nq, bond_dim_); + scratch3 = ss.Create(largest_nq, bond_dim_); + } + // no need to update scratch_state since ComputeExpectation + // will take care of things for us. + ss.SetStateZero(sv); + for (auto gate : unfused_gates) { + // Can't fuse, since this might break nearest neighbor constraints. + qsim::ApplyGate(sim, gate, sv); + } + } + + float exp_v = 0.0; + NESTED_FN_STATUS_SYNC( + compute_status, + ComputeMPSSampledExpectationQsim( + pauli_sums[cur_batch_index][cur_op_index], sim, ss, sv, scratch, + scratch2, scratch3, num_samples[cur_batch_index][cur_op_index], + rand_source, &exp_v), + c_lock); + + (*output_tensor)(cur_batch_index, cur_op_index) = exp_v; + old_batch_index = cur_batch_index; + } + }; + + const int64_t num_cycles = + 200 * (int64_t(1) << static_cast(max_num_qubits)); + context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor( + unfused_circuits.size() * output_dim_op_size, num_cycles, DoWork); + OP_REQUIRES_OK(context, compute_status); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("TfqSimulateMPS1DSampledExpectation").Device(tensorflow::DEVICE_CPU), + TfqSimulateMPS1DSampledExpectationOp); + +REGISTER_OP("TfqSimulateMPS1DSampledExpectation") + .Input("programs: string") + .Input("symbol_names: string") + .Input("symbol_values: float") + .Input("pauli_sums: string") + .Input("num_samples: int32") + .Output("expectations: float") + .Attr("bond_dim: int >= 4 = 4") + .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { + tensorflow::shape_inference::ShapeHandle programs_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &programs_shape)); + + tensorflow::shape_inference::ShapeHandle symbol_names_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &symbol_names_shape)); + + tensorflow::shape_inference::ShapeHandle symbol_values_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &symbol_values_shape)); + + tensorflow::shape_inference::ShapeHandle pauli_sums_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &pauli_sums_shape)); + + tensorflow::shape_inference::ShapeHandle num_samples_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &num_samples_shape)); + + tensorflow::shape_inference::DimensionHandle output_rows = + c->Dim(programs_shape, 0); + tensorflow::shape_inference::DimensionHandle output_cols = + c->Dim(pauli_sums_shape, 1); + c->set_output(0, c->Matrix(output_rows, output_cols)); + + return tensorflow::Status::OK(); + }); + +} // namespace tfq diff --git a/tensorflow_quantum/core/ops/math_ops/tfq_simulate_1d_samples.cc b/tensorflow_quantum/core/ops/math_ops/tfq_simulate_1d_samples.cc new file mode 100644 index 000000000..495e5f8f2 --- /dev/null +++ b/tensorflow_quantum/core/ops/math_ops/tfq_simulate_1d_samples.cc @@ -0,0 +1,248 @@ +/* Copyright 2020 The TensorFlow Quantum Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include + +#include "../qsim/lib/circuit.h" +#include "../qsim/lib/formux.h" +#include "../qsim/lib/gate_appl.h" +#include "../qsim/lib/gates_cirq.h" +#include "../qsim/lib/mps_simulator.h" +#include "../qsim/lib/mps_statespace.h" +#include "../qsim/lib/seqfor.h" +#include "../qsim/lib/simmux.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/guarded_philox_random.h" +#include "tensorflow_quantum/core/ops/parse_context.h" +#include "tensorflow_quantum/core/proto/program.pb.h" +#include "tensorflow_quantum/core/src/circuit_parser_qsim.h" +#include "tensorflow_quantum/core/src/program_resolution.h" +#include "tensorflow_quantum/core/src/util_qsim.h" + +namespace tfq { + +using ::tensorflow::Status; +using ::tfq::proto::Program; + +typedef qsim::Cirq::GateCirq QsimGate; +typedef qsim::Circuit QsimCircuit; + +class TfqSimulateMPS1DSamplesOp : public tensorflow::OpKernel { + public: + explicit TfqSimulateMPS1DSamplesOp(tensorflow::OpKernelConstruction* context) + : OpKernel(context) { + // Get the bond dimension of MPS + OP_REQUIRES_OK(context, context->GetAttr("bond_dim", &bond_dim_)); + } + + void Compute(tensorflow::OpKernelContext* context) override { + // TODO (mbbrough): add more dimension checks for other inputs here. + DCHECK_EQ(4, context->num_inputs()); + + // Parse to Program Proto and num_qubits. + std::vector programs; + std::vector num_qubits; + OP_REQUIRES_OK(context, + GetProgramsAndNumQubits(context, &programs, &num_qubits, + nullptr, true)); + + // Parse symbol maps for parameter resolution in the circuits. + std::vector maps; + OP_REQUIRES_OK(context, GetSymbolMaps(context, &maps)); + OP_REQUIRES( + context, maps.size() == programs.size(), + tensorflow::errors::InvalidArgument(absl::StrCat( + "Number of circuits and values do not match. Got ", programs.size(), + " circuits and ", maps.size(), " values."))); + + int num_samples = 0; + OP_REQUIRES_OK(context, GetIndividualSample(context, &num_samples)); + + // Construct qsim circuits. + std::vector qsim_circuits(programs.size(), QsimCircuit()); + std::vector>> fused_circuits( + programs.size(), std::vector>({})); + + Status parse_status = Status::OK(); + auto p_lock = tensorflow::mutex(); + auto construct_f = [&](int start, int end) { + for (int i = start; i < end; i++) { + Status local = + QsimCircuitFromProgram(programs[i], maps[i], num_qubits[i], + &qsim_circuits[i], &fused_circuits[i]); + // If parsing works, check MPS constraints. + if (local.ok()) { + local = CheckMPSSupported(programs[i]); + } + NESTED_FN_STATUS_SYNC(parse_status, local, p_lock); + } + }; + + const int num_cycles = 1000; + context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor( + programs.size(), num_cycles, construct_f); + OP_REQUIRES_OK(context, parse_status); + + // Find largest circuit for tensor size padding and allocate + // the output tensor. + int max_num_qubits = 0; + int min_num_qubits = 1 << 30; + for (const int num : num_qubits) { + max_num_qubits = std::max(max_num_qubits, num); + min_num_qubits = std::min(min_num_qubits, num); + } + + OP_REQUIRES(context, min_num_qubits > 3, + tensorflow::errors::InvalidArgument( + "All input circuits require minimum 3 qubits.")); + + const int output_dim_size = maps.size(); + tensorflow::TensorShape output_shape; + output_shape.AddDim(output_dim_size); + output_shape.AddDim(num_samples); + output_shape.AddDim(max_num_qubits); + + tensorflow::Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + auto output_tensor = output->tensor(); + + if (num_samples == 0) { + return; // bug in qsim dependency we can't control. + } + + // Since MPS simulations have much smaller memory footprint, + // we do not need a ComputeLarge like we do for state vector simulation. + ComputeSmall(num_qubits, max_num_qubits, num_samples, qsim_circuits, + context, &output_tensor); + } + + private: + int bond_dim_; + + void ComputeSmall(const std::vector& num_qubits, + const int max_num_qubits, const int num_samples, + const std::vector& unfused_circuits, + tensorflow::OpKernelContext* context, + tensorflow::TTypes::Tensor* output_tensor) { + // Instantiate qsim objects. + using Simulator = qsim::mps::MPSSimulator; + using StateSpace = Simulator::MPSStateSpace_; + + tensorflow::GuardedPhiloxRandom random_gen; + random_gen.Init(tensorflow::random::New64(), tensorflow::random::New64()); + + auto DoWork = [&](int start, int end) { + int largest_nq = 1; + // Note: ForArgs in MPSSimulator and MPSStateState are currently unused. + // So, this 1 is a dummy for qsim::For. + Simulator sim = Simulator(1); + StateSpace ss = StateSpace(1); + auto sv = ss.Create(largest_nq, bond_dim_); + auto scratch = ss.Create(largest_nq, bond_dim_); + auto scratch2 = ss.Create(largest_nq, bond_dim_); + + auto local_gen = random_gen.ReserveSamples32(unfused_circuits.size() + 1); + tensorflow::random::SimplePhilox rand_source(&local_gen); + + for (int i = start; i < end; i++) { + int nq = num_qubits[i]; + + if (nq > largest_nq) { + // need to switch to larger statespace. + largest_nq = nq; + sv = ss.Create(largest_nq, bond_dim_); + scratch = ss.Create(largest_nq, bond_dim_); + scratch2 = ss.Create(largest_nq, bond_dim_); + } + ss.SetStateZero(sv); + auto unfused_gates = unfused_circuits[i].gates; + for (auto gate : unfused_gates) { + // Can't fuse, since this might break nearest neighbor constraints. + qsim::ApplyGate(sim, gate, sv); + } + + std::vector> results(num_samples, + std::vector({})); + + ss.Sample(sv, scratch, scratch2, num_samples, rand_source.Rand32(), + &results); + + for (int j = 0; j < num_samples; j++) { + int64_t q_ind = 0; + while (q_ind < max_num_qubits - nq) { + (*output_tensor)(i, j, static_cast(q_ind)) = -2; + q_ind++; + } + while (q_ind < max_num_qubits) { + (*output_tensor)(i, j, static_cast(q_ind)) = + results[j][q_ind - max_num_qubits + nq]; + q_ind++; + } + } + } + }; + + const int64_t num_cycles = + 200 * (int64_t(1) << static_cast(max_num_qubits)); + context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor( + unfused_circuits.size(), num_cycles, DoWork); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("TfqSimulateMPS1DSamples").Device(tensorflow::DEVICE_CPU), + TfqSimulateMPS1DSamplesOp); + +REGISTER_OP("TfqSimulateMPS1DSamples") + .Input("programs: string") + .Input("symbol_names: string") + .Input("symbol_values: float") + .Input("num_samples: int32") + .Output("samples: int8") + .Attr("bond_dim: int >= 4 = 4") + .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { + tensorflow::shape_inference::ShapeHandle programs_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &programs_shape)); + + tensorflow::shape_inference::ShapeHandle symbol_names_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &symbol_names_shape)); + + tensorflow::shape_inference::ShapeHandle symbol_values_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &symbol_values_shape)); + + tensorflow::shape_inference::ShapeHandle num_samples_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &num_samples_shape)); + + // [batch_size, n_samples, largest_n_qubits] + c->set_output( + 0, c->MakeShape( + {c->Dim(programs_shape, 0), + tensorflow::shape_inference::InferenceContext::kUnknownDim, + tensorflow::shape_inference::InferenceContext::kUnknownDim})); + + return tensorflow::Status::OK(); + }); + +} // namespace tfq diff --git a/tensorflow_quantum/core/src/util_qsim.h b/tensorflow_quantum/core/src/util_qsim.h index f5f9c7285..94531a884 100644 --- a/tensorflow_quantum/core/src/util_qsim.h +++ b/tensorflow_quantum/core/src/util_qsim.h @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include +#include #include #include @@ -267,6 +269,91 @@ tensorflow::Status ComputeSampledExpectationQsim( return status; } +// Overloading for MPS : it requires more scratch states. +// bad style standards here that we are forced to follow from qsim. +// computes the expectation value using +// scratch to save on memory. Implementation does this: +// 1. Copy state onto scratch +// 2. Convert scratch to Z basis +// 3. Compute < state | scratch > via sampling. +// 4. Sum and repeat. +// scratch is required to have memory initialized, but does not require +// values in memory to be set. +template +tensorflow::Status ComputeMPSSampledExpectationQsim( + const tfq::proto::PauliSum& p_sum, const SimT& sim, const StateSpaceT& ss, + StateT& state, StateT& scratch, StateT& scratch2, StateT& scratch3, + const int num_samples, tensorflow::random::SimplePhilox& random_source, + float* expectation_value) { + std::uniform_int_distribution<> distrib(1, 1 << 30); + + if (num_samples == 0) { + return tensorflow::Status::OK(); + } + // apply the gates of the pauliterms to a copy of the state vector + // and add up expectation value term by term. + tensorflow::Status status = tensorflow::Status::OK(); + for (const tfq::proto::PauliTerm& term : p_sum.terms()) { + // catch identity terms + if (term.paulis_size() == 0) { + *expectation_value += term.coefficient_real(); + // TODO(zaqqwerty): error somewhere if identities have any imaginary part + continue; + } + + // Transform state into the measurement basis and sample it + QsimCircuit main_circuit; + std::vector> fused_circuit; + + status = QsimZBasisCircuitFromPauliTerm(term, state.num_qubits(), + &main_circuit, &fused_circuit); + if (!status.ok()) { + return status; + } + // copy from src to scratch. + ss.Copy(state, scratch); + for (const auto& unfused_gate : main_circuit.gates) { + qsim::ApplyGate(sim, unfused_gate, scratch); + } + + if (!status.ok()) { + return status; + } + std::vector> state_samples(num_samples, + std::vector({})); + + ss.Sample(scratch, scratch2, scratch3, num_samples, random_source.Rand32(), + &state_samples); + + // Find qubits on which to measure parity and compute the BitMask. + const unsigned int max_num_qubits = state.num_qubits(); + std::vector mask(max_num_qubits, false); + for (const tfq::proto::PauliQubitPair& pair : term.paulis()) { + unsigned int location; + // GridQubit id should be parsed down to integer at this upstream + // so it is safe to just use atoi. + (void)absl::SimpleAtoi(pair.qubit_id(), &location); + // Parity functions use little-endian indexing + mask[max_num_qubits - location - 1] = 1; + } + + // Compute the running parity. + int parity_total(0); + int count = 0; + for (std::vector& state_sample : state_samples) { + std::transform(mask.begin(), mask.end(), state_sample.begin(), + state_sample.begin(), + [](bool x, bool y) -> bool { return x & y; }); + count = std::accumulate(state_sample.begin(), state_sample.end(), 0); + parity_total += (count & 1) ? -1 : 1; + } + *expectation_value += static_cast(parity_total) * + term.coefficient_real() / + static_cast(num_samples); + } + return status; +} + // Assumes p_sums.size() == op_coeffs.size() // state stores |psi>. scratch has been created, but does not // require initialization. dest has been created, but does not require