Skip to content

Commit

Permalink
Support of the shape argument in empty_like() & Co. (#845)
Browse files Browse the repository at this point in the history
* *_like() now takes a shape argument

* style

* test: fixed missing shape args
  • Loading branch information
madsbk authored Mar 21, 2023
1 parent 219e3ab commit e16bd14
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 17 deletions.
39 changes: 31 additions & 8 deletions cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def empty(shape: NdShapeLike, dtype: npt.DTypeLike = np.float64) -> ndarray:


@add_boilerplate("a")
def empty_like(a: ndarray, dtype: Optional[npt.DTypeLike] = None) -> ndarray:
def empty_like(
a: ndarray,
dtype: Optional[npt.DTypeLike] = None,
shape: Optional[NdShapeLike] = None,
) -> ndarray:
"""
empty_like(prototype, dtype=None)
Expand All @@ -108,6 +112,8 @@ def empty_like(a: ndarray, dtype: Optional[npt.DTypeLike] = None) -> ndarray:
of the returned array.
dtype : data-type, optional
Overrides the data type of the result.
shape : int or tuple[int], optional
Overrides the shape of the result.
Returns
-------
Expand All @@ -123,7 +129,7 @@ def empty_like(a: ndarray, dtype: Optional[npt.DTypeLike] = None) -> ndarray:
--------
Multiple GPUs, Multiple CPUs
"""
shape = a.shape
shape = a.shape if shape is None else shape
if dtype is not None:
dtype = np.dtype(dtype)
else:
Expand Down Expand Up @@ -238,7 +244,11 @@ def ones(shape: NdShapeLike, dtype: npt.DTypeLike = np.float64) -> ndarray:
return full(shape, 1, dtype=dtype)


def ones_like(a: ndarray, dtype: Optional[npt.DTypeLike] = None) -> ndarray:
def ones_like(
a: ndarray,
dtype: Optional[npt.DTypeLike] = None,
shape: Optional[NdShapeLike] = None,
) -> ndarray:
"""
Return an array of ones with the same shape and type as a given array.
Expand All @@ -250,6 +260,8 @@ def ones_like(a: ndarray, dtype: Optional[npt.DTypeLike] = None) -> ndarray:
returned array.
dtype : data-type, optional
Overrides the data type of the result.
shape : int or tuple[int], optional
Overrides the shape of the result.
Returns
-------
Expand All @@ -267,7 +279,7 @@ def ones_like(a: ndarray, dtype: Optional[npt.DTypeLike] = None) -> ndarray:
usedtype = a.dtype
if dtype is not None:
usedtype = np.dtype(dtype)
return full_like(a, 1, dtype=usedtype)
return full_like(a, 1, dtype=usedtype, shape=shape)


def zeros(shape: NdShapeLike, dtype: npt.DTypeLike = np.float64) -> ndarray:
Expand Down Expand Up @@ -301,7 +313,11 @@ def zeros(shape: NdShapeLike, dtype: npt.DTypeLike = np.float64) -> ndarray:
return full(shape, 0, dtype=dtype)


def zeros_like(a: ndarray, dtype: Optional[npt.DTypeLike] = None) -> ndarray:
def zeros_like(
a: ndarray,
dtype: Optional[npt.DTypeLike] = None,
shape: Optional[NdShapeLike] = None,
) -> ndarray:
"""
Return an array of zeros with the same shape and type as a given array.
Expand All @@ -313,6 +329,8 @@ def zeros_like(a: ndarray, dtype: Optional[npt.DTypeLike] = None) -> ndarray:
the returned array.
dtype : data-type, optional
Overrides the data type of the result.
shape : int or tuple[int], optional
Overrides the shape of the result.
Returns
-------
Expand All @@ -330,7 +348,7 @@ def zeros_like(a: ndarray, dtype: Optional[npt.DTypeLike] = None) -> ndarray:
usedtype = a.dtype
if dtype is not None:
usedtype = np.dtype(dtype)
return full_like(a, 0, dtype=usedtype)
return full_like(a, 0, dtype=usedtype, shape=shape)


def full(
Expand Down Expand Up @@ -376,7 +394,10 @@ def full(


def full_like(
a: ndarray, value: Union[int, float], dtype: Optional[npt.DTypeLike] = None
a: ndarray,
value: Union[int, float],
dtype: Optional[npt.DTypeLike] = None,
shape: Optional[NdShapeLike] = None,
) -> ndarray:
"""
Expand All @@ -391,6 +412,8 @@ def full_like(
Fill value.
dtype : data-type, optional
Overrides the data type of the result.
shape : int or tuple[int], optional
Overrides the shape of the result.
Returns
-------
Expand All @@ -409,7 +432,7 @@ def full_like(
dtype = np.dtype(dtype)
else:
dtype = a.dtype
result = empty_like(a, dtype=dtype)
result = empty_like(a, dtype=dtype, shape=shape)
val = np.array(value, dtype=result.dtype)
result._thunk.fill(val)
return result
Expand Down
25 changes: 16 additions & 9 deletions tests/integration/test_array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,39 +137,46 @@ def test_full_bad_filled_value(self):
(np.arange(24).reshape(4, 3, 2), "f4"),
]
LIKE_FUNCTIONS = ("zeros_like", "ones_like")
SHAPE_ARG = (None, (-1,), (1, -1))


@pytest.mark.parametrize("x_np,dtype", DATA_ARGS)
def test_empty_like(x_np, dtype):
@pytest.mark.parametrize("shape", SHAPE_ARG)
def test_empty_like(x_np, dtype, shape):
shape = shape if shape is None else x_np.reshape(shape).shape
x = num.array(x_np)
xfl = num.empty_like(x, dtype=dtype)
yfl = np.empty_like(x_np, dtype=dtype)
xfl = num.empty_like(x, dtype=dtype, shape=shape)
yfl = np.empty_like(x_np, dtype=dtype, shape=shape)

assert xfl.shape == yfl.shape
assert xfl.dtype == yfl.dtype


@pytest.mark.parametrize("x_np,dtype", DATA_ARGS)
@pytest.mark.parametrize("fn", LIKE_FUNCTIONS)
def test_func_like(fn, x_np, dtype):
@pytest.mark.parametrize("shape", SHAPE_ARG)
def test_func_like(fn, x_np, dtype, shape):
shape = shape if shape is None else x_np.reshape(shape).shape
num_f = getattr(num, fn)
np_f = getattr(np, fn)

x = num.array(x_np)
xfl = num_f(x, dtype=dtype)
yfl = np_f(x_np, dtype=dtype)
xfl = num_f(x, dtype=dtype, shape=shape)
yfl = np_f(x_np, dtype=dtype, shape=shape)

assert np.array_equal(xfl, yfl)
assert xfl.dtype == yfl.dtype


@pytest.mark.parametrize("value", FILLED_VALUES)
@pytest.mark.parametrize("x_np, dtype", DATA_ARGS)
def test_full_like(x_np, dtype, value):
@pytest.mark.parametrize("shape", SHAPE_ARG)
def test_full_like(x_np, dtype, value, shape):
shape = shape if shape is None else x_np.reshape(shape).shape
x = num.array(x_np)

xfl = num.full_like(x, value, dtype=dtype)
yfl = np.full_like(x_np, value, dtype=dtype)
xfl = num.full_like(x, value, dtype=dtype, shape=shape)
yfl = np.full_like(x_np, value, dtype=dtype, shape=shape)
assert np.array_equal(xfl, yfl)
assert xfl.dtype == yfl.dtype

Expand Down

0 comments on commit e16bd14

Please sign in to comment.