Skip to content

Commit

Permalink
Handle dtype/casting/out properly in contractions (#402)
Browse files Browse the repository at this point in the history
* Handle dtype and casting args in generic contraction op

* Pass correct dtype/casting for concrete contraction ops

* Support dtype= on contraction/einsum test runner

* Add typed out= and/or dtype= cases to contraction/einsum tests
  • Loading branch information
manopapad authored Jun 15, 2022
1 parent 2a00f02 commit b8d4fde
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 79 deletions.
1 change: 1 addition & 0 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2308,6 +2308,7 @@ def dot(self, rhs, out=None) -> ndarray:
self,
rhs,
out=out,
casting="no",
)

def dump(self, file):
Expand Down
142 changes: 106 additions & 36 deletions cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2659,7 +2659,15 @@ def inner(a, b, out=None):
if a.ndim == 0 or b.ndim == 0:
return multiply(a, b, out=out)
(a_modes, b_modes, out_modes) = inner_modes(a.ndim, b.ndim)
return _contract(a_modes, b_modes, out_modes, a, b, out=out)
return _contract(
a_modes,
b_modes,
out_modes,
a,
b,
out=out,
casting="unsafe",
)


@add_boilerplate("a", "b")
Expand Down Expand Up @@ -2693,9 +2701,8 @@ def dot(a, b, out=None) -> ndarray:
b : array_like
Second argument.
out : ndarray, optional
Output argument. This must have the exact shape that would be returned
if it was not present. If its dtype is not what would be expected from
this operation, then the result will be (unsafely) cast to `out`.
Output argument. This must have the exact shape and dtype that would be
returned if it was not present.
Returns
-------
Expand All @@ -2722,25 +2729,38 @@ def dot(a, b, out=None) -> ndarray:


@add_boilerplate("a", "b")
def matmul(a, b, out=None):
def matmul(a, b, /, out=None, *, casting="same_kind", dtype=None):
"""
Matrix product of two arrays.
Parameters
----------
x1, x2 : array_like
a, b : array_like
Input arrays, scalars not allowed.
out : ndarray, optional
A location into which the result is stored. If provided, it must have
a shape that matches the signature `(n,k),(k,m)->(n,m)`. If its dtype
is not what would be expected from this operation, then the result will
be (unsafely) cast to `out`.
a shape that matches the signature `(n,k),(k,m)->(n,m)`.
casting : ``{'no', 'equiv', 'safe', 'same_kind', 'unsafe'}``, optional
Controls what kind of data casting may occur.
* 'no' means the data types should not be cast at all.
* 'equiv' means only byte-order changes are allowed.
* 'safe' means only casts which can preserve values are allowed.
* 'same_kind' means only safe casts or casts within a kind,
like float64 to float32, are allowed.
* 'unsafe' means any data conversions may be done.
Default is 'same_kind'.
dtype : data-type, optional
If provided, forces the calculation to use the data type specified.
Note that you may have to also give a more liberal `casting`
parameter to allow the conversions. Default is None.
Returns
-------
output : ndarray
The matrix product of the inputs.
This is a scalar only when both x1, x2 are 1-d vectors.
This is a scalar only when both a, b are 1-d vectors.
If `out` is given, then it is returned.
Notes
Expand Down Expand Up @@ -2789,7 +2809,16 @@ def matmul(a, b, out=None):
if a.ndim == 0 or b.ndim == 0:
raise ValueError("Scalars not allowed in matmul")
(a_modes, b_modes, out_modes) = matmul_modes(a.ndim, b.ndim)
return _contract(a_modes, b_modes, out_modes, a, b, out=out)
return _contract(
a_modes,
b_modes,
out_modes,
a,
b,
out=out,
casting=casting,
dtype=dtype,
)


@add_boilerplate("a", "b")
Expand Down Expand Up @@ -2933,7 +2962,15 @@ def tensordot(a, b, axes=2, out=None):
Multiple GPUs, Multiple CPUs
"""
(a_modes, b_modes, out_modes) = tensordot_modes(a.ndim, b.ndim, axes)
return _contract(a_modes, b_modes, out_modes, a, b, out=out)
return _contract(
a_modes,
b_modes,
out_modes,
a,
b,
out=out,
casting="unsafe",
)


# Trivial multi-tensor contraction strategy: contract in input order
Expand All @@ -2942,15 +2979,27 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
return [(0, 1)] + [(0, -1)] * (len(inputs) - 2)


