From 559a59fb79477e1c350902deb189c26f6d76bacc Mon Sep 17 00:00:00 2001 From: Raul Torres Date: Mon, 6 May 2024 12:48:53 +0300 Subject: [PATCH] [skip ci][frontend] Add support to AG for Jax single array assignment --- frontend/catalyst/ag_primitives.py | 26 ++++++++++++++++++++++++++ frontend/catalyst/autograph.py | 7 +++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/ag_primitives.py b/frontend/catalyst/ag_primitives.py index 9d9ecbaeee..9662fe7ad6 100644 --- a/frontend/catalyst/ag_primitives.py +++ b/frontend/catalyst/ag_primitives.py @@ -44,6 +44,7 @@ "and_", "or_", "not_", + "set_item", ] @@ -565,6 +566,31 @@ def qnode_call_wrapper(): return ag_converted_call(fn, args, kwargs, caller_fn_scope, options) +def set_item(target, i, x): + """An implementation of the AutoGraph 'set_item' function. The interface is defined by + AutoGraph, here we merely provide an implementation of it in terms of Catalyst primitives. + The idea is simply to accept the much simpler single index assigment syntax for Jax arrays, + to subsequently transform it under the hood into the set of 'at' and 'set' calls that + Autograph supports. E.g.: + target[i] = x -> target = target.at[i].set(x) + + .. note:: + For this feature to work, 'converter.Feature.LISTS' had to be added to the + TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst + Autograph transformer. If you create a new transformer and want to support this feature, + make sure you enable such option there as well. + """ + + # Apply the 'at...set' transformation only to Jax arrays. + # Otherwise, fallback to Python's default syntax. + if isinstance(target, DynamicJaxprTracer): + target = target.at[i].set(x) + else: + target[i] = x + + return target + + class CRange: """Catalyst range object. diff --git a/frontend/catalyst/autograph.py b/frontend/catalyst/autograph.py index 19faa4330f..926e6c11d2 100644 --- a/frontend/catalyst/autograph.py +++ b/frontend/catalyst/autograph.py @@ -265,18 +265,21 @@ def __init__(self): # Singleton instance of DisableAutograph disable_autograph = DisableAutograph() +# converter.Feature.LISTS permits overloading the 'set_item' function in 'ag_primitives.py' +OPTIONAL_FEATURES = [converter.Feature.BUILTIN_FUNCTIONS, converter.Feature.LISTS] + TOPLEVEL_OPTIONS = converter.ConversionOptions( recursive=True, user_requested=True, internal_convert_user_code=True, - optional_features=[converter.Feature.BUILTIN_FUNCTIONS], + optional_features=OPTIONAL_FEATURES, ) NESTED_OPTIONS = converter.ConversionOptions( recursive=True, user_requested=False, internal_convert_user_code=True, - optional_features=[converter.Feature.BUILTIN_FUNCTIONS], + optional_features=OPTIONAL_FEATURES, ) STANDARD_OPTIONS = converter.STANDARD_OPTIONS