Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance mask_indices and move_axis #622

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 78 additions & 9 deletions tests/integration/test_mask_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,94 @@
import cunumeric as num

KS = [0, -1, 1, -2, 2]
FUNCTIONS = ["tril", "triu"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer tuples unless mutability is an explicit need

Suggested change
KS = [0, -1, 1, -2, 2]
FUNCTIONS = ["tril", "triu"]
KS = (0, -1, 1, -2, 2)
FUNCTIONS = ("tril", "triu")

(here, and also several places below)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Tuple is better than list if the case is immutable. Fixed.

N = 100


def _test(mask_func, k):
def _test(mask_func, n, k):
num_f = getattr(num, mask_func)
np_f = getattr(np, mask_func)

a = num.mask_indices(100, num_f, k=k)
an = np.mask_indices(100, np_f, k=k)
a = num.mask_indices(n, num_f, k=k)
an = np.mask_indices(n, np_f, k=k)
assert num.array_equal(a, an)


@pytest.mark.parametrize("k", KS, ids=lambda k: f"(k={k})")
def test_mask_indices_tril(k):
_test("tril", k)
def _test_default_k(mask_func, n):
num_f = getattr(num, mask_func)
np_f = getattr(np, mask_func)

a = num.mask_indices(n, num_f)
an = np.mask_indices(n, np_f)
assert num.array_equal(a, an)


@pytest.mark.parametrize("n", [0, 1, 100], ids=lambda n: f"(n={n})")
@pytest.mark.parametrize("mask_func", FUNCTIONS)
def test_mask_indices_default_k(n, mask_func):
_test_default_k(mask_func, n)


@pytest.mark.parametrize(
"k", KS + [-N, N, -10 * N, 10 * N], ids=lambda k: f"(k={k})"
)
@pytest.mark.parametrize("mask_func", FUNCTIONS)
def test_mask_indices(k, mask_func):
_test(mask_func, N, k)


@pytest.mark.xfail
@pytest.mark.parametrize(
"k", [-10.5, -0.5, 0.5, 10.5], ids=lambda k: f"(k={k})"
)
@pytest.mark.parametrize("mask_func", FUNCTIONS)
def test_mask_indices_float_k(k, mask_func):
# cuNumeric: struct.error: required argument is not an integer
# Numpy: pass
_test(mask_func, N, k)


class TestMaskIndicesErrors:
def test_negative_int_n(self):
with pytest.raises(ValueError):
num.mask_indices(-1, num.tril)

@pytest.mark.parametrize("n", [-10.0, 0.0, 10.5])
def test_float_n(self, n):
msg = "expected a sequence of integers or a single integer"
with pytest.raises(TypeError, match=msg):
num.mask_indices(n, num.tril)

@pytest.mark.xfail
def test_k_complex(self):
# In cuNumeric, it raises struct.error,
# msg is required argument is not an integer
# In Numpy, it raises TypeError,
# msg is '<=' not supported between instances of 'complex' and 'int'
with pytest.raises(TypeError):
num.mask_indices(10, num.tril, 1 + 2j)

@pytest.mark.xfail
def test_k_none(self):
# In cuNumeric, it raises struct.error,
# msg is required argument is not an integer
# In Numpy, it raises TypeError,
# msg is unsupported operand type(s) for -: 'NoneType' and 'int'
with pytest.raises(TypeError):
num.mask_indices(10, num.tril, None)

def test_bad_mask_func(self):
msg = "takes 1 positional argument but 2 were given"
with pytest.raises(TypeError, match=msg):
num.mask_indices(10, num.block)

msg = "'str' object is not callable"
with pytest.raises(TypeError, match=msg):
num.mask_indices(10, "abc")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be better split into two separate named tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


@pytest.mark.parametrize("k", KS, ids=lambda k: f"(k={k})")
def test_indices_triu(k):
_test("triu", k)
msg = "'NoneType' object is not callable"
with pytest.raises(TypeError, match=msg):
num.mask_indices(10, None)


if __name__ == "__main__":
Expand Down
76 changes: 76 additions & 0 deletions tests/integration/test_moveaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,82 @@ def test_moveaxis(ndim, axes):
assert cn_a.sum() == 0


def test_moveaxis_with_empty_axis():
np_a = np.ones((3, 4, 5))
cn_a = cn.ones((3, 4, 5))

axes = ([], [])
source, destination = axes

np_res = np.moveaxis(np_a, source, destination)
cn_res = cn.moveaxis(cn_a, source, destination)
assert np.array_equal(np_res, cn_res)


EMPTY_ARRAYS = [
[],
[[]],
[[], []],
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
EMPTY_ARRAYS = [
[],
[[]],
[[], []],
]
EMPTY_ARRAYS = (
[],
[[]],
[[], []],
)



@pytest.mark.parametrize("a", EMPTY_ARRAYS)
def test_moveaxis_with_empty_array(a):
axes = (0, -1)
source, destination = axes

np_res = np.moveaxis(a, source, destination)
cn_res = cn.moveaxis(a, source, destination)
assert np.array_equal(np_res, cn_res)


class TestMoveAxisErrors:
def setup(self):
self.x = cn.ones((3, 4, 5))

def test_repeated_axis(self):
msg = "repeated axis"
with pytest.raises(ValueError, match=msg):
cn.moveaxis(self.x, [0, 0], [1, 0])

with pytest.raises(ValueError, match=msg):
cn.moveaxis(self.x, [0, 1], [0, -3])

def test_axis_out_of_bound(self):
msg = "out of bound"
with pytest.raises(np.AxisError, match=msg):
cn.moveaxis(self.x, [0, 3], [0, 1])

with pytest.raises(np.AxisError, match=msg):
cn.moveaxis(self.x, [0, 1], [0, -4])

with pytest.raises(np.AxisError, match=msg):
cn.moveaxis(self.x, 4, 0)

with pytest.raises(np.AxisError, match=msg):
cn.moveaxis(self.x, 0, -4)

def test_axis_with_different_length(self):
msg = "arguments must have the same number of elements"
with pytest.raises(ValueError, match=msg):
cn.moveaxis(self.x, [0], [1, 0])

def test_axis_with_bad_type(self):
msg = "integer argument expected, got float"
with pytest.raises(TypeError, match=msg):
cn.moveaxis(self.x, [0.0, 1], [1, 0])

with pytest.raises(TypeError, match=msg):
cn.moveaxis(self.x, [0, 1], [1, 0.0])

msg = "'NoneType' object is not iterable"
with pytest.raises(TypeError, match=msg):
cn.moveaxis(self.x, None, 0)

with pytest.raises(TypeError, match=msg):
cn.moveaxis(self.x, 0, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split into separate tests for float and None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed



if __name__ == "__main__":
import sys

Expand Down