diff --git a/tests/integration/test_matmul.py b/tests/integration/test_matmul.py index bee51fcd7..bedc62bcc 100644 --- a/tests/integration/test_matmul.py +++ b/tests/integration/test_matmul.py @@ -147,19 +147,23 @@ def test_out_invalid_shape_DIVERGENCE(self): out = num.zeros(shape) num.matmul(A, B, out=out) - def test_out_invalid_dtype(self): + @pytest.mark.parametrize( + ("dtype", "out_dtype", "casting"), + ((None, np.int64, "same_kind"), (float, str, "safe")), + ids=("direct", "intermediate"), + ) + def test_out_invalid_dtype(self, dtype, out_dtype, casting): expected_exc = TypeError - A_np = num.ones((3, 2, 4)) - B_np = num.ones((3, 4, 3)) + A_np = np.ones((3, 2, 4)) + B_np = np.ones((3, 4, 3)) A_num = num.ones((3, 2, 4)) B_num = num.ones((3, 4, 3)) - dtype = np.int64 - out_np = np.zeros((3, 2, 3), dtype=dtype) - out_num = num.zeros((3, 2, 3), dtype=dtype) + out_np = np.zeros((3, 2, 3), dtype=out_dtype) + out_num = num.zeros((3, 2, 3), dtype=out_dtype) with pytest.raises(expected_exc): - np.matmul(A_np, B_np, out=out_np) + np.matmul(A_np, B_np, dtype=dtype, out=out_np, casting=casting) with pytest.raises(expected_exc): - num.matmul(A_num, B_num, out=out_num) + num.matmul(A_num, B_num, dtype=dtype, out=out_num, casting=casting) @pytest.mark.parametrize( "casting_dtype", @@ -183,18 +187,20 @@ def test_invalid_casting_dtype(self, casting_dtype): with pytest.raises(expected_exc): num.matmul(A_num, B_num, casting=casting, dtype=dtype) - @pytest.mark.xfail - def test_invalid_casting(self): - # In Numpy, raise ValueError - # In cuNumeric, pass + @pytest.mark.parametrize( + "dtype", (str, pytest.param(float, marks=pytest.mark.xfail)), ids=str + ) + def test_invalid_casting(self, dtype): expected_exc = ValueError casting = "unknown" A_np = np.ones((2, 4)) - B_np = np.ones((4, 3)) + B_np = np.ones((4, 3), dtype=dtype) A_num = num.ones((2, 4)) - B_num = num.ones((4, 3)) + B_num = num.ones((4, 3), dtype=dtype) + # In Numpy, raise ValueError with pytest.raises(expected_exc): np.matmul(A_np, B_np, casting=casting) + # cuNumeric does not check casting when A and B are of the same dtype with pytest.raises(expected_exc): num.matmul(A_num, B_num, casting=casting)