Skip to content

Commit

Permalink
Ensure input to unstable sort algorithms contains no duplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
manopapad committed Aug 14, 2023
1 parent a37a562 commit 4bcb5f7
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions tests/integration/test_argsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@
(DIM, DIM, DIM),
]

SORT_TYPES = ["quicksort", "mergesort", "stable"]
UNSTABLE_SORT_TYPE = ["heapsort"]
STABLE_SORT_TYPES = ["stable"]
UNSTABLE_SORT_TYPES = ["heapsort", "quicksort", "mergesort"]
SORT_TYPES = STABLE_SORT_TYPES + UNSTABLE_SORT_TYPES


class TestArgSort(object):
Expand Down Expand Up @@ -137,7 +138,7 @@ def test_basic_axis(self, size):
assert np.array_equal(res_num, res_np)

@pytest.mark.parametrize("size", SIZES)
@pytest.mark.parametrize("sort_type", SORT_TYPES)
@pytest.mark.parametrize("sort_type", STABLE_SORT_TYPES)
def test_basic_axis_sort_type(self, size, sort_type):
arr_np = np.random.randint(-100, 100, size)
arr_num = num.array(arr_np)
Expand All @@ -146,13 +147,14 @@ def test_basic_axis_sort_type(self, size, sort_type):
res_num = num.argsort(arr_num, axis=axis, kind=sort_type)
assert np.array_equal(res_num, res_np)

@pytest.mark.xfail
@pytest.mark.parametrize("size", SIZES)
@pytest.mark.parametrize("sort_type", UNSTABLE_SORT_TYPE)
@pytest.mark.parametrize("sort_type", UNSTABLE_SORT_TYPES)
def test_basic_axis_sort_type_unstable(self, size, sort_type):
# intermittent failed due to
# https://github.com/nv-legate/cunumeric/issues/782
arr_np = np.random.randint(-100, 100, size)
# have to guarantee unique values in input
# see https://github.com/nv-legate/cunumeric/issues/782
arr_np = np.arange(np.prod(size))
np.random.shuffle(arr_np)
arr_np = arr_np.reshape(size)
arr_num = num.array(arr_np)
for axis in range(-arr_np.ndim + 1, arr_np.ndim):
res_np = np.argsort(arr_np, axis=axis, kind=sort_type)
Expand All @@ -171,7 +173,7 @@ def test_arr_basic_axis(self, size):
assert np.array_equal(arr_np_copy, arr_num_copy)

@pytest.mark.parametrize("size", SIZES)
@pytest.mark.parametrize("sort_type", SORT_TYPES)
@pytest.mark.parametrize("sort_type", STABLE_SORT_TYPES)
def test_arr_basic_axis_sort(self, size, sort_type):
arr_np = np.random.randint(-100, 100, size)
arr_num = num.array(arr_np)
Expand All @@ -182,13 +184,14 @@ def test_arr_basic_axis_sort(self, size, sort_type):
arr_num_copy.argsort(axis=axis, kind=sort_type)
assert np.array_equal(arr_np_copy, arr_num_copy)

@pytest.mark.xfail
@pytest.mark.parametrize("size", SIZES)
@pytest.mark.parametrize("sort_type", UNSTABLE_SORT_TYPE)
@pytest.mark.parametrize("sort_type", UNSTABLE_SORT_TYPES)
def test_arr_basic_axis_sort_unstable(self, size, sort_type):
# intermittent failed due to
# https://github.com/nv-legate/cunumeric/issues/782
arr_np = np.random.randint(-100, 100, size)
# have to guarantee unique values in input
# see https://github.com/nv-legate/cunumeric/issues/782
arr_np = np.arange(np.prod(size))
np.random.shuffle(arr_np)
arr_np = arr_np.reshape(size)
arr_num = num.array(arr_np)
for axis in range(-arr_num.ndim + 1, arr_num.ndim):
arr_np_copy = arr_np
Expand Down

0 comments on commit 4bcb5f7

Please sign in to comment.