diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index b0ab0b9cfa..5f0ad2485e 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -375,7 +375,6 @@ jobs: - name: Install Python dependencies run: | python${{ matrix.python_version }} -m pip install pytest pytest-xdist - python${{ matrix.python_version }} -m pip install tensorflow-cpu # for autograph tests - name: Install Catalyst run: | diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index bf0a1b8662..571a7d704c 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -383,9 +383,7 @@ jobs: - name: Install Python dependencies run: | - # tensorflow-cpu is not distributed for macOS ARM python${{ matrix.python_version }} -m pip install pytest pytest-xdist - python${{ matrix.python_version }} -m pip install tensorflow # for autograph tests - name: Install Catalyst run: | diff --git a/.github/workflows/build-wheel-macos-x86_64.yaml b/.github/workflows/build-wheel-macos-x86_64.yaml index ab830e9c3c..c2a6214150 100644 --- a/.github/workflows/build-wheel-macos-x86_64.yaml +++ b/.github/workflows/build-wheel-macos-x86_64.yaml @@ -354,7 +354,6 @@ jobs: - name: Install Python dependencies run: | python${{ matrix.python_version }} -m pip install pytest pytest-xdist - python${{ matrix.python_version }} -m pip install tensorflow-cpu # for autograph tests - name: Install Catalyst run: | diff --git a/.pylintrc b/.pylintrc index 13771d98ce..640a427a95 100644 --- a/.pylintrc +++ b/.pylintrc @@ -10,7 +10,7 @@ extension-pkg-allow-list=catalyst.utils.wrapper # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. -ignored-modules=pennylane.ops,jaxlib.mlir.ir,jaxlib.xla_extension,tensorflow +ignored-modules=pennylane.ops,jaxlib.mlir.ir,jaxlib.xla_extension # List of classes names for which member attributes should not be checked # (useful for classes with attributes dynamically set). This supports can work diff --git a/doc/changelog.md b/doc/changelog.md index ec611ce774..e624d719d4 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -53,6 +53,10 @@

Improvements

