From d5d608c2cdb506bf9edc42a4305b170c24c16546 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Wed, 20 Dec 2023 13:11:49 +0000 Subject: [PATCH 01/14] Add indexing support --- frontend/catalyst/utils/jax_extras.py | 102 +++++++++++++++++++++++++- 1 file changed, 100 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/utils/jax_extras.py b/frontend/catalyst/utils/jax_extras.py index a9954ba537..9b278a19c8 100644 --- a/frontend/catalyst/utils/jax_extras.py +++ b/frontend/catalyst/utils/jax_extras.py @@ -24,7 +24,7 @@ from jax._src.core import _update_thread_local_jit_state from jax._src.dispatch import jaxpr_replicas from jax._src.effects import ordered_effects as jax_ordered_effects -from jax._src.interpreters.mlir import _module_name_regex +from jax._src.interpreters.mlir import _module_name_regex, register_lowering from jax._src.interpreters.partial_eval import ( _input_type_to_tracers, infer_lambda_input_type, @@ -38,6 +38,11 @@ from jax._src.source_info_util import current as jax_current from jax._src.source_info_util import new_name_stack from jax._src.util import partition_list, safe_map, unzip2, unzip3, wrap_name, wraps +from jax._src.util import safe_map, unzip2, wrap_name, wraps +from jax._src.lax.slicing import ( + _is_sorted, _no_duplicate_dims, _rank, _sorted_dims_in_range, _gather_shape_computation, + _gather_dtype_rule, _argnum_weak_type, standard_primitive, _gather_lower +) from jax.api_util import flatten_fun from jax.core import AbstractValue, ClosedJaxpr, Jaxpr, JaxprEqn, MainTrace, OutputType from jax.core import Primitive as JaxprPrimitive @@ -463,10 +468,19 @@ def abstractify(args, kwargs): in_type = infer_lambda_input_type(axes_specs, flat_args) return in_type, in_tree + gather2_p = standard_primitive( + _gather_shape_rule_dynamic, _gather_dtype_rule, 'gather', + weak_type_rule=_argnum_weak_type(0)) + + register_lowering(gather2_p, _gather_lower) + @wraps(fun) def make_jaxpr_f(*args, **kwargs): # TODO: re-use `deduce_avals` here. - with Patcher((jax._src.interpreters.partial_eval, "get_aval", get_aval2)), ExitStack(): + with Patcher( + (jax._src.interpreters.partial_eval, "get_aval", get_aval2), + (jax._src.lax.slicing, "gather_p", gather2_p) + ), ExitStack(): f = wrap_init(fun) in_type, in_tree = abstractify(args, kwargs) f, out_tree_promise = flatten_fun(f, in_tree) @@ -477,3 +491,87 @@ def make_jaxpr_f(*args, **kwargs): make_jaxpr_f.__name__ = f"make_jaxpr2({make_jaxpr2.__name__})" return make_jaxpr_f + + +def _gather_shape_rule_dynamic(operand, indices, *, dimension_numbers, + slice_sizes, unique_indices, indices_are_sorted, + mode, fill_value): + offset_dims = dimension_numbers.offset_dims + collapsed_slice_dims = dimension_numbers.collapsed_slice_dims + start_index_map = dimension_numbers.start_index_map + + # Note: in JAX, index_vector_dim is always computed as below, cf. the + # documentation of the GatherDimensionNumbers class. + index_vector_dim = _rank(indices) - 1 + + # This case should never happen in JAX, due to the implicit construction of + # index_vector_dim, but is included for completeness. + if _rank(indices) < index_vector_dim or index_vector_dim < 0: + raise TypeError(f"Gather index leaf dimension must be within [0, rank(" + f"indices) + 1). rank(indices) is {_rank(indices)} and " + f"gather index leaf dimension is {index_vector_dim}.") + + # Start ValidateGatherDimensions + # In the error messages output by XLA, "offset_dims" is called "Output window + # dimensions" in error messages. For consistency's sake, our error messages + # stick to "offset_dims". + _is_sorted(offset_dims, "gather", "offset_dims") + _no_duplicate_dims(offset_dims, "gather", "offset_dims") + + output_offset_dim_count = len(offset_dims) + output_shape_rank = len(offset_dims) + _rank(indices) - 1 + + for i in range(output_offset_dim_count): + offset_dim = offset_dims[i] + if offset_dim < 0 or offset_dim >= output_shape_rank: + raise TypeError(f"Offset dimension {i} in gather op is out of bounds; " + f"got {offset_dim}, but should have been in " + f"[0, {output_shape_rank})") + + if len(start_index_map) != indices.shape[index_vector_dim]: + raise TypeError(f"Gather op has {len(start_index_map)} elements in " + f"start_index_map and the bound of dimension " + f"{index_vector_dim=} of indices is " + f"{indices.shape[index_vector_dim]}. These two " + f"numbers must be equal.") + + for i in range(len(start_index_map)): + operand_dim_for_start_index_i = start_index_map[i] + if (operand_dim_for_start_index_i < 0 or + operand_dim_for_start_index_i >= _rank(operand)): + raise TypeError(f"Invalid start_index_map; domain is " + f"[0, {_rank(operand)}), got: " + f"{i}->{operand_dim_for_start_index_i}.") + + _no_duplicate_dims(start_index_map, "gather", "start_index_map") + + # _is_sorted and _sorted_dims_in_range are checked in the opposite order + # compared to the XLA implementation. In cases when the input is not sorted + # AND there are problematic collapsed_slice_dims, the error message will thus + # be different. + _is_sorted(collapsed_slice_dims, "gather", "collapsed_slice_dims") + _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", + "collapsed_slice_dims") + _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims") + # End ValidateGatherDimensions + + if _rank(operand) != len(slice_sizes): + raise TypeError(f"Gather op must have one slice size for every input " + f"dimension; got: len(slice_sizes)={len(slice_sizes)}, " + f"input_shape.rank={_rank(operand)}") + + if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims): + raise TypeError(f"All components of the offset index in a gather op must " + f"either be a offset dimension or explicitly collapsed; " + f"got len(slice_sizes)={len(slice_sizes)}, " + f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" + f"{collapsed_slice_dims}.") + + for i in range(len(collapsed_slice_dims)): + bound = slice_sizes[collapsed_slice_dims[i]] + if bound != 1: + raise TypeError(f"Gather op can only collapse slice dims with bound 1, " + f"but bound is {bound} for index " + f"{collapsed_slice_dims[i]} at position {i}.") + + return _gather_shape_computation(indices, dimension_numbers, slice_sizes) From ffc397d4c4558ea7c41e6fd90cd79edf86bbc371 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Wed, 20 Dec 2023 13:16:21 +0000 Subject: [PATCH 02/14] Test dynamic indexing support --- frontend/catalyst/utils/jax_extras.py | 186 ++++++++++++---------- frontend/test/lit/test_jax_dynamic_api.py | 10 ++ 2 files changed, 114 insertions(+), 82 deletions(-) diff --git a/frontend/catalyst/utils/jax_extras.py b/frontend/catalyst/utils/jax_extras.py index 9b278a19c8..302ff00c49 100644 --- a/frontend/catalyst/utils/jax_extras.py +++ b/frontend/catalyst/utils/jax_extras.py @@ -32,6 +32,17 @@ ) from jax._src.lax.control_flow import _initial_style_jaxpr, _initial_style_open_jaxpr from jax._src.lax.lax import _abstractify, xla +from jax._src.lax.slicing import ( + _argnum_weak_type, + _gather_dtype_rule, + _gather_lower, + _gather_shape_computation, + _is_sorted, + _no_duplicate_dims, + _rank, + _sorted_dims_in_range, + standard_primitive, +) from jax._src.linear_util import annotate from jax._src.pjit import _extract_implicit_args, _flat_axes_specs from jax._src.sharding_impls import ReplicaAxisContext @@ -39,10 +50,6 @@ from jax._src.source_info_util import new_name_stack from jax._src.util import partition_list, safe_map, unzip2, unzip3, wrap_name, wraps from jax._src.util import safe_map, unzip2, wrap_name, wraps -from jax._src.lax.slicing import ( - _is_sorted, _no_duplicate_dims, _rank, _sorted_dims_in_range, _gather_shape_computation, - _gather_dtype_rule, _argnum_weak_type, standard_primitive, _gather_lower -) from jax.api_util import flatten_fun from jax.core import AbstractValue, ClosedJaxpr, Jaxpr, JaxprEqn, MainTrace, OutputType from jax.core import Primitive as JaxprPrimitive @@ -469,8 +476,11 @@ def abstractify(args, kwargs): return in_type, in_tree gather2_p = standard_primitive( - _gather_shape_rule_dynamic, _gather_dtype_rule, 'gather', - weak_type_rule=_argnum_weak_type(0)) + _gather_shape_rule_dynamic, + _gather_dtype_rule, + "gather", + weak_type_rule=_argnum_weak_type(0), + ) register_lowering(gather2_p, _gather_lower) @@ -479,7 +489,7 @@ def make_jaxpr_f(*args, **kwargs): # TODO: re-use `deduce_avals` here. with Patcher( (jax._src.interpreters.partial_eval, "get_aval", get_aval2), - (jax._src.lax.slicing, "gather_p", gather2_p) + (jax._src.lax.slicing, "gather_p", gather2_p), ), ExitStack(): f = wrap_init(fun) in_type, in_tree = abstractify(args, kwargs) @@ -500,78 +510,90 @@ def _gather_shape_rule_dynamic(operand, indices, *, dimension_numbers, collapsed_slice_dims = dimension_numbers.collapsed_slice_dims start_index_map = dimension_numbers.start_index_map - # Note: in JAX, index_vector_dim is always computed as below, cf. the - # documentation of the GatherDimensionNumbers class. - index_vector_dim = _rank(indices) - 1 - - # This case should never happen in JAX, due to the implicit construction of - # index_vector_dim, but is included for completeness. - if _rank(indices) < index_vector_dim or index_vector_dim < 0: - raise TypeError(f"Gather index leaf dimension must be within [0, rank(" - f"indices) + 1). rank(indices) is {_rank(indices)} and " - f"gather index leaf dimension is {index_vector_dim}.") - - # Start ValidateGatherDimensions - # In the error messages output by XLA, "offset_dims" is called "Output window - # dimensions" in error messages. For consistency's sake, our error messages - # stick to "offset_dims". - _is_sorted(offset_dims, "gather", "offset_dims") - _no_duplicate_dims(offset_dims, "gather", "offset_dims") - - output_offset_dim_count = len(offset_dims) - output_shape_rank = len(offset_dims) + _rank(indices) - 1 - - for i in range(output_offset_dim_count): - offset_dim = offset_dims[i] - if offset_dim < 0 or offset_dim >= output_shape_rank: - raise TypeError(f"Offset dimension {i} in gather op is out of bounds; " - f"got {offset_dim}, but should have been in " - f"[0, {output_shape_rank})") - - if len(start_index_map) != indices.shape[index_vector_dim]: - raise TypeError(f"Gather op has {len(start_index_map)} elements in " - f"start_index_map and the bound of dimension " - f"{index_vector_dim=} of indices is " - f"{indices.shape[index_vector_dim]}. These two " - f"numbers must be equal.") - - for i in range(len(start_index_map)): - operand_dim_for_start_index_i = start_index_map[i] - if (operand_dim_for_start_index_i < 0 or - operand_dim_for_start_index_i >= _rank(operand)): - raise TypeError(f"Invalid start_index_map; domain is " - f"[0, {_rank(operand)}), got: " - f"{i}->{operand_dim_for_start_index_i}.") - - _no_duplicate_dims(start_index_map, "gather", "start_index_map") - - # _is_sorted and _sorted_dims_in_range are checked in the opposite order - # compared to the XLA implementation. In cases when the input is not sorted - # AND there are problematic collapsed_slice_dims, the error message will thus - # be different. - _is_sorted(collapsed_slice_dims, "gather", "collapsed_slice_dims") - _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", - "collapsed_slice_dims") - _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims") - # End ValidateGatherDimensions - - if _rank(operand) != len(slice_sizes): - raise TypeError(f"Gather op must have one slice size for every input " - f"dimension; got: len(slice_sizes)={len(slice_sizes)}, " - f"input_shape.rank={_rank(operand)}") - - if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims): - raise TypeError(f"All components of the offset index in a gather op must " - f"either be a offset dimension or explicitly collapsed; " - f"got len(slice_sizes)={len(slice_sizes)}, " - f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" - f"{collapsed_slice_dims}.") - - for i in range(len(collapsed_slice_dims)): - bound = slice_sizes[collapsed_slice_dims[i]] - if bound != 1: - raise TypeError(f"Gather op can only collapse slice dims with bound 1, " - f"but bound is {bound} for index " - f"{collapsed_slice_dims[i]} at position {i}.") - - return _gather_shape_computation(indices, dimension_numbers, slice_sizes) + # Note: in JAX, index_vector_dim is always computed as below, cf. the + # documentation of the GatherDimensionNumbers class. + index_vector_dim = _rank(indices) - 1 + + # This case should never happen in JAX, due to the implicit construction of + # index_vector_dim, but is included for completeness. + if _rank(indices) < index_vector_dim or index_vector_dim < 0: + raise TypeError( + f"Gather index leaf dimension must be within [0, rank(" + f"indices) + 1). rank(indices) is {_rank(indices)} and " + f"gather index leaf dimension is {index_vector_dim}." + ) + + # Start ValidateGatherDimensions + # In the error messages output by XLA, "offset_dims" is called "Output window + # dimensions" in error messages. For consistency's sake, our error messages + # stick to "offset_dims". + _is_sorted(offset_dims, "gather", "offset_dims") + _no_duplicate_dims(offset_dims, "gather", "offset_dims") + + output_offset_dim_count = len(offset_dims) + output_shape_rank = len(offset_dims) + _rank(indices) - 1 + + for i in range(output_offset_dim_count): + offset_dim = offset_dims[i] + if offset_dim < 0 or offset_dim >= output_shape_rank: + raise TypeError( + f"Offset dimension {i} in gather op is out of bounds; " + f"got {offset_dim}, but should have been in " + f"[0, {output_shape_rank})" + ) + + if len(start_index_map) != indices.shape[index_vector_dim]: + raise TypeError( + f"Gather op has {len(start_index_map)} elements in " + f"start_index_map and the bound of dimension " + f"{index_vector_dim=} of indices is " + f"{indices.shape[index_vector_dim]}. These two " + f"numbers must be equal." + ) + + for i in range(len(start_index_map)): + operand_dim_for_start_index_i = start_index_map[i] + if operand_dim_for_start_index_i < 0 or operand_dim_for_start_index_i >= _rank(operand): + raise TypeError( + f"Invalid start_index_map; domain is " + f"[0, {_rank(operand)}), got: " + f"{i}->{operand_dim_for_start_index_i}." + ) + + _no_duplicate_dims(start_index_map, "gather", "start_index_map") + + # _is_sorted and _sorted_dims_in_range are checked in the opposite order + # compared to the XLA implementation. In cases when the input is not sorted + # AND there are problematic collapsed_slice_dims, the error message will thus + # be different. + _is_sorted(collapsed_slice_dims, "gather", "collapsed_slice_dims") + _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", "collapsed_slice_dims") + _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims") + # End ValidateGatherDimensions + + if _rank(operand) != len(slice_sizes): + raise TypeError( + f"Gather op must have one slice size for every input " + f"dimension; got: len(slice_sizes)={len(slice_sizes)}, " + f"input_shape.rank={_rank(operand)}" + ) + + if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims): + raise TypeError( + f"All components of the offset index in a gather op must " + f"either be a offset dimension or explicitly collapsed; " + f"got len(slice_sizes)={len(slice_sizes)}, " + f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" + f"{collapsed_slice_dims}." + ) + + for i in range(len(collapsed_slice_dims)): + bound = slice_sizes[collapsed_slice_dims[i]] + if bound != 1: + raise TypeError( + f"Gather op can only collapse slice dims with bound 1, " + f"but bound is {bound} for index " + f"{collapsed_slice_dims[i]} at position {i}." + ) + + return _gather_shape_computation(indices, dimension_numbers, slice_sizes) diff --git a/frontend/test/lit/test_jax_dynamic_api.py b/frontend/test/lit/test_jax_dynamic_api.py index b2a0062937..9ee1d1e952 100644 --- a/frontend/test/lit/test_jax_dynamic_api.py +++ b/frontend/test/lit/test_jax_dynamic_api.py @@ -117,3 +117,13 @@ def test_qjit_aot(a: ShapedArray([1, 3, 1], dtype=float)): print_mlir(test_qjit_aot, aot=True) + + +@qjit +def test_qjit_indexing(sz): + r = jnp.ones((sz + 1,), dtype=int) + # CHECK: gather + return r[0] + + +print_mlir(test_qjit_indexing, 3) From 78f0e93ae056eeda7640d3a215955e9669eaca38 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Wed, 20 Dec 2023 13:26:40 +0000 Subject: [PATCH 03/14] Add pytest --- frontend/test/pytest/test_jax_dynamic_api.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/frontend/test/pytest/test_jax_dynamic_api.py b/frontend/test/pytest/test_jax_dynamic_api.py index 7228c1ebc7..4af560f757 100644 --- a/frontend/test/pytest/test_jax_dynamic_api.py +++ b/frontend/test/pytest/test_jax_dynamic_api.py @@ -387,5 +387,17 @@ def i(x): assert _id0 == _id1 +def test_indexing(): + """Test the support of indexing of dynamically-shaped arrays""" + + @qjit + def fun(sz, idx): + r = jnp.ones((sz, 3, sz + 1), dtype=int) + return r[idx, 2, idx] + + res = fun(5, 2) + assert res == 1 + + if __name__ == "__main__": pytest.main(["-x", __file__]) From 2145fab9d7cff02ff8c2ef024748eec56758d28e Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Thu, 21 Dec 2023 11:53:09 +0000 Subject: [PATCH 04/14] Fix index type mismatch problem --- mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp b/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp index 277fdbea4f..6b685a7772 100644 --- a/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp +++ b/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp @@ -366,7 +366,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern auto indexScatter = builder.create(loc, scatterIndices, index); auto indexUpdateCasted = - builder.create(loc, builder.getI32Type(), indexUpdate); + builder.create(loc, indexScatter.getType(), indexUpdate); Value addValue = builder.create(loc, indexScatter, indexUpdateCasted); Value addValueCasted = @@ -409,7 +409,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern Value indexScatter = fullStartIndex[i]; auto indexUpdate = updateWindowsIndices[i]; auto indexUpdateCasted = - builder.create(loc, builder.getI32Type(), indexUpdate); + builder.create(loc, indexScatter.getType(), indexUpdate); Value addValue = builder.create(loc, indexScatter, indexUpdateCasted); Value addValueCasted = From 6ff0e69acb6fd37f91e812f1ead80e318d2043db Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Thu, 21 Dec 2023 13:00:32 +0000 Subject: [PATCH 05/14] Add array assigning test --- frontend/test/pytest/test_jax_dynamic_api.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_jax_dynamic_api.py b/frontend/test/pytest/test_jax_dynamic_api.py index 4af560f757..3c520c178d 100644 --- a/frontend/test/pytest/test_jax_dynamic_api.py +++ b/frontend/test/pytest/test_jax_dynamic_api.py @@ -387,7 +387,7 @@ def i(x): assert _id0 == _id1 -def test_indexing(): +def test_array_indexing(): """Test the support of indexing of dynamically-shaped arrays""" @qjit @@ -399,5 +399,19 @@ def fun(sz, idx): assert res == 1 +def test_array_assigning(): + """Test the support of assigning a value to a dynamically-shaped array""" + + @qjit + def fun(sz, idx, val): + r = jnp.ones((sz, 3, sz), dtype=int) + r = r.at[idx, 0, idx].set(val) + return r + + result = fun(5, 2, 33) + expected = jnp.ones((5, 3, 5), dtype=int).at[2, 0, 2].set(33) + assert_array_and_dtype_equal(result, expected) + + if __name__ == "__main__": pytest.main(["-x", __file__]) From 72fcbe42e2394ba2140b959866c9520a898cce12 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Fri, 22 Dec 2023 10:23:16 +0000 Subject: [PATCH 06/14] Link the upstream PR; add no cover pragma --- frontend/catalyst/utils/jax_extras.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/utils/jax_extras.py b/frontend/catalyst/utils/jax_extras.py index 302ff00c49..e046480a7a 100644 --- a/frontend/catalyst/utils/jax_extras.py +++ b/frontend/catalyst/utils/jax_extras.py @@ -475,13 +475,14 @@ def abstractify(args, kwargs): in_type = infer_lambda_input_type(axes_specs, flat_args) return in_type, in_tree + # TODO: See the `_gather_shape_rule_dynamic` comment. Remove once the upstream change is + # applied. gather2_p = standard_primitive( _gather_shape_rule_dynamic, _gather_dtype_rule, "gather", weak_type_rule=_argnum_weak_type(0), ) - register_lowering(gather2_p, _gather_lower) @wraps(fun) @@ -587,6 +588,19 @@ def _gather_shape_rule_dynamic(operand, indices, *, dimension_numbers, f"{collapsed_slice_dims}." ) + # This section contains a patch suggested to the upstream. + for i in range(len(slice_sizes)): + slice_size = slice_sizes[i] + corresponding_input_size = operand.shape[i] + + if jax.core.is_constant_dim(corresponding_input_size): + if not (slice_size >= 0 and corresponding_input_size >= slice_size): + raise TypeError( + f"Slice size at index {i} in gather op is out of range, " + f"must be within [0, {corresponding_input_size} + 1), " + f"got {slice_size}." + ) + for i in range(len(collapsed_slice_dims)): bound = slice_sizes[collapsed_slice_dims[i]] if bound != 1: From ae142146a0676b07556fe8d8a623ed6e99a8c6f7 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Fri, 22 Dec 2023 11:37:01 +0000 Subject: [PATCH 07/14] No test cover for a Jax patched function --- frontend/catalyst/utils/jax_extras.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/utils/jax_extras.py b/frontend/catalyst/utils/jax_extras.py index e046480a7a..2732a7ddde 100644 --- a/frontend/catalyst/utils/jax_extras.py +++ b/frontend/catalyst/utils/jax_extras.py @@ -507,9 +507,10 @@ def make_jaxpr_f(*args, **kwargs): def _gather_shape_rule_dynamic(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): - offset_dims = dimension_numbers.offset_dims - collapsed_slice_dims = dimension_numbers.collapsed_slice_dims - start_index_map = dimension_numbers.start_index_map + + offset_dims = dimension_numbers.offset_dims + collapsed_slice_dims = dimension_numbers.collapsed_slice_dims + start_index_map = dimension_numbers.start_index_map # Note: in JAX, index_vector_dim is always computed as below, cf. the # documentation of the GatherDimensionNumbers class. From 22a4ed348448153c738b7daaf80b10544be3f16a Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Fri, 22 Dec 2023 11:59:30 +0000 Subject: [PATCH 08/14] Test array assignment in a loop --- frontend/test/pytest/test_jax_dynamic_api.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_jax_dynamic_api.py b/frontend/test/pytest/test_jax_dynamic_api.py index 3c520c178d..861c362a93 100644 --- a/frontend/test/pytest/test_jax_dynamic_api.py +++ b/frontend/test/pytest/test_jax_dynamic_api.py @@ -399,7 +399,7 @@ def fun(sz, idx): assert res == 1 -def test_array_assigning(): +def test_array_assignment(): """Test the support of assigning a value to a dynamically-shaped array""" @qjit @@ -413,5 +413,22 @@ def fun(sz, idx, val): assert_array_and_dtype_equal(result, expected) +def test_qjit_forloop_array_assignment(): + """Test the array assignment in a loop""" + + @qjit + def fun(sz): + @for_loop(0, sz, 1) + def loop(i, a): + a = a.at[i].set(i) + return a + + return loop(jnp.zeros([sz], dtype=int)) + + result = fun(5) + expected = jnp.array((0, 1, 2, 3, 4), dtype=int) + assert_array_and_dtype_equal(result, expected) + + if __name__ == "__main__": pytest.main(["-x", __file__]) From ce9a63f54aa1f4cbc985e67286617e1c9626a4a7 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Fri, 22 Dec 2023 12:05:25 +0000 Subject: [PATCH 09/14] Update the changelog --- doc/changelog.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/changelog.md b/doc/changelog.md index 1d78cc3ee3..d64931efd8 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -2,6 +2,9 @@

New features

+* Catalyst supports indexing and assignments of the dynamically-shaped arrays. + [(#411)](https://github.com/PennyLaneAI/catalyst/pull/411) + * Error mitigation using the zero-noise extrapolation method is now available through the `catalyst.mitigate_with_zne` transform. From 75a6d4cd2b39fbef9633f5646499534b60808889 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Mon, 8 Jan 2024 11:39:40 +0000 Subject: [PATCH 10/14] Rebased onto current main --- frontend/catalyst/utils/jax_extras.py | 16 +++++++++++----- frontend/test/pytest/test_jax_dynamic_api.py | 17 ----------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/frontend/catalyst/utils/jax_extras.py b/frontend/catalyst/utils/jax_extras.py index 2732a7ddde..c119072863 100644 --- a/frontend/catalyst/utils/jax_extras.py +++ b/frontend/catalyst/utils/jax_extras.py @@ -49,7 +49,6 @@ from jax._src.source_info_util import current as jax_current from jax._src.source_info_util import new_name_stack from jax._src.util import partition_list, safe_map, unzip2, unzip3, wrap_name, wraps -from jax._src.util import safe_map, unzip2, wrap_name, wraps from jax.api_util import flatten_fun from jax.core import AbstractValue, ClosedJaxpr, Jaxpr, JaxprEqn, MainTrace, OutputType from jax.core import Primitive as JaxprPrimitive @@ -504,10 +503,17 @@ def make_jaxpr_f(*args, **kwargs): return make_jaxpr_f -def _gather_shape_rule_dynamic(operand, indices, *, dimension_numbers, - slice_sizes, unique_indices, indices_are_sorted, - mode, fill_value): - +def _gather_shape_rule_dynamic( + operand, + indices, + *, + dimension_numbers, + slice_sizes, + unique_indices, + indices_are_sorted, + mode, + fill_value, +): offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims start_index_map = dimension_numbers.start_index_map diff --git a/frontend/test/pytest/test_jax_dynamic_api.py b/frontend/test/pytest/test_jax_dynamic_api.py index 861c362a93..89120f3bd0 100644 --- a/frontend/test/pytest/test_jax_dynamic_api.py +++ b/frontend/test/pytest/test_jax_dynamic_api.py @@ -413,22 +413,5 @@ def fun(sz, idx, val): assert_array_and_dtype_equal(result, expected) -def test_qjit_forloop_array_assignment(): - """Test the array assignment in a loop""" - - @qjit - def fun(sz): - @for_loop(0, sz, 1) - def loop(i, a): - a = a.at[i].set(i) - return a - - return loop(jnp.zeros([sz], dtype=int)) - - result = fun(5) - expected = jnp.array((0, 1, 2, 3, 4), dtype=int) - assert_array_and_dtype_equal(result, expected) - - if __name__ == "__main__": pytest.main(["-x", __file__]) From 1599b2419f66442a7e718c6c0f4237341136628a Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Mon, 8 Jan 2024 11:44:28 +0000 Subject: [PATCH 11/14] Add jax copyright notice --- frontend/catalyst/utils/jax_extras.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/frontend/catalyst/utils/jax_extras.py b/frontend/catalyst/utils/jax_extras.py index c119072863..691b1a6bf4 100644 --- a/frontend/catalyst/utils/jax_extras.py +++ b/frontend/catalyst/utils/jax_extras.py @@ -514,6 +514,15 @@ def _gather_shape_rule_dynamic( mode, fill_value, ): + """Validates the well-formedness of the arguments to Gather. Compared to the original version, + this implementation skips static shape checks if variable dimensions are used. + + This function has been modified from its original form in the JAX project at + https://github.com/google/jax/blob/88a60b808c1f91260cc9e75b9aa2508aae5bc9f9/jax/_src/lax/slicing.py#L1438 + version released under the Apache License, Version 2.0, with the following copyright notice: + + Copyright 2021 The JAX Authors. + """ offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims start_index_map = dimension_numbers.start_index_map From 232656aaf121a2d0a90325057a36a6131481bd39 Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Mon, 8 Jan 2024 11:48:12 +0000 Subject: [PATCH 12/14] Skip coverage for patched Jax function --- frontend/catalyst/utils/jax_extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/utils/jax_extras.py b/frontend/catalyst/utils/jax_extras.py index 691b1a6bf4..db2e62af4b 100644 --- a/frontend/catalyst/utils/jax_extras.py +++ b/frontend/catalyst/utils/jax_extras.py @@ -513,7 +513,7 @@ def _gather_shape_rule_dynamic( indices_are_sorted, mode, fill_value, -): +): # pragma: no cover """Validates the well-formedness of the arguments to Gather. Compared to the original version, this implementation skips static shape checks if variable dimensions are used. From f1bc1304e742e987a8f847ddb45ccafb1f880e5c Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Mon, 8 Jan 2024 11:52:04 +0000 Subject: [PATCH 13/14] Address pylint issues --- frontend/catalyst/utils/jax_extras.py | 6 +++++- frontend/test/lit/test_jax_dynamic_api.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/utils/jax_extras.py b/frontend/catalyst/utils/jax_extras.py index db2e62af4b..6aa1ec000b 100644 --- a/frontend/catalyst/utils/jax_extras.py +++ b/frontend/catalyst/utils/jax_extras.py @@ -513,7 +513,7 @@ def _gather_shape_rule_dynamic( indices_are_sorted, mode, fill_value, -): # pragma: no cover +): # pragma: no cover """Validates the well-formedness of the arguments to Gather. Compared to the original version, this implementation skips static shape checks if variable dimensions are used. @@ -523,6 +523,10 @@ def _gather_shape_rule_dynamic( Copyright 2021 The JAX Authors. """ + # pylint: diable=unused-argument + # pylint: diable=too-many-branches + # pylint: diable=consider-using-enumerate + # pylint: diable=chained-comparison offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims start_index_map = dimension_numbers.start_index_map diff --git a/frontend/test/lit/test_jax_dynamic_api.py b/frontend/test/lit/test_jax_dynamic_api.py index 9ee1d1e952..79e3e9f3f0 100644 --- a/frontend/test/lit/test_jax_dynamic_api.py +++ b/frontend/test/lit/test_jax_dynamic_api.py @@ -121,6 +121,7 @@ def test_qjit_aot(a: ShapedArray([1, 3, 1], dtype=float)): @qjit def test_qjit_indexing(sz): + """Check the usage of stablehlo.gather for indexing""" r = jnp.ones((sz + 1,), dtype=int) # CHECK: gather return r[0] From ae82b8bc75387dd6d8bd43180c8149090c2530fe Mon Sep 17 00:00:00 2001 From: Sergei Mironov Date: Mon, 8 Jan 2024 11:52:58 +0000 Subject: [PATCH 14/14] Address pylint issues --- frontend/catalyst/utils/jax_extras.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/frontend/catalyst/utils/jax_extras.py b/frontend/catalyst/utils/jax_extras.py index 6aa1ec000b..1ff31abeb6 100644 --- a/frontend/catalyst/utils/jax_extras.py +++ b/frontend/catalyst/utils/jax_extras.py @@ -523,10 +523,10 @@ def _gather_shape_rule_dynamic( Copyright 2021 The JAX Authors. """ - # pylint: diable=unused-argument - # pylint: diable=too-many-branches - # pylint: diable=consider-using-enumerate - # pylint: diable=chained-comparison + # pylint: disable=unused-argument + # pylint: disable=too-many-branches + # pylint: disable=consider-using-enumerate + # pylint: disable=chained-comparison offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims start_index_map = dimension_numbers.start_index_map