def _maybe_cast_input(arr, to_dtype, casting):
if arr is None or arr.dtype == to_dtype:
return arr
if not np.can_cast(arr.dtype, to_dtype, casting=casting):
raise TypeError(
f"Cannot cast input array of type {arr.dtype} to {to_dtype} with "
f"casting rule '{casting}'"
)
return arr.astype(to_dtype)


# Generalized tensor contraction
@add_boilerplate("a", "b")
def _contract(
a_modes,
b_modes,
out_modes,
a,
b=None,
out=None,
casting="same_kind",
dtype=None,
):
# Sanity checks
if len(a_modes) != a.ndim:
Expand All @@ -2973,6 +3022,19 @@ def _contract(
if len(set(out_modes) - set(a_modes) - set(b_modes)) > 0:
raise ValueError("Unknown mode labels on output")

# Handle types
if dtype is not None:
c_dtype = dtype
elif out is not None:
c_dtype = out.dtype
elif b is None:
c_dtype = a.dtype
else:
c_dtype = ndarray.find_common_type(a, b)
a = _maybe_cast_input(a, c_dtype, casting)
b = _maybe_cast_input(b, c_dtype, casting)
out_dtype = out.dtype if out is not None else c_dtype

# Handle duplicate modes on inputs
c_a_modes = Counter(a_modes)
for (mode, count) in c_a_modes.items():
Expand Down Expand Up @@ -3066,10 +3128,6 @@ def _contract(
a = a * b
b = None

# Handle types
c_dtype = ndarray.find_common_type(a, b) if b is not None else a.dtype
out_dtype = out.dtype if out is not None else c_dtype

if b is None:
# Unary contraction case
assert len(a_modes) == len(c_modes) and set(a_modes) == set(c_modes)
Expand All @@ -3095,23 +3153,6 @@ def _contract(
dtype=c_dtype,
inputs=(a, b),
)
# Check for type conversion on the way in
if a.dtype != c.dtype:
temp = ndarray(
shape=a.shape,
dtype=c.dtype,
inputs=(a,),
)
temp._thunk.convert(a._thunk)
a = temp
if b.dtype != c.dtype:
temp = ndarray(
shape=b.shape,
dtype=c.dtype,
inputs=(b,),
)
temp._thunk.convert(b._thunk)
b = temp
# Perform operation
c._thunk.contract(
c_modes,
Expand All @@ -3129,8 +3170,18 @@ def _contract(
if out_dtype != c_dtype or out_shape != c_bloated_shape:
# We need to broadcast the result of the contraction or switch types
# before returning
if not np.can_cast(c_dtype, out_dtype, casting=casting):
raise TypeError(
f"Cannot cast intermediate result array of type {c_dtype} "
f"into output array of type {out_dtype} with casting rule "
f"'{casting}'"
)
if out is None:
out = empty(out_shape, out_dtype)
out = ndarray(
shape=out_shape,
dtype=out_dtype,
inputs=(c,),
)
out[...] = c.reshape(c_bloated_shape)
return out
if out_shape != c_shape:
Expand All @@ -3150,7 +3201,9 @@ def _contract(
return c


def einsum(expr, *operands, out=None, optimize=False):
def einsum(
expr, *operands, out=None, dtype=None, casting="safe", optimize=False
):
"""
Evaluates the Einstein summation convention on the operands.
Expand All @@ -3174,6 +3227,21 @@ def einsum(expr, *operands, out=None, optimize=False):
These are the arrays for the operation.
out : ndarray, optional
If provided, the calculation is done into this array.
dtype : data-type, optional
If provided, forces the calculation to use the data type specified.
Note that you may have to also give a more liberal `casting`
parameter to allow the conversions. Default is None.
casting : ``{'no', 'equiv', 'safe', 'same_kind', 'unsafe'}``, optional
Controls what kind of data casting may occur.
* 'no' means the data types should not be cast at all.
* 'equiv' means only byte-order changes are allowed.
* 'safe' means only casts which can preserve values are allowed.
* 'same_kind' means only safe casts or casts within a kind,
like float64 to float32, are allowed.
* 'unsafe' means any data conversions may be done.
Default is 'safe'.
optimize : ``{False, True, 'greedy', 'optimal'}``, optional
Controls if intermediate optimization should occur. No optimization
will occur if False. Uses opt_einsum to find an optimized contraction
Expand Down Expand Up @@ -3232,6 +3300,8 @@ def einsum(expr, *operands, out=None, optimize=False):
a,
b,
out=(out if len(operands) == 0 else None),
casting=casting,
dtype=dtype,
)
operands.append(sub_result)
assert len(operands) == 1
Expand Down
11 changes: 1 addition & 10 deletions tests/integration/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@
#
import pytest
from cunumeric.utils import dot_modes
from test_tools.contractions import (
check_default,
check_permutations,
check_shapes,
check_types,
)
from test_tools.contractions import check_default

from legate.core import LEGATE_MAX_DIM

Expand All @@ -34,10 +29,6 @@ def operation(lib, *args, **kwargs):
return lib.dot(*args, **kwargs)

check_default(name, modes, operation)
check_permutations(name, modes, operation)
check_shapes(name, modes, operation)
if a_ndim <= 2 and b_ndim <= 2:
check_types(name, modes, operation)


if __name__ == "__main__":
Expand Down
33 changes: 15 additions & 18 deletions tests/integration/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,11 @@ def mk_typed_input(lib, shape):
def mk_typed_output(lib, shape):
return [
lib.zeros(shape, np.float16),
lib.zeros(shape, np.float32),
lib.zeros(shape, np.complex64),
]


def check_np_vs_cn(expr, mk_input, mk_output=None):
def check_np_vs_cn(expr, mk_input, mk_output=None, **kwargs):
lhs, rhs = expr.split("->")
opers = lhs.split(",")
in_shapes = [
Expand All @@ -224,41 +223,39 @@ def check_np_vs_cn(expr, mk_input, mk_output=None):
product(*(mk_input(np, sh) for sh in in_shapes)),
product(*(mk_input(cn, sh) for sh in in_shapes)),
):
np_res = np.einsum(expr, *np_inputs)
cn_res = cn.einsum(expr, *cn_inputs)
np_res = np.einsum(expr, *np_inputs, **kwargs)
cn_res = cn.einsum(expr, *cn_inputs, **kwargs)
rtol = (
1e-02 if any(x.dtype == np.float16 for x in np_inputs) else 1e-05
1e-02
if any(x.dtype == np.float16 for x in np_inputs)
or kwargs.get("dtype") == np.float16
else 1e-05
)
assert np.allclose(np_res, cn_res, rtol=rtol)
if mk_output is not None:
for cn_out in mk_output(cn, out_shape):
cn.einsum(expr, *cn_inputs, out=cn_out)
rtol = (
1e-02
if any(x.dtype == np.float16 for x in np_inputs)
or cn_out.dtype == np.float16
else 1e-05
)
assert np.allclose(cn_out, cn_res, rtol=rtol)
cn.einsum(expr, *cn_inputs, out=cn_out, **kwargs)
rtol_out = 1e-02 if cn_out.dtype == np.float16 else rtol
assert np.allclose(cn_out, cn_res, rtol=rtol_out)


@pytest.mark.parametrize("expr", gen_expr())
def test_small(expr):
print(f"Test small expressions (permutations and broadcasting): {expr}")
check_np_vs_cn(expr, mk_input_that_permutes_to)
check_np_vs_cn(expr, mk_input_that_broadcasts_to)


@pytest.mark.parametrize("expr", LARGE_EXPRS)
def test_large(expr):
print(f"Test large expressions (default execution only): {expr}")
check_np_vs_cn(expr, mk_input_default)


@pytest.mark.parametrize("expr", SMALL_EXPRS)
def test_cast(expr):
print(f"Test casting: {expr}")
check_np_vs_cn(expr, mk_typed_input, mk_typed_output)
@pytest.mark.parametrize("dtype", [None, np.float32])
def test_cast(expr, dtype):
check_np_vs_cn(
expr, mk_typed_input, mk_typed_output, dtype=dtype, casting="unsafe"
)


if __name__ == "__main__":
Expand Down
11 changes: 10 additions & 1 deletion tests/integration/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@

import pytest
from cunumeric.utils import matmul_modes
from test_tools.contractions import check_default
from test_tools.contractions import (
check_default,
check_permutations,
check_shapes,
check_types,
)

from legate.core import LEGATE_MAX_DIM

Expand All @@ -30,6 +35,10 @@ def operation(lib, *args, **kwargs):
return lib.matmul(*args, **kwargs)

check_default(name, modes, operation)
check_permutations(name, modes, operation)
check_shapes(name, modes, operation)
if a_ndim <= 2 and b_ndim <= 2:
check_types(name, modes, operation)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit b8d4fde

Please sign in to comment.