diff --git a/src/cunumeric/index/putmask_template.inl b/src/cunumeric/index/putmask_template.inl index 3a85d2044..60ce6af0e 100644 --- a/src/cunumeric/index/putmask_template.inl +++ b/src/cunumeric/index/putmask_template.inl @@ -105,7 +105,8 @@ static void putmask_template(TaskContext& context) { auto& inputs = context.inputs(); PutmaskArgs args{context.outputs()[0], inputs[1], inputs[2]}; - double_dispatch(args.input.dim(), args.input.code(), PutmaskImpl{}, args); + int dim = std::max(1, args.input.dim()); + double_dispatch(dim, args.input.code(), PutmaskImpl{}, args); } } // namespace cunumeric diff --git a/tests/integration/test_putmask.py b/tests/integration/test_putmask.py index a59c40aba..91d6b7e90 100644 --- a/tests/integration/test_putmask.py +++ b/tests/integration/test_putmask.py @@ -67,6 +67,12 @@ def test_scalar(): num.putmask(x_num, mask_num, values_num[:1]) assert np.array_equal(x_num, x) + # the case when every input is a scalar + x = num.random.rand(3, 3) + s = x.sum() + num.putmask(s, True, 1.0) + assert s == 1.0 + def test_type_convert(): x = mk_seq_array(np, (3, 4, 5))