diff --git a/src/operator/tensor/control_flow_op.h b/src/operator/tensor/control_flow_op.h index 94e65109c358..e9aa9f63faec 100644 --- a/src/operator/tensor/control_flow_op.h +++ b/src/operator/tensor/control_flow_op.h @@ -189,6 +189,7 @@ inline bool WhereOpShape(const nnvm::NodeAttrs& attrs, return true; } else if ((*in_attrs)[0].ndim() == 1) { CHECK_EQ((*in_attrs)[0].Size(), static_cast(tshape[0])); + return true; } return false; } diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index fd60611add8c..d0bc450415e9 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4515,6 +4515,14 @@ def test_invalid_shape(): y=mx.nd.array([[8,9],[10,11],[12,13]]), condition=mx.nd.array([1,0])), MXNetError) + def test_1d_cond(): + cond = mx.nd.array([1, 0, 1]) + x = mx.nd.array([[2, 3], [4, 5], [6, 7]]) + y = mx.nd.array([[7, 8], [9, 10], [10, 11]]) + expect_out = np.array([[2, 3], [9, 10], [6, 7]]) + out = mx.nd.where(cond, x, y).asnumpy() + assert(expect_out.all() == out.all()) + test_where_helper((5, 9), True) test_where_helper((5, 9), False) test_where_helper((5, 7, 9), True) @@ -4526,6 +4534,7 @@ def test_invalid_shape(): test_where_numeric_gradient((5, 7, 9), True) test_where_numeric_gradient((5, 7, 9), False) test_invalid_shape() + test_1d_cond() @with_seed() def test_new_softmax():