Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend, MLIR] Support indexing of the dynamically shaped arrays #411

Merged
merged 14 commits into from
Jan 8, 2024
3 changes: 3 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

<h3>New features</h3>

* Catalyst supports indexing and assignments of the dynamically-shaped arrays.
[(#411)](https://github.com/PennyLaneAI/catalyst/pull/411)

* Error mitigation using the zero-noise extrapolation method is now available through the
`catalyst.mitigate_with_zne` transform.

Expand Down
158 changes: 156 additions & 2 deletions frontend/catalyst/utils/jax_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,25 @@
from jax._src.core import _update_thread_local_jit_state
from jax._src.dispatch import jaxpr_replicas
from jax._src.effects import ordered_effects as jax_ordered_effects
from jax._src.interpreters.mlir import _module_name_regex
from jax._src.interpreters.mlir import _module_name_regex, register_lowering
from jax._src.interpreters.partial_eval import (
_input_type_to_tracers,
infer_lambda_input_type,
trace_to_jaxpr_dynamic2,
)
from jax._src.lax.control_flow import _initial_style_jaxpr, _initial_style_open_jaxpr
from jax._src.lax.lax import _abstractify, xla
from jax._src.lax.slicing import (
_argnum_weak_type,
_gather_dtype_rule,
_gather_lower,
_gather_shape_computation,
_is_sorted,
_no_duplicate_dims,
_rank,
_sorted_dims_in_range,
standard_primitive,
)
from jax._src.linear_util import annotate
from jax._src.pjit import _extract_implicit_args, _flat_axes_specs
from jax._src.sharding_impls import ReplicaAxisContext
Expand Down Expand Up @@ -463,10 +474,23 @@
in_type = infer_lambda_input_type(axes_specs, flat_args)
return in_type, in_tree

# TODO: See the `_gather_shape_rule_dynamic` comment. Remove once the upstream change is
# applied.
gather2_p = standard_primitive(
_gather_shape_rule_dynamic,
_gather_dtype_rule,
"gather",
weak_type_rule=_argnum_weak_type(0),
)
register_lowering(gather2_p, _gather_lower)

@wraps(fun)
def make_jaxpr_f(*args, **kwargs):
# TODO: re-use `deduce_avals` here.
with Patcher((jax._src.interpreters.partial_eval, "get_aval", get_aval2)), ExitStack():
with Patcher(
(jax._src.interpreters.partial_eval, "get_aval", get_aval2),
(jax._src.lax.slicing, "gather_p", gather2_p),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dime10 @grwlf here we have an issue, the new rule does not define jvp and therefore it is not compatible with grad or jacobian transformations ad.defjvp(gather_p, _gather_jvp_rule, None) see the original gather_p

gather_p = standard_primitive(
    _gather_shape_rule, _gather_dtype_rule, 'gather',
    weak_type_rule=_argnum_weak_type(0))
ad.defjvp(gather_p, _gather_jvp_rule, None)
ad.primitive_transposes[gather_p] = _gather_transpose_rule
batching.primitive_batchers[gather_p] = _gather_batching_rule
pe.padding_rules[gather_p] = _gather_pad_rule

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rmoyard does this result in a user-facing bug?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you cannot use jax.grad inside qjit when a gather operation is created, for example slicing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so #305?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That issue is about Catalyst gradients, I think Romain is talking about JAX gradients (which are run in the frontend on the jaxpr, hence the primitives need gradient rules).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, got it!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly what David said, qjit with jax.grad and slicing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qjit(jax.grad(f))(x)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking if we have a resolution on this particular comment thread :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@grwlf How is the upstream PR looking for this patch?
In the meantime, can we just attach the original gradient rule to the patched primitive?

), ExitStack():
f = wrap_init(fun)
in_type, in_tree = abstractify(args, kwargs)
f, out_tree_promise = flatten_fun(f, in_tree)
Expand All @@ -477,3 +501,133 @@

make_jaxpr_f.__name__ = f"make_jaxpr2({make_jaxpr2.__name__})"
return make_jaxpr_f


def _gather_shape_rule_dynamic(
operand,
indices,
*,
dimension_numbers,
slice_sizes,
unique_indices,
indices_are_sorted,
mode,
fill_value,
): # pragma: no cover
"""Validates the well-formedness of the arguments to Gather. Compared to the original version,
this implementation skips static shape checks if variable dimensions are used.

This function has been modified from its original form in the JAX project at
https://github.com/google/jax/blob/88a60b808c1f91260cc9e75b9aa2508aae5bc9f9/jax/_src/lax/slicing.py#L1438
version released under the Apache License, Version 2.0, with the following copyright notice:

Copyright 2021 The JAX Authors.
"""
# pylint: disable=unused-argument
# pylint: disable=too-many-branches
# pylint: disable=consider-using-enumerate
# pylint: disable=chained-comparison
offset_dims = dimension_numbers.offset_dims
collapsed_slice_dims = dimension_numbers.collapsed_slice_dims
start_index_map = dimension_numbers.start_index_map

# Note: in JAX, index_vector_dim is always computed as below, cf. the
# documentation of the GatherDimensionNumbers class.
index_vector_dim = _rank(indices) - 1

# This case should never happen in JAX, due to the implicit construction of
# index_vector_dim, but is included for completeness.
if _rank(indices) < index_vector_dim or index_vector_dim < 0:
raise TypeError(
f"Gather index leaf dimension must be within [0, rank("
f"indices) + 1). rank(indices) is {_rank(indices)} and "
f"gather index leaf dimension is {index_vector_dim}."
)

# Start ValidateGatherDimensions
# In the error messages output by XLA, "offset_dims" is called "Output window
# dimensions" in error messages. For consistency's sake, our error messages
# stick to "offset_dims".
_is_sorted(offset_dims, "gather", "offset_dims")
_no_duplicate_dims(offset_dims, "gather", "offset_dims")

output_offset_dim_count = len(offset_dims)
output_shape_rank = len(offset_dims) + _rank(indices) - 1

for i in range(output_offset_dim_count):
offset_dim = offset_dims[i]
if offset_dim < 0 or offset_dim >= output_shape_rank:
raise TypeError(
f"Offset dimension {i} in gather op is out of bounds; "
f"got {offset_dim}, but should have been in "
f"[0, {output_shape_rank})"
)

if len(start_index_map) != indices.shape[index_vector_dim]:
raise TypeError(
f"Gather op has {len(start_index_map)} elements in "
f"start_index_map and the bound of dimension "
f"{index_vector_dim=} of indices is "
f"{indices.shape[index_vector_dim]}. These two "
f"numbers must be equal."
)

for i in range(len(start_index_map)):
operand_dim_for_start_index_i = start_index_map[i]
if operand_dim_for_start_index_i < 0 or operand_dim_for_start_index_i >= _rank(operand):
raise TypeError(
f"Invalid start_index_map; domain is "
f"[0, {_rank(operand)}), got: "
f"{i}->{operand_dim_for_start_index_i}."
)

_no_duplicate_dims(start_index_map, "gather", "start_index_map")

# _is_sorted and _sorted_dims_in_range are checked in the opposite order
# compared to the XLA implementation. In cases when the input is not sorted
# AND there are problematic collapsed_slice_dims, the error message will thus
# be different.
_is_sorted(collapsed_slice_dims, "gather", "collapsed_slice_dims")
_sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", "collapsed_slice_dims")
_no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims")
# End ValidateGatherDimensions

if _rank(operand) != len(slice_sizes):
raise TypeError(
f"Gather op must have one slice size for every input "
f"dimension; got: len(slice_sizes)={len(slice_sizes)}, "
f"input_shape.rank={_rank(operand)}"
)

if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims):
raise TypeError(
f"All components of the offset index in a gather op must "
f"either be a offset dimension or explicitly collapsed; "
f"got len(slice_sizes)={len(slice_sizes)}, "
f"output_slice_sizes={offset_dims}, collapsed_slice_dims="
f"{collapsed_slice_dims}."
)

# This section contains a patch suggested to the upstream.
for i in range(len(slice_sizes)):
slice_size = slice_sizes[i]
corresponding_input_size = operand.shape[i]

if jax.core.is_constant_dim(corresponding_input_size):
if not (slice_size >= 0 and corresponding_input_size >= slice_size):
raise TypeError(
f"Slice size at index {i} in gather op is out of range, "
f"must be within [0, {corresponding_input_size} + 1), "
f"got {slice_size}."
)

for i in range(len(collapsed_slice_dims)):
bound = slice_sizes[collapsed_slice_dims[i]]
if bound != 1:
raise TypeError(
f"Gather op can only collapse slice dims with bound 1, "
f"but bound is {bound} for index "

Check notice on line 629 in frontend/catalyst/utils/jax_extras.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/utils/jax_extras.py#L506-L629

Complex Method
f"{collapsed_slice_dims[i]} at position {i}."
)

return _gather_shape_computation(indices, dimension_numbers, slice_sizes)
11 changes: 11 additions & 0 deletions frontend/test/lit/test_jax_dynamic_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,14 @@ def test_qjit_aot(a: ShapedArray([1, 3, 1], dtype=float)):


print_mlir(test_qjit_aot, aot=True)


@qjit
def test_qjit_indexing(sz):
"""Check the usage of stablehlo.gather for indexing"""
r = jnp.ones((sz + 1,), dtype=int)
# CHECK: gather
return r[0]


print_mlir(test_qjit_indexing, 3)
26 changes: 26 additions & 0 deletions frontend/test/pytest/test_jax_dynamic_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,5 +387,31 @@ def i(x):
assert _id0 == _id1


def test_array_indexing():
"""Test the support of indexing of dynamically-shaped arrays"""

@qjit
def fun(sz, idx):
r = jnp.ones((sz, 3, sz + 1), dtype=int)
return r[idx, 2, idx]

res = fun(5, 2)
assert res == 1


def test_array_assignment():
"""Test the support of assigning a value to a dynamically-shaped array"""

@qjit
def fun(sz, idx, val):
r = jnp.ones((sz, 3, sz), dtype=int)
r = r.at[idx, 0, idx].set(val)
return r

result = fun(5, 2, 33)
expected = jnp.ones((5, 3, 5), dtype=int).at[2, 0, 2].set(33)
assert_array_and_dtype_equal(result, expected)


if __name__ == "__main__":
pytest.main(["-x", __file__])
4 changes: 2 additions & 2 deletions mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern<mhlo::ScatterOp>
auto indexScatter =
builder.create<tensor::ExtractOp>(loc, scatterIndices, index);
auto indexUpdateCasted =
builder.create<index::CastSOp>(loc, builder.getI32Type(), indexUpdate);
builder.create<index::CastSOp>(loc, indexScatter.getType(), indexUpdate);
sergei-mironov marked this conversation as resolved.
Show resolved Hide resolved
Value addValue =
builder.create<arith::AddIOp>(loc, indexScatter, indexUpdateCasted);
Value addValueCasted =
Expand Down Expand Up @@ -409,7 +409,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern<mhlo::ScatterOp>
Value indexScatter = fullStartIndex[i];
auto indexUpdate = updateWindowsIndices[i];
auto indexUpdateCasted =
builder.create<index::CastSOp>(loc, builder.getI32Type(), indexUpdate);
builder.create<index::CastSOp>(loc, indexScatter.getType(), indexUpdate);
Value addValue =
builder.create<arith::AddIOp>(loc, indexScatter, indexUpdateCasted);
Value addValueCasted =
Expand Down
Loading