Skip to content

Commit

Permalink
[skip ci][frontend] Add support to AG for Jax single array assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
rauletorresc committed May 6, 2024
1 parent 57dcc0a commit 559a59f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
26 changes: 26 additions & 0 deletions frontend/catalyst/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"and_",
"or_",
"not_",
"set_item",
]


Expand Down Expand Up @@ -565,6 +566,31 @@ def qnode_call_wrapper():
return ag_converted_call(fn, args, kwargs, caller_fn_scope, options)


def set_item(target, i, x):
"""An implementation of the AutoGraph 'set_item' function. The interface is defined by
AutoGraph, here we merely provide an implementation of it in terms of Catalyst primitives.
The idea is simply to accept the much simpler single index assigment syntax for Jax arrays,
to subsequently transform it under the hood into the set of 'at' and 'set' calls that
Autograph supports. E.g.:
target[i] = x -> target = target.at[i].set(x)
.. note::
For this feature to work, 'converter.Feature.LISTS' had to be added to the
TOP_LEVEL_OPTIONS and NESTED_LEVEL_OPTIONS conversion options of our own Catalyst
Autograph transformer. If you create a new transformer and want to support this feature,
make sure you enable such option there as well.
"""

# Apply the 'at...set' transformation only to Jax arrays.
# Otherwise, fallback to Python's default syntax.
if isinstance(target, DynamicJaxprTracer):
target = target.at[i].set(x)
else:
target[i] = x

return target


class CRange:
"""Catalyst range object.
Expand Down
7 changes: 5 additions & 2 deletions frontend/catalyst/autograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,18 +265,21 @@ def __init__(self):
# Singleton instance of DisableAutograph
disable_autograph = DisableAutograph()

# converter.Feature.LISTS permits overloading the 'set_item' function in 'ag_primitives.py'
OPTIONAL_FEATURES = [converter.Feature.BUILTIN_FUNCTIONS, converter.Feature.LISTS]

TOPLEVEL_OPTIONS = converter.ConversionOptions(
recursive=True,
user_requested=True,
internal_convert_user_code=True,
optional_features=[converter.Feature.BUILTIN_FUNCTIONS],
optional_features=OPTIONAL_FEATURES,
)

NESTED_OPTIONS = converter.ConversionOptions(
recursive=True,
user_requested=False,
internal_convert_user_code=True,
optional_features=[converter.Feature.BUILTIN_FUNCTIONS],
optional_features=OPTIONAL_FEATURES,
)

STANDARD_OPTIONS = converter.STANDARD_OPTIONS
Expand Down

0 comments on commit 559a59f

Please sign in to comment.