Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding support for array shape () passed as an index argument in advanced indexing #486

Merged
merged 9 commits into from
Aug 4, 2022
7 changes: 6 additions & 1 deletion cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,10 @@ def _zip_indices(
# NumPy array.
N = self.ndim
pointN_dtype = self.runtime.get_point_type(N)
# if scalar array is passed as an argument, make the output
# shape be (1,)
if out_shape == ():
out_shape = (1,)
store = self.context.create_store(
pointN_dtype, shape=out_shape, optimize_scalar=True
)
Expand All @@ -463,7 +467,8 @@ def _zip_indices(
task.add_scalar_arg(self.shape, (ty.int64,))
for a in arrays:
task.add_input(a)
task.add_alignment(output_arr.base, a)
if a.shape != ():
task.add_alignment(output_arr.base, a)
manopapad marked this conversation as resolved.
Show resolved Hide resolved
task.execute()

return output_arr
Expand Down
10 changes: 7 additions & 3 deletions src/cunumeric/index/zip_template.inl
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ struct ZipImpl {
template <VariantKind KIND>
static void zip_template(TaskContext& context)
{
// Here `N` is the number of dimenstions of the input array and the number
// Here `N` is the number of dimensions of the input array and the number
// of dimensions of the Point<N> field
// key_dim - is the number of dimensions of the index arrays before
// they were broadcasted to the shape of the input array (shape of
// all index arrays should be the same))
// start index - is the index from wich first index array was passed
// start index - is the index from which first index array was passed
// DIM - dimension of the output array
//
// for the example:
Expand All @@ -95,7 +95,11 @@ static void zip_template(TaskContext& context)
int64_t start_index = context.scalars()[2].value<int64_t>();
auto shape = context.scalars()[3].value<DomainPoint>();
ZipArgs args{context.outputs()[0], context.inputs(), N, key_dim, start_index, shape};
double_dispatch(args.inputs[0].dim(), N, ZipImpl<KIND>{}, args);
int dim = args.inputs[0].dim();
// if scalar passed as an input, convert it to the array size 1
if (dim == 0) { dim = 1; }

double_dispatch(dim, N, ZipImpl<KIND>{}, args);
}

} // namespace cunumeric
4 changes: 2 additions & 2 deletions tests/integration/test_set_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def test_basic():
def test_scalar_ndarray_as_index(arr):
offsets = num.arange(5) # [0, 1, 2, 3, 4]
offset = offsets[3] # 3
# arr[offset] = -1 # TODO: doesn't work when arr is a num.ndarray
arr[offset] = -1
arr[offset - 2 : offset] = [-1, -1]
assert np.array_equal(arr, [4, -1, -1, 1, 0])
assert np.array_equal(arr, [4, -1, -1, -1, 0])


manopapad marked this conversation as resolved.
Show resolved Hide resolved
if __name__ == "__main__":
Expand Down