From 0997eb487cdb00ef7b419cfe9be5fe383beed9cc Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Mon, 13 Jun 2022 14:35:23 -0700 Subject: [PATCH 1/4] Handle dtype and casting args in generic contraction op --- cunumeric/module.py | 60 ++++++++++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/cunumeric/module.py b/cunumeric/module.py index 4b9d028dd..93fedc2cd 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -2941,8 +2941,18 @@ def __call__(self, inputs, output, size_dict, memory_limit=None): return [(0, 1)] + [(0, -1)] * (len(inputs) - 2) +def _maybe_cast_input(arr, to_dtype, casting): + if arr is None or arr.dtype == to_dtype: + return arr + if not np.can_cast(arr.dtype, to_dtype, casting=casting): + raise TypeError( + f"Cannot cast input array of type {arr.dtype} to {to_dtype} with " + f"casting rule '{casting}'" + ) + return arr.astype(to_dtype) + + # Generalized tensor contraction -@add_boilerplate("a", "b") def _contract( a_modes, b_modes, @@ -2950,6 +2960,8 @@ def _contract( a, b=None, out=None, + casting="same_kind", + dtype=None, ): # Sanity checks if len(a_modes) != a.ndim: @@ -2972,6 +2984,19 @@ def _contract( if len(set(out_modes) - set(a_modes) - set(b_modes)) > 0: raise ValueError("Unknown mode labels on output") + # Handle types + if dtype is not None: + c_dtype = dtype + elif out is not None: + c_dtype = out.dtype + elif b is None: + c_dtype = a.dtype + else: + c_dtype = ndarray.find_common_type(a, b) + a = _maybe_cast_input(a, c_dtype, casting) + b = _maybe_cast_input(b, c_dtype, casting) + out_dtype = out.dtype if out is not None else c_dtype + # Handle duplicate modes on inputs c_a_modes = Counter(a_modes) for (mode, count) in c_a_modes.items(): @@ -3065,10 +3090,6 @@ def _contract( a = a * b b = None - # Handle types - c_dtype = ndarray.find_common_type(a, b) if b is not None else a.dtype - out_dtype = out.dtype if out is not None else c_dtype - if b is None: # Unary contraction case assert len(a_modes) == len(c_modes) and set(a_modes) == set(c_modes) @@ -3094,23 +3115,6 @@ def _contract( dtype=c_dtype, inputs=(a, b), ) - # Check for type conversion on the way in - if a.dtype != c.dtype: - temp = ndarray( - shape=a.shape, - dtype=c.dtype, - inputs=(a,), - ) - temp._thunk.convert(a._thunk) - a = temp - if b.dtype != c.dtype: - temp = ndarray( - shape=b.shape, - dtype=c.dtype, - inputs=(b,), - ) - temp._thunk.convert(b._thunk) - b = temp # Perform operation c._thunk.contract( c_modes, @@ -3128,8 +3132,18 @@ def _contract( if out_dtype != c_dtype or out_shape != c_bloated_shape: # We need to broadcast the result of the contraction or switch types # before returning + if not np.can_cast(c_dtype, out_dtype, casting=casting): + raise TypeError( + f"Cannot cast intermediate result array of type {c_dtype} " + f"into output array of type {out_dtype} with casting rule " + f"'{casting}'" + ) if out is None: - out = empty(out_shape, out_dtype) + out = ndarray( + shape=out_shape, + dtype=out_dtype, + inputs=(c,), + ) out[...] = c.reshape(c_bloated_shape) return out if out_shape != c_shape: From 22d668aad65e943fd1e17e6ac9148790caf61c4f Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Mon, 13 Jun 2022 14:37:39 -0700 Subject: [PATCH 2/4] Pass correct dtype/casting for concrete contraction ops --- cunumeric/array.py | 1 + cunumeric/module.py | 82 ++++++++++++++++++++++++++++++++++++++------- 2 files changed, 70 insertions(+), 13 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index 43235ce0a..f9e8fd47d 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -2301,6 +2301,7 @@ def dot(self, rhs, out=None): self, rhs, out=out, + casting="no", ) def dump(self, file): diff --git a/cunumeric/module.py b/cunumeric/module.py index 93fedc2cd..2e07f5413 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -2658,7 +2658,15 @@ def inner(a, b, out=None): if a.ndim == 0 or b.ndim == 0: return multiply(a, b, out=out) (a_modes, b_modes, out_modes) = inner_modes(a.ndim, b.ndim) - return _contract(a_modes, b_modes, out_modes, a, b, out=out) + return _contract( + a_modes, + b_modes, + out_modes, + a, + b, + out=out, + casting="unsafe", + ) @add_boilerplate("a", "b") @@ -2692,9 +2700,8 @@ def dot(a, b, out=None): b : array_like Second argument. out : ndarray, optional - Output argument. This must have the exact shape that would be returned - if it was not present. If its dtype is not what would be expected from - this operation, then the result will be (unsafely) cast to `out`. + Output argument. This must have the exact shape and dtype that would be + returned if it was not present. Returns ------- @@ -2721,25 +2728,38 @@ def dot(a, b, out=None): @add_boilerplate("a", "b") -def matmul(a, b, out=None): +def matmul(a, b, /, out=None, *, casting="same_kind", dtype=None): """ Matrix product of two arrays. Parameters ---------- - x1, x2 : array_like + a, b : array_like Input arrays, scalars not allowed. out : ndarray, optional A location into which the result is stored. If provided, it must have - a shape that matches the signature `(n,k),(k,m)->(n,m)`. If its dtype - is not what would be expected from this operation, then the result will - be (unsafely) cast to `out`. + a shape that matches the signature `(n,k),(k,m)->(n,m)`. + casting : ``{'no', 'equiv', 'safe', 'same_kind', 'unsafe'}``, optional + Controls what kind of data casting may occur. + + * 'no' means the data types should not be cast at all. + * 'equiv' means only byte-order changes are allowed. + * 'safe' means only casts which can preserve values are allowed. + * 'same_kind' means only safe casts or casts within a kind, + like float64 to float32, are allowed. + * 'unsafe' means any data conversions may be done. + + Default is 'same_kind'. + dtype : data-type, optional + If provided, forces the calculation to use the data type specified. + Note that you may have to also give a more liberal `casting` + parameter to allow the conversions. Default is None. Returns ------- output : ndarray The matrix product of the inputs. - This is a scalar only when both x1, x2 are 1-d vectors. + This is a scalar only when both a, b are 1-d vectors. If `out` is given, then it is returned. Notes @@ -2788,7 +2808,16 @@ def matmul(a, b, out=None): if a.ndim == 0 or b.ndim == 0: raise ValueError("Scalars not allowed in matmul") (a_modes, b_modes, out_modes) = matmul_modes(a.ndim, b.ndim) - return _contract(a_modes, b_modes, out_modes, a, b, out=out) + return _contract( + a_modes, + b_modes, + out_modes, + a, + b, + out=out, + casting=casting, + dtype=dtype, + ) @add_boilerplate("a", "b") @@ -2932,7 +2961,15 @@ def tensordot(a, b, axes=2, out=None): Multiple GPUs, Multiple CPUs """ (a_modes, b_modes, out_modes) = tensordot_modes(a.ndim, b.ndim, axes) - return _contract(a_modes, b_modes, out_modes, a, b, out=out) + return _contract( + a_modes, + b_modes, + out_modes, + a, + b, + out=out, + casting="unsafe", + ) # Trivial multi-tensor contraction strategy: contract in input order @@ -3163,7 +3200,9 @@ def _contract( return c -def einsum(expr, *operands, out=None, optimize=False): +def einsum( + expr, *operands, out=None, dtype=None, casting="safe", optimize=False +): """ Evaluates the Einstein summation convention on the operands. @@ -3187,6 +3226,21 @@ def einsum(expr, *operands, out=None, optimize=False): These are the arrays for the operation. out : ndarray, optional If provided, the calculation is done into this array. + dtype : data-type, optional + If provided, forces the calculation to use the data type specified. + Note that you may have to also give a more liberal `casting` + parameter to allow the conversions. Default is None. + casting : ``{'no', 'equiv', 'safe', 'same_kind', 'unsafe'}``, optional + Controls what kind of data casting may occur. + + * 'no' means the data types should not be cast at all. + * 'equiv' means only byte-order changes are allowed. + * 'safe' means only casts which can preserve values are allowed. + * 'same_kind' means only safe casts or casts within a kind, + like float64 to float32, are allowed. + * 'unsafe' means any data conversions may be done. + + Default is 'safe'. optimize : ``{False, True, 'greedy', 'optimal'}``, optional Controls if intermediate optimization should occur. No optimization will occur if False. Uses opt_einsum to find an optimized contraction @@ -3245,6 +3299,8 @@ def einsum(expr, *operands, out=None, optimize=False): a, b, out=(out if len(operands) == 0 else None), + casting=casting, + dtype=dtype, ) operands.append(sub_result) assert len(operands) == 1 From 255134563c324534f324ecc6b7bfe7a1472dfcb2 Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Mon, 13 Jun 2022 14:48:17 -0700 Subject: [PATCH 3/4] Support dtype= on contraction/einsum test runner --- tests/integration/test_einsum.py | 22 +++++++++----------- tests/integration/test_tools/contractions.py | 22 +++++++++----------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/tests/integration/test_einsum.py b/tests/integration/test_einsum.py index bd9caf3e8..f0270fa6a 100644 --- a/tests/integration/test_einsum.py +++ b/tests/integration/test_einsum.py @@ -213,7 +213,7 @@ def mk_typed_output(lib, shape): ] -def check_np_vs_cn(expr, mk_input, mk_output=None): +def check_np_vs_cn(expr, mk_input, mk_output=None, **kwargs): lhs, rhs = expr.split("->") opers = lhs.split(",") in_shapes = [ @@ -224,22 +224,20 @@ def check_np_vs_cn(expr, mk_input, mk_output=None): product(*(mk_input(np, sh) for sh in in_shapes)), product(*(mk_input(cn, sh) for sh in in_shapes)), ): - np_res = np.einsum(expr, *np_inputs) - cn_res = cn.einsum(expr, *cn_inputs) + np_res = np.einsum(expr, *np_inputs, **kwargs) + cn_res = cn.einsum(expr, *cn_inputs, **kwargs) rtol = ( - 1e-02 if any(x.dtype == np.float16 for x in np_inputs) else 1e-05 + 1e-02 + if any(x.dtype == np.float16 for x in np_inputs) + or kwargs.get("dtype") == np.float16 + else 1e-05 ) assert np.allclose(np_res, cn_res, rtol=rtol) if mk_output is not None: for cn_out in mk_output(cn, out_shape): - cn.einsum(expr, *cn_inputs, out=cn_out) - rtol = ( - 1e-02 - if any(x.dtype == np.float16 for x in np_inputs) - or cn_out.dtype == np.float16 - else 1e-05 - ) - assert np.allclose(cn_out, cn_res, rtol=rtol) + cn.einsum(expr, *cn_inputs, out=cn_out, **kwargs) + rtol_out = 1e-02 if cn_out.dtype == np.float16 else rtol + assert np.allclose(cn_out, cn_res, rtol=rtol_out) @pytest.mark.parametrize("expr", gen_expr()) diff --git a/tests/integration/test_tools/contractions.py b/tests/integration/test_tools/contractions.py index cd3972fec..9a0372ab0 100644 --- a/tests/integration/test_tools/contractions.py +++ b/tests/integration/test_tools/contractions.py @@ -92,7 +92,7 @@ def gen_inputs_of_various_types(lib, modes): ) -def _test(name, modes, operation, gen_inputs, gen_output=None): +def _test(name, modes, operation, gen_inputs, gen_output=None, **kwargs): (a_modes, b_modes, out_modes) = modes if len(set(a_modes) | set(b_modes) | set(out_modes)) > LEGATE_MAX_DIM: # Total number of distinct modes can't exceed maximum Legion dimension, @@ -102,22 +102,20 @@ def _test(name, modes, operation, gen_inputs, gen_output=None): for (np_inputs, cn_inputs) in zip( gen_inputs(np, modes), gen_inputs(cn, modes) ): - np_res = operation(np, *np_inputs) - cn_res = operation(cn, *cn_inputs) + np_res = operation(np, *np_inputs, **kwargs) + cn_res = operation(cn, *cn_inputs, **kwargs) rtol = ( - 2e-03 if any(x.dtype == np.float16 for x in np_inputs) else 1e-05 + 1e-02 + if any(x.dtype == np.float16 for x in np_inputs) + or kwargs.get("dtype") == np.float16 + else 1e-05 ) assert np.allclose(np_res, cn_res, rtol=rtol) if gen_output is not None: for cn_out in gen_output(cn, modes, *cn_inputs): - operation(cn, *cn_inputs, out=cn_out) - rtol = ( - 2e-03 - if any(x.dtype == np.float16 for x in np_inputs) - or cn_out.dtype == np.float16 - else 1e-05 - ) - assert np.allclose(cn_out, cn_res, rtol=rtol) + operation(cn, *cn_inputs, out=cn_out, **kwargs) + rtol_out = 1e-02 if cn_out.dtype == np.float16 else rtol + assert np.allclose(cn_out, cn_res, rtol=rtol_out) def check_default(name, modes, operation): From c80c07efbcd2a00dbf0d21c497b248102a226cf1 Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Mon, 13 Jun 2022 14:51:28 -0700 Subject: [PATCH 4/4] Add typed out= and/or dtype= cases to contraction/einsum tests --- tests/integration/test_dot.py | 11 +------- tests/integration/test_einsum.py | 11 ++++---- tests/integration/test_matmul.py | 11 +++++++- tests/integration/test_tools/contractions.py | 29 ++++++++++++++++++-- 4 files changed, 43 insertions(+), 19 deletions(-) diff --git a/tests/integration/test_dot.py b/tests/integration/test_dot.py index de606e5d6..eba753a80 100644 --- a/tests/integration/test_dot.py +++ b/tests/integration/test_dot.py @@ -14,12 +14,7 @@ # import pytest from cunumeric.utils import dot_modes -from test_tools.contractions import ( - check_default, - check_permutations, - check_shapes, - check_types, -) +from test_tools.contractions import check_default from legate.core import LEGATE_MAX_DIM @@ -34,10 +29,6 @@ def operation(lib, *args, **kwargs): return lib.dot(*args, **kwargs) check_default(name, modes, operation) - check_permutations(name, modes, operation) - check_shapes(name, modes, operation) - if a_ndim <= 2 and b_ndim <= 2: - check_types(name, modes, operation) if __name__ == "__main__": diff --git a/tests/integration/test_einsum.py b/tests/integration/test_einsum.py index f0270fa6a..156eb5b86 100644 --- a/tests/integration/test_einsum.py +++ b/tests/integration/test_einsum.py @@ -208,7 +208,6 @@ def mk_typed_input(lib, shape): def mk_typed_output(lib, shape): return [ lib.zeros(shape, np.float16), - lib.zeros(shape, np.float32), lib.zeros(shape, np.complex64), ] @@ -242,21 +241,21 @@ def check_np_vs_cn(expr, mk_input, mk_output=None, **kwargs): @pytest.mark.parametrize("expr", gen_expr()) def test_small(expr): - print(f"Test small expressions (permutations and broadcasting): {expr}") check_np_vs_cn(expr, mk_input_that_permutes_to) check_np_vs_cn(expr, mk_input_that_broadcasts_to) @pytest.mark.parametrize("expr", LARGE_EXPRS) def test_large(expr): - print(f"Test large expressions (default execution only): {expr}") check_np_vs_cn(expr, mk_input_default) @pytest.mark.parametrize("expr", SMALL_EXPRS) -def test_cast(expr): - print(f"Test casting: {expr}") - check_np_vs_cn(expr, mk_typed_input, mk_typed_output) +@pytest.mark.parametrize("dtype", [None, np.float32]) +def test_cast(expr, dtype): + check_np_vs_cn( + expr, mk_typed_input, mk_typed_output, dtype=dtype, casting="unsafe" + ) if __name__ == "__main__": diff --git a/tests/integration/test_matmul.py b/tests/integration/test_matmul.py index da4e37524..23ed607b8 100644 --- a/tests/integration/test_matmul.py +++ b/tests/integration/test_matmul.py @@ -15,7 +15,12 @@ import pytest from cunumeric.utils import matmul_modes -from test_tools.contractions import check_default +from test_tools.contractions import ( + check_default, + check_permutations, + check_shapes, + check_types, +) from legate.core import LEGATE_MAX_DIM @@ -30,6 +35,10 @@ def operation(lib, *args, **kwargs): return lib.matmul(*args, **kwargs) check_default(name, modes, operation) + check_permutations(name, modes, operation) + check_shapes(name, modes, operation) + if a_ndim <= 2 and b_ndim <= 2: + check_types(name, modes, operation) if __name__ == "__main__": diff --git a/tests/integration/test_tools/contractions.py b/tests/integration/test_tools/contractions.py index 9a0372ab0..267cb76eb 100644 --- a/tests/integration/test_tools/contractions.py +++ b/tests/integration/test_tools/contractions.py @@ -82,7 +82,6 @@ def gen_inputs_of_various_types(lib, modes): (np.float16, np.float32), (np.float32, np.float32), (np.complex64, np.complex64), - (np.complex128, np.complex128), ]: if lib == cn: print(f" {a_dtype} x {b_dtype}") @@ -92,6 +91,15 @@ def gen_inputs_of_various_types(lib, modes): ) +def gen_output_of_various_types(lib, modes, a, b): + (a_modes, b_modes, out_modes) = modes + out_shape = (5,) * len(out_modes) + for out_dtype in [np.float16, np.complex64]: + if lib == cn: + print(f" -> {out_dtype}") + yield lib.zeros(out_shape, out_dtype) + + def _test(name, modes, operation, gen_inputs, gen_output=None, **kwargs): (a_modes, b_modes, out_modes) = modes if len(set(a_modes) | set(b_modes) | set(out_modes)) > LEGATE_MAX_DIM: @@ -135,4 +143,21 @@ def check_permutations(name, modes, operation): def check_types(name, modes, operation): name = f"{name} -- various types" - _test(name, modes, operation, gen_inputs_of_various_types) + _test( + name, + modes, + operation, + gen_inputs_of_various_types, + gen_output_of_various_types, + casting="unsafe", + ) + name = f"{name} -- various types, dtype=np.float32" + _test( + name, + modes, + operation, + gen_inputs_of_various_types, + gen_output_of_various_types, + dtype=np.float32, + casting="unsafe", + )