+* Catalyst no longer relies on a TensorFlow installation for its AutoGraph functionality. Instead, + the standalone `diastatic-malt` package is used and automatically installed as a dependency. + [(#401)](https://github.com/PennyLaneAI/catalyst/pull/401) + * Catalyst will now remember previously compiled functions when the PyTree metadata of arguments changes, in addition to already rememebering compiled functions when static arguments change. [(#522)](https://github.com/PennyLaneAI/catalyst/pull/531) diff --git a/doc/dev/autograph.rst b/doc/dev/autograph.rst index 3392f69e77..ddd8a3671e 100644 --- a/doc/dev/autograph.rst +++ b/doc/dev/autograph.rst @@ -29,19 +29,12 @@ restrictions and constraints you may discover. Using AutoGraph --------------- -AutoGraph currently requires TensorFlow as a dependency; in most cases it can -be installed via +The AutoGraph feature in Catalyst is supported by the ``diastatic-malt`` package, a standalone +fork of the AutoGraph module in TensorFlow ( +`official documentation `_ +). -.. code-block:: console - - pip install tensorflow - -but please refer to the -`TensorFlow documentation `__ -for specific details on installing TensorFlow for your platform. - -Once TensorFlow is available, AutoGraph can be enabled by passing -``autograph=True`` to the ``@qjit`` decorator: +To enable AutoGraph in Catalyst, simply pass ``autograph=True`` to the ``@qjit`` decorator: .. code-block:: python @@ -661,7 +654,7 @@ of multiple measurements. For example, qml.RY(0.5, wires=1) m1 = measure(0) - m2 = measure(1) + m2 = measure(1) if m1 and not m2: qml.Hadamard(wires=1) @@ -713,8 +706,8 @@ Array arguments Note that, like with NumPy and JAX, logical operators apply elementwise to array arguments: ->>> @qjit(autograph=True) -... def f(x, y): +>>> @qjit(autograph=True) +... def f(x, y): ... return x and y >>> f(jnp.array([0, 1]), jnp.array([1, 1])) array([False, True]) @@ -722,8 +715,8 @@ array([False, True]) Care must therefore be taken when using logical operators within conditional branches; ``jnp.all`` and ``jnp.any`` can be used to generate a single boolean for conditionals: ->>> @qjit(autograph=True) -... def f(x, y): +>>> @qjit(autograph=True) +... def f(x, y): ... if jnp.all(x and y): ... z = 1 ... else: diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py index 20cc2f8223..486e276d40 100644 --- a/frontend/catalyst/__init__.py +++ b/frontend/catalyst/__init__.py @@ -66,7 +66,7 @@ ) from catalyst import debug -from catalyst.ag_utils import AutoGraphError, autograph_source +from catalyst.autograph import autograph_source from catalyst.compiler import CompileOptions from catalyst.jit import QJIT, qjit from catalyst.pennylane_extensions import ( @@ -82,7 +82,7 @@ vjp, while_loop, ) -from catalyst.utils.exceptions import CompileError +from catalyst.utils.exceptions import AutoGraphError, CompileError autograph_ignore_fallbacks = False """bool: Specify whether AutoGraph should avoid raising diff --git a/frontend/catalyst/ag_primitives.py b/frontend/catalyst/ag_primitives.py index 30bc5a0436..9d9ecbaeee 100644 --- a/frontend/catalyst/ag_primitives.py +++ b/frontend/catalyst/ag_primitives.py @@ -21,41 +21,22 @@ import jax import jax.numpy as jnp - -# Use tensorflow implementations for handling function scopes and calls, -# as well as various utility objects. import pennylane as qml -import tensorflow.python.autograph.impl.api as tf_autograph_api +from malt.core import config as ag_config +from malt.impl import api as ag_api +from malt.impl.api import converted_call as ag_converted_call +from malt.operators import py_builtins as ag_py_builtins +from malt.operators.variables import Undefined +from malt.pyct.origin_info import LineLocation from pennylane.queuing import AnnotatedQueue -from tensorflow.python.autograph.core import config -from tensorflow.python.autograph.core.converter import STANDARD_OPTIONS as STD -from tensorflow.python.autograph.core.converter import ConversionOptions -from tensorflow.python.autograph.core.function_wrappers import ( - FunctionScope, - with_function_scope, -) -from tensorflow.python.autograph.impl.api import autograph_artifact -from tensorflow.python.autograph.impl.api import converted_call as tf_converted_call -from tensorflow.python.autograph.operators.variables import ( - Undefined, - UndefinedReturnValue, -) -from tensorflow.python.autograph.pyct.origin_info import LineLocation import catalyst -from catalyst.ag_utils import AutoGraphError from catalyst.jax_extras import DynamicJaxprTracer, ShapedArray from catalyst.tracing.contexts import EvaluationContext +from catalyst.utils.exceptions import AutoGraphError from catalyst.utils.patching import Patcher __all__ = [ - "STD", - "ConversionOptions", - "Undefined", - "UndefinedReturnValue", - "autograph_artifact", - "FunctionScope", - "with_function_scope", "if_stmt", "for_stmt", "while_stmt", @@ -522,19 +503,22 @@ def get_source_code_info(tb_frame): # issues such as always tracing through code that should only be executed conditionally. We might # have to be even more restrictive in the future to prevent issues if necessary. module_allowlist = ( - config.DoNotConvert("pennylane"), - config.DoNotConvert("catalyst"), - config.DoNotConvert("jax"), -) + config.CONVERSION_RULES + ag_config.DoNotConvert("pennylane"), + ag_config.DoNotConvert("catalyst"), + ag_config.DoNotConvert("jax"), + *ag_config.CONVERSION_RULES, +) def converted_call(fn, args, kwargs, caller_fn_scope=None, options=None): """We want AutoGraph to use our own instance of the AST transformer when recursively transforming functions, but otherwise duplicate the same behaviour.""" + # TODO: eliminate the need for patching by improving the autograph interface with Patcher( - (tf_autograph_api, "_TRANSPILER", catalyst.autograph.TRANSFORMER), - (config, "CONVERSION_RULES", module_allowlist), + (ag_api, "_TRANSPILER", catalyst.autograph.TRANSFORMER), + (ag_config, "CONVERSION_RULES", module_allowlist), + (ag_py_builtins, "BUILTIN_FUNCTIONS_MAP", py_builtins_map), ): # HOTFIX: pass through calls of known Catalyst wrapper functions if fn in ( @@ -557,14 +541,7 @@ def passthrough_wrapper(*args, **kwargs): **(kwargs if kwargs is not None else {}), ) - # Dispatch range calls to a custom range class that enables constructs like - # `for .. in range(..)` to be converted natively to `for_loop` calls. This is beneficial - # since the Python range function does not allow tracers as arguments. - if fn is range: - return CRange(*args, **(kwargs if kwargs is not None else {})) - elif fn is enumerate: - return CEnumerate(*args, **(kwargs if kwargs is not None else {})) - + # TODO: find a way to handle custom decorators more effectively with autograph # We need to unpack nested QNode and QJIT calls as autograph will have trouble handling # them. Ideally, we only want the wrapped function to be transformed by autograph, rather # than the QNode or QJIT call method. @@ -580,12 +557,12 @@ def passthrough_wrapper(*args, **kwargs): @functools.wraps(fn.func) def qnode_call_wrapper(): - return tf_converted_call(fn.func, args, kwargs, caller_fn_scope, options) + return ag_converted_call(fn.func, args, kwargs, caller_fn_scope, options) new_qnode = qml.QNode(qnode_call_wrapper, device=fn.device, diff_method=fn.diff_method) return new_qnode() - return tf_converted_call(fn, args, kwargs, caller_fn_scope, options) + return ag_converted_call(fn, args, kwargs, caller_fn_scope, options) class CRange: @@ -674,3 +651,10 @@ class CEnumerate(enumerate): def __init__(self, iterable, start=0): self.iteration_target = iterable self.start_idx = start + + +py_builtins_map = { + **ag_py_builtins.BUILTIN_FUNCTIONS_MAP, + "range": CRange, + "enumerate": CEnumerate, +} diff --git a/frontend/catalyst/ag_utils.py b/frontend/catalyst/ag_utils.py deleted file mode 100644 index c7838b5f32..0000000000 --- a/frontend/catalyst/ag_utils.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility functions for Catalyst's AutoGraph module. This module can be safely imported without -a TensorFlow installation.""" - -import inspect - -import pennylane as qml - -# pylint: disable=import-outside-toplevel - - -class AutoGraphError(Exception): - """Errors related to Catalyst's AutoGraph module.""" - - -def _test_ag_import(): - """Reusable function for attempting to import Catalyst's AutoGraph module, which requires - TensorFlow to be installed.""" - - try: - import catalyst.autograph # pylint: disable=unused-import - except ImportError as e: - raise ImportError( - "The AutoGraph feature in Catalyst requires TensorFlow. " - "Please install TensorFlow (https://www.tensorflow.org/install) and ensure it is " - "available in the current environment." - ) from e - - -def autograph_source(fn): - """Utility function to retrieve the source code of a function converted by AutoGraph. - - .. warning:: - - Nested functions (those not directly decorated with ``@qjit``) are only lazily converted by - AutoGraph. Make sure that the function has been traced at least once before accessing its - transformed source code, for example by specifying the signature of the compiled program - or by running it at least once. - - Args: - fn (Callable): the original function object that was converted - - Returns: - str: the source code of the converted function - - Raises: - AutoGraphError: If the given function was not converted by AutoGraph, an error will be - raised. - ImportError: If TensorFlow is not installed, an error will be raised. - - **Example** - - .. code-block:: python - - def decide(x): - if x < 5: - y = 15 - else: - y = 1 - return y - - @qjit(autograph=True) - def func(x: int): - y = decide(x) - return y ** 2 - - >>> print(autograph_source(decide)) - def decide_1(x): - with ag__.FunctionScope('decide', 'fscope', ag__.STD) as fscope: - def get_state(): - return (y,) - def set_state(vars_): - nonlocal y - (y,) = vars_ - def if_body(): - nonlocal y - y = 15 - def else_body(): - nonlocal y - y = 1 - y = ag__.Undefined('y') - ag__.if_stmt(x < 5, if_body, else_body, get_state, set_state, ('y',), 1) - return y - """ - _test_ag_import() - from catalyst import QJIT - from catalyst.ag_primitives import STD as STD_OPTIONS - from catalyst.autograph import TOPLEVEL_OPTIONS, TRANSFORMER - - # Handle directly converted objects. - if hasattr(fn, "ag_unconverted"): - return inspect.getsource(fn) - - # Unwrap known objects to get the function actually transformed by autograph. - if isinstance(fn, QJIT): - fn = fn.original_function - if isinstance(fn, qml.QNode): - fn = fn.func - - if TRANSFORMER.has_cache(fn, STD_OPTIONS): - new_fn = TRANSFORMER.get_cached_function(fn, STD_OPTIONS) - return inspect.getsource(new_fn) - elif TRANSFORMER.has_cache(fn, TOPLEVEL_OPTIONS): - new_fn = TRANSFORMER.get_cached_function(fn, TOPLEVEL_OPTIONS) - return inspect.getsource(new_fn) - - raise AutoGraphError( - "The given function was not converted by AutoGraph. If you expect the" - "given function to be converted, please submit a bug report." - ) - - -def print_code(fn): - """Convenience function for testing to print the transformed code.""" - - print(autograph_source(fn)) # pragma: nocover - - -def check_cache(fn): - """Convenience function for testing to check the TRANSFORMER cache.""" - _test_ag_import() - from catalyst.ag_primitives import STD as STD_OPTIONS - from catalyst.autograph import TOPLEVEL_OPTIONS, TRANSFORMER - - return TRANSFORMER.has_cache(fn, STD_OPTIONS) or TRANSFORMER.has_cache(fn, TOPLEVEL_OPTIONS) - - -def run_autograph(fn): - """Safe wrapper around the AutoGraph decorator from the catalyst.autograph module.""" - _test_ag_import() - from catalyst.autograph import autograph - - return autograph(fn) diff --git a/frontend/catalyst/autograph.py b/frontend/catalyst/autograph.py index 64bfc57890..07e9ead3c2 100644 --- a/frontend/catalyst/autograph.py +++ b/frontend/catalyst/autograph.py @@ -13,7 +13,8 @@ # limitations under the License. """AutoGraph is a source-to-source transformation system for converting imperative code into -traceable code for compute graph generation. The system is implemented in the tensorflow project. +traceable code for compute graph generation. The system is implemented in the Diastatic-Malt +package (originally from TensorFlow). Here, we integrate AutoGraph into Catalyst to improve the UX and allow programmers to use built-in Python control flow and other imperative expressions rather than the functional equivalents provided by Catalyst.""" @@ -21,23 +22,23 @@ import inspect import pennylane as qml -from tensorflow.python.autograph.converters import ( - call_trees, - control_flow, - functions, - logical_expressions, -) -from tensorflow.python.autograph.core import converter, unsupported_features_checker -from tensorflow.python.autograph.pyct import transpiler +from malt.core import converter +from malt.impl.api import PyToPy +import catalyst from catalyst import ag_primitives -from catalyst.ag_utils import AutoGraphError +from catalyst.utils.exceptions import AutoGraphError -class CFTransformer(transpiler.PyToPy): +class CatalystTransformer(PyToPy): """A source-to-source transformer to convert imperative style control flow into a function style suitable for tracing.""" + def __init__(self): + super().__init__() + + self._extra_locals = None + def transform(self, obj, user_context): """Launch the transformation process. Typically this only works on function objects. Here we also allow QNodes to be transformed.""" @@ -61,54 +62,41 @@ def transform(self, obj, user_context): return new_obj, module, source_map - def transform_ast(self, node, ctx): - """This method must be overwritten to run all desired transformations. AutoGraph provides - several existing transforms, but we can all also provide our own in the future.""" - - # Check some unsupported Python code ahead of time. - unsupported_features_checker.verify(node) - - # First transform the top-level function to avoid infinite recursion. - node = functions.transform(node, ctx) - - # Convert function calls. This allows us to convert these called functions as well. - node = call_trees.transform(node, ctx) - - # Convert Python control flow to custom 'ag__.if_stmt' ... functions. - node = control_flow.transform(node, ctx) - - # Convert logical expressions - node = logical_expressions.transform(node, ctx) - - return node - def get_extra_locals(self): """Here we can provide any extra names that the converted function should have access to. At a minimum we need to provide the module with definitions for AutoGraph primitives.""" - return {"ag__": ag_primitives} - - def get_caching_key(self, user_context): - """AutoGraph automatically caches transformed functions, the caching key is a combination of - the function source as well as a custom key provided by us here. Changing AutoGraph options - should trigger the function transform again, rather than getting it from cache.""" + if self._extra_locals is None: + extra_locals = super().get_extra_locals() + updates = {key: ag_primitives.__dict__[key] for key in ag_primitives.__all__} + extra_locals["ag__"].__dict__.update(updates) + self._extra_locals = extra_locals - return user_context.options + return self._extra_locals - def has_cache(self, fn, cache_key): + def has_cache(self, fn): """Check for the presence of the given function in the cache. Functions to be converted are cached by the function object itself as well as the conversion options.""" - return self._cache.has(fn, cache_key) + return ( + self._cache.has(fn, TOPLEVEL_OPTIONS) + or self._cache.has(fn, NESTED_OPTIONS) + or self._cache.has(fn, STANDARD_OPTIONS) + ) - def get_cached_function(self, fn, cache_key): + def get_cached_function(self, fn): """Retrieve a Python function object for a previously converted function. Note that repeatedly calling this function with the same arguments will result in new function objects every time, however their source code should be identical with the exception of auto-generated names.""" # Converted functions are cached as a _PythonFnFactory object. - cached_factory = self._cached_factory(fn, cache_key) + if self._cache.has(fn, TOPLEVEL_OPTIONS): + cached_factory = self._cached_factory(fn, TOPLEVEL_OPTIONS) + elif self._cache.has(fn, NESTED_OPTIONS): + cached_factory = self._cached_factory(fn, NESTED_OPTIONS) + else: + cached_factory = self._cached_factory(fn, STANDARD_OPTIONS) # Convert to a Python function object before returning (e.g. to obtain its source code). new_fn = cached_factory.instantiate( @@ -121,7 +109,7 @@ def get_cached_function(self, fn, cache_key): return new_fn -def autograph(fn): +def run_autograph(fn): """Decorator that converts the given function into graph form.""" user_context = converter.ProgramContext(TOPLEVEL_OPTIONS) @@ -134,13 +122,96 @@ def autograph(fn): return new_fn +def autograph_source(fn): + """Utility function to retrieve the source code of a function converted by AutoGraph. + + .. warning:: + + Nested functions (those not directly decorated with ``@qjit``) are only lazily converted by + AutoGraph. Make sure that the function has been traced at least once before accessing its + transformed source code, for example by specifying the signature of the compiled program + or by running it at least once. + + Args: + fn (Callable): the original function object that was converted + + Returns: + str: the source code of the converted function + + Raises: + AutoGraphError: If the given function was not converted by AutoGraph, an error will be + raised. + + **Example** + + .. code-block:: python + + def decide(x): + if x < 5: + y = 15 + else: + y = 1 + return y + + @qjit(autograph=True) + def func(x: int): + y = decide(x) + return y ** 2 + + >>> print(autograph_source(decide)) + def decide_1(x): + with ag__.FunctionScope('decide', 'fscope', ag__.STD) as fscope: + def get_state(): + return (y,) + def set_state(vars_): + nonlocal y + (y,) = vars_ + def if_body(): + nonlocal y + y = 15 + def else_body(): + nonlocal y + y = 1 + y = ag__.Undefined('y') + ag__.if_stmt(x < 5, if_body, else_body, get_state, set_state, ('y',), 1) + return y + """ + + # Handle directly converted objects. + if hasattr(fn, "ag_unconverted"): + return inspect.getsource(fn) + + # Unwrap known objects to get the function actually transformed by autograph. + if isinstance(fn, catalyst.QJIT): + fn = fn.original_function + if isinstance(fn, qml.QNode): + fn = fn.func + + if TRANSFORMER.has_cache(fn): + new_fn = TRANSFORMER.get_cached_function(fn) + return inspect.getsource(new_fn) + + raise AutoGraphError( + "The given function was not converted by AutoGraph. If you expect the" + "given function to be converted, please submit a bug report." + ) + + TOPLEVEL_OPTIONS = converter.ConversionOptions( recursive=True, user_requested=True, internal_convert_user_code=True, - optional_features=None, + optional_features=[converter.Feature.BUILTIN_FUNCTIONS], +) + +NESTED_OPTIONS = converter.ConversionOptions( + recursive=True, + user_requested=False, + internal_convert_user_code=True, + optional_features=[converter.Feature.BUILTIN_FUNCTIONS], ) +STANDARD_OPTIONS = converter.STANDARD_OPTIONS # Keep a global instance of the transformer to benefit from caching. -TRANSFORMER = CFTransformer() +TRANSFORMER = CatalystTransformer() diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 31e47dd44b..588f62093c 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -29,7 +29,7 @@ from jax.tree_util import tree_flatten, tree_unflatten import catalyst -from catalyst.ag_utils import run_autograph +from catalyst.autograph import run_autograph from catalyst.compiled_functions import CompilationCache, CompiledFunction from catalyst.compiler import CompileOptions, Compiler from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr diff --git a/frontend/catalyst/utils/exceptions.py b/frontend/catalyst/utils/exceptions.py index c3d739c392..a42b6933bb 100644 --- a/frontend/catalyst/utils/exceptions.py +++ b/frontend/catalyst/utils/exceptions.py @@ -14,6 +14,10 @@ """Custom Catalyst exceptions.""" +class AutoGraphError(Exception): + """Errors related to Catalyst's AutoGraph module.""" + + class CompileError(Exception): """Error encountered in the compilation phase.""" diff --git a/frontend/test/conftest.py b/frontend/test/conftest.py index c736b71f4f..6ad8081762 100644 --- a/frontend/test/conftest.py +++ b/frontend/test/conftest.py @@ -18,14 +18,6 @@ import platform import pytest -try: - import catalyst - import tensorflow as tf -except (ImportError, ModuleNotFoundError) as e: - tf_available = False -else: - tf_available = True - def is_cuda_available(): """Checks if cuda is available by trying an import. @@ -100,31 +92,6 @@ def pytest_configure(config): ) -def pytest_runtest_setup(item): - """Automatically skip tests if interfaces are not installed""" - interfaces = {"tf"} - available_interfaces = { - "tf": tf_available, - } - - allowed_interfaces = [ - allowed_interface - for allowed_interface in interfaces - if available_interfaces[allowed_interface] is True - ] - - # load the marker specifying what the interface is - all_interfaces = {"tf"} - marks = {mark.name for mark in item.iter_markers() if mark.name in all_interfaces} - - for b in marks: - if b not in allowed_interfaces: - pytest.skip( - f"\nTest {item.nodeid} only runs with {allowed_interfaces}" - f" interfaces(s) but {b} interface provided", - ) - - def skip_cuda_tests(config, items): """Skip cuda tests according to the following logic: By default: RUN diff --git a/frontend/test/lit/test_autograph.py b/frontend/test/lit/test_autograph.py index 0a82bf6294..e2a6f1ebbd 100644 --- a/frontend/test/lit/test_autograph.py +++ b/frontend/test/lit/test_autograph.py @@ -17,14 +17,18 @@ # RUN: %PYTHON %s | FileCheck %s from catalyst import qjit -from catalyst.ag_utils import AutoGraphError, print_code -from catalyst.autograph import autograph +from catalyst.autograph import AutoGraphError, autograph_source, run_autograph -# CHECK-LABEL: def while_simple -@autograph +def print_code(fn): + """Print autograph generated code for a function.""" + print(autograph_source(fn)) + + +# CHECK-LABEL: def ag__while_simple +@run_autograph def while_simple(x: float): - """Test a simple while-loop statemnt.""" + """Test a simple while-loop statement.""" # CHECK: def loop_body # CHECK: def loop_test @@ -103,8 +107,8 @@ def while_fallback_jax(a: int): # ----- -# CHECK-LABEL: def if_simple -@autograph +# CHECK-LABEL: def ag__if_simple +@run_autograph def if_simple(x: float): """Test a simple conditional with a single branch.""" @@ -115,7 +119,7 @@ def if_simple(x: float): # CHECK: def else_body(): # CHECK: pass - # CHECK: ag__.if_stmt(x < 3, if_body, else_body, get_state, set_state, (), 0) + # CHECK: ag__.if_stmt(ag__.ld(x) < 3, if_body, else_body, get_state, set_state, (), 0) return x @@ -125,8 +129,8 @@ def if_simple(x: float): # ----- -# CHECK-LABEL: def if_else -@autograph +# CHECK-LABEL: def ag__if_else +@run_autograph def if_else(x: float): """Test a simple conditional with two branches.""" @@ -139,7 +143,7 @@ def if_else(x: float): else: pass - # CHECK: ag__.if_stmt(x < 3, if_body, else_body, get_state, set_state, (), 0) + # CHECK: ag__.if_stmt(ag__.ld(x) < 3, if_body, else_body, get_state, set_state, (), 0) return x @@ -149,8 +153,8 @@ def if_else(x: float): # ----- -# CHECK-LABEL: def if_assign -@autograph +# CHECK-LABEL: def ag__if_assign +@run_autograph def if_assign(x: float): """Test a conditional creates a new variable.""" @@ -165,9 +169,12 @@ def if_assign(x: float): else: y = 5 - # CHECK: ag__.if_stmt(x < 3, if_body, else_body, get_state, set_state, ('y',), 1) + # CHECK: ag__.if_stmt(ag__.ld(x) < 3, if_body, else_body, get_state, set_state, ('y',), 1) - # CHECK: return y + # CHECK: try: + # CHECK: do_return = True + # CHECK: retval_ = ag__.ld(y) + # CHECK: return fscope.ret(retval_, do_return) return y @@ -176,9 +183,9 @@ def if_assign(x: float): # ----- -# CHECK-LABEL: def if_assign_no_type_mismatch +# CHECK-LABEL: def ag__if_assign_no_type_mismatch @qjit # needed to trigger Catalyst type checks during tracing -@autograph +@run_autograph def if_assign_no_type_mismatch(x: float): """Verify the absense of error from a conditional that doesn't produce the same type across branches.""" @@ -205,7 +212,7 @@ def if_assign_no_type_mismatch(x: float): try: @qjit # needed to trigger the execution of ag__.if_stmt which performs the check - @autograph + @run_autograph def if_assign_pytree_shape_mismatch(x: float): """Verify error from a conditional that doesn't produce a value in all branches.""" @@ -226,7 +233,7 @@ def if_assign_pytree_shape_mismatch(x: float): try: @qjit # needed to trigger the execution of ag__.if_stmt which performs the check - @autograph + @run_autograph def if_assign_partial(x: float): """Verify error from a conditional that doesn't produce a value in all branches.""" @@ -242,8 +249,8 @@ def if_assign_partial(x: float): # ----- -# CHECK-LABEL: def if_assign_existing -@autograph +# CHECK-LABEL: def ag__if_assign_existing +@run_autograph def if_assign_existing(x: float): """Test a conditional that assigns to an existing variable in all branches.""" @@ -261,9 +268,12 @@ def if_assign_existing(x: float): else: y = 5 - # CHECK: ag__.if_stmt(x < 3, if_body, else_body, get_state, set_state, ('y',), 1) + # CHECK: ag__.if_stmt(ag__.ld(x) < 3, if_body, else_body, get_state, set_state, ('y',), 1) - # CHECK: return y + # CHECK: try: + # CHECK: do_return = True + # CHECK: retval_ = ag__.ld(y) + # CHECK: return fscope.ret(retval_, do_return) return y @@ -272,8 +282,8 @@ def if_assign_existing(x: float): # ----- -# CHECK-LABEL: def if_assign_existing_type_mismatch -@autograph +# CHECK-LABEL: def ag__if_assign_existing_type_mismatch +@run_autograph def if_assign_existing_type_mismatch(x: float): """Test a conditional that assigns to an existing variable with a different type, while being consistent across all branches.""" @@ -292,9 +302,12 @@ def if_assign_existing_type_mismatch(x: float): else: y = 5.0 - # CHECK: ag__.if_stmt(x < 3, if_body, else_body, get_state, set_state, ('y',), 1) + # CHECK: ag__.if_stmt(ag__.ld(x) < 3, if_body, else_body, get_state, set_state, ('y',), 1) - # CHECK: return y + # CHECK: try: + # CHECK: do_return = True + # CHECK: retval_ = ag__.ld(y) + # CHECK: return fscope.ret(retval_, do_return) return y @@ -303,8 +316,8 @@ def if_assign_existing_type_mismatch(x: float): # ----- -# CHECK-LABEL: def if_assign_existing_partial -@autograph +# CHECK-LABEL: def ag__if_assign_existing_partial +@run_autograph def if_assign_existing_partial(x: float): """Test a conditional that assigns to an existing variable in some branches only.""" @@ -320,9 +333,12 @@ def if_assign_existing_partial(x: float): # CHECK: nonlocal y # CHECK: pass - # CHECK: ag__.if_stmt(x < 3, if_body, else_body, get_state, set_state, ('y',), 1) + # CHECK: ag__.if_stmt(ag__.ld(x) < 3, if_body, else_body, get_state, set_state, ('y',), 1) - # CHECK: return y + # CHECK: try: + # CHECK: do_return = True + # CHECK: retval_ = ag__.ld(y) + # CHECK: return fscope.ret(retval_, do_return) return y @@ -331,9 +347,9 @@ def if_assign_existing_partial(x: float): # ----- -# CHECK-LABEL: def if_assign_existing_partial_no_type_mismatch +# CHECK-LABEL: def ag__if_assign_existing_partial_no_type_mismatch @qjit -@autograph +@run_autograph def if_assign_existing_partial_no_type_mismatch(x: float): """Verify error from a conditional that assigns to an existing value with different type, without defining a value in all branches. This should lead to a type mismatch error.""" @@ -359,8 +375,8 @@ def if_assign_existing_partial_no_type_mismatch(x: float): # ----- -# CHECK-LABEL: def if_assign_multiple -@autograph +# CHECK-LABEL: def ag__if_assign_multiple +@run_autograph def if_assign_multiple(x: float): """Test a conditional that assigns to multiple existing variables.""" @@ -380,9 +396,12 @@ def if_assign_multiple(x: float): y = 5 z = True - # CHECK: ag__.if_stmt(x < 3, if_body, else_body, get_state, set_state, ('y', 'z'), 2) + # CHECK: ag__.if_stmt(ag__.ld(x) < 3, if_body, else_body, get_state, set_state, ('y', 'z'), 2) - # CHECK: return y * z + # CHECK: try: + # CHECK: do_return = True + # CHECK: retval_ = ag__.ld(y) * ag__.ld(z) + # CHECK: return fscope.ret(retval_, do_return) return y * z @@ -394,7 +413,7 @@ def if_assign_multiple(x: float): try: @qjit - @autograph + @run_autograph def if_assign_invalid_type(x: float): """Verify error from a conditional that produces a type invalid for tracing.""" @@ -412,8 +431,8 @@ def if_assign_invalid_type(x: float): # ----- -# CHECK-LABEL: def if_elif -@autograph +# CHECK-LABEL: def ag__if_elif +@run_autograph def if_elif(x: float): """Test a conditional with more than two branches.""" @@ -432,13 +451,16 @@ def if_elif(x: float): # CHECK: def else_body(): # CHECK: nonlocal y # CHECK: pass - # CHECK: ag__.if_stmt(x < 5, if_body, else_body, get_state, set_state, ('y',), 1) + # CHECK: ag__.if_stmt(ag__.ld(x) < 5, if_body, else_body, get_state, set_state, ('y',), 1) elif x < 5: y = 7 - # CHECK: ag__.if_stmt(x < 3, if_body_1, else_body_1, get_state_1, set_state_1, ('y',), 1) + # CHECK: ag__.if_stmt(ag__.ld(x) < 3, if_body_1, else_body_1, get_state_1, set_state_1, ('y',), 1) - # CHECK: return y + # CHECK: try: + # CHECK: do_return = True + # CHECK: retval_ = ag__.ld(y) + # CHECK: return fscope.ret(retval_, do_return) return y @@ -447,7 +469,7 @@ def if_elif(x: float): # ----- -# CHECK-LABEL: def nested_call +# CHECK-LABEL: def ag__nested_call def nested_call(x, y): """Nested function with conditional.""" @@ -460,14 +482,17 @@ def nested_call(x, y): # CHECK: nonlocal y # CHECK: pass - # CHECK: ag__.if_stmt(x < 3, if_body, else_body, get_state, set_state, ('y',), 1) + # CHECK: ag__.if_stmt(ag__.ld(x) < 3, if_body, else_body, get_state, set_state, ('y',), 1) - # CHECK: return y + # CHECK: try: + # CHECK: do_return = True + # CHECK: retval_ = ag__.ld(y) + # CHECK: return fscope.ret(retval_, do_return) return y -# CHECK-LABEL: def if_call -@autograph +# CHECK-LABEL: def ag__if_call +@run_autograph def if_call(x: float): """Test a conditional that is nested inside another function. All (user) functions invoked by the explicitly transformed function should also be transformed.""" @@ -475,7 +500,10 @@ def if_call(x: float): # CHECK: y = 0 y = 0 - # CHECK: return ag__.converted_call(nested_call, (x, y) + # CHECK: try: + # CHECK: do_return = True + # CHECK: retval_ = ag__.converted_call(ag__.ld(nested_call), (ag__.ld(x), ag__.ld(y)) + # CHECK: return fscope.ret(retval_, do_return) return nested_call(x, y) @@ -487,17 +515,22 @@ def if_call(x: float): # ----- -# CHECK-LABEL: def logical_calls -@autograph +# CHECK-LABEL: def ag__logical_calls +@run_autograph def logical_calls(x: float, y: float): """Check that catalyst can handle ``and``, ``or`` and ``not`` using autograph.""" # pylint: disable=chained-comparison - # CHECK: a = ag__.and_ + # CHECK: a = ag__.and_ a = x >= 0.0 and x <= 1.0 - # CHECK: b = ag__.not_ + + # CHECK: b = ag__.not_ b = not y >= 1.0 - # CHECK: return ag__.or_ + + # CHECK: try: + # CHECK: do_return = True + # CHECK: retval_ = ag__.or_(lambda{{\ ?}}: ag__.ld(a), lambda{{\ ?}}: ag__.ld(b)) + # CHECK: return fscope.ret(retval_, do_return) return a or b @@ -507,14 +540,14 @@ def logical_calls(x: float, y: float): # ----- -# CHECK-LABEL: def chain_logical_call -@autograph +# CHECK-LABEL: def ag__chain_logical_call +@run_autograph def chain_logical_call(x: float): """Check that catalyst can handle chained-``and`` using autograph.""" - # CHECK: ag__.and_ - # CHECK-SAME: 0.0 <= x - # CHECK-SAME: x <= 1.0 + # CHECK: ag__.and_ + # CHECK-SAME: 0.0 <= ag__.ld(x) + # CHECK-SAME: ag__.ld(x) <= 1.0 return 0.0 <= x <= 1.0 diff --git a/frontend/test/pytest/test_autograph.py b/frontend/test/pytest/test_autograph.py index b7cc3e5f36..a8b44c4a73 100644 --- a/frontend/test/pytest/test_autograph.py +++ b/frontend/test/pytest/test_autograph.py @@ -14,7 +14,6 @@ """PyTests for the AutoGraph source-to-source transformation feature.""" -import sys import traceback from collections import defaultdict @@ -38,7 +37,9 @@ qjit, vjp, ) -from catalyst.ag_utils import AutoGraphError, autograph_source, check_cache +from catalyst.autograph import TRANSFORMER, AutoGraphError, autograph_source + +check_cache = TRANSFORMER.has_cache # pylint: disable=import-outside-toplevel # pylint: disable=unnecessary-lambda-assignment @@ -65,17 +66,6 @@ def val(self): return self.ref -def test_unavailable(monkeypatch): - """Check the error produced in the absence of tensorflow.""" - monkeypatch.setitem(sys.modules, "tensorflow", None) - - def fn(x): - return x**2 - - with pytest.raises(ImportError, match="AutoGraph feature in Catalyst requires TensorFlow"): - qjit(autograph=True)(fn) - - @pytest.mark.tf class TestSourceCodeInfo: """Unit tests for exception utilities that retrieves traceback information for the original diff --git a/requirements.txt b/requirements.txt index c4378bac09..5e9d03522c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,5 @@ pytest-cov nbmake # optional rt/test dependencies -tensorflow -amazon-braket-pennylane-plugin>=1.23.0 pennylane-lightning[kokkos] +amazon-braket-pennylane-plugin>=1.23.0 diff --git a/setup.py b/setup.py index 61961ccd51..eafcc5a5fe 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ f"jaxlib=={jax_version}", "tomlkit;python_version<'3.11'", "scipy", + "diastatic-malt>=2.15.1", ] entry_points = {