Skip to content

Commit

Permalink
Fix cunumeric.arange issues (#940)
Browse files Browse the repository at this point in the history
* Infer arange dtype from input arguments

* Fix arange error when ceil((stop - start)/step) is a negative number

* Remove invalid test and add the params to basic test.
  • Loading branch information
yimoj authored May 29, 2023
1 parent 21369a2 commit 3e03357
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 13 deletions.
4 changes: 2 additions & 2 deletions cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,12 +663,12 @@ def arange(
step = 1

if dtype is None:
dtype = np.array([stop]).dtype
dtype = np.find_common_type([], [type(start), type(stop), type(step)])
else:
dtype = np.dtype(dtype)

N = math.ceil((stop - start) / step)
result = ndarray((N,), dtype)
result = ndarray((_builtin_max(0, N),), dtype)
result._thunk.arange(start, stop, step)
return result

Expand Down
7 changes: 4 additions & 3 deletions tests/integration/test_array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,12 @@ def test_full_like_bad_filled_value():
(0,),
(10,),
(3.5,),
pytest.param((-10), marks=pytest.mark.xfail),
(3.0, 8, None),
(-10,),
(2, 10),
pytest.param((2, -10), marks=pytest.mark.xfail),
(2, -10),
(-2.5, 10.0),
pytest.param((1, -10, -2.5), marks=pytest.mark.xfail),
(1, -10, -2.5),
(1.0, -10.0, -2.5),
(-10, 10, 10),
(-10, 10, -100),
Expand Down
12 changes: 4 additions & 8 deletions tests/integration/test_diag_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
import cunumeric as num


def test_diag_indices_default_ndim():
a_np = np.diag_indices(10)
a_num = num.diag_indices(10)
@pytest.mark.parametrize("n", [10, -10.5, -1])
def test_diag_indices_default_ndim(n):
a_np = np.diag_indices(n)
a_num = num.diag_indices(n)
assert np.array_equal(a_np, a_num)


Expand All @@ -42,11 +43,6 @@ def test_diag_indices(n, ndim):


class TestDiagIndicesErrors:
@pytest.mark.parametrize("n", [-10.5, -1])
def test_negative_n(self, n):
with pytest.raises(ValueError):
num.diag_indices(n)

@pytest.mark.xfail
@pytest.mark.parametrize("n", [-10.5, -1])
def test_negative_n_DIVERGENCE(self, n):
Expand Down

0 comments on commit 3e03357

Please sign in to comment.