-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance mask_indices and move_axis #622
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,25 +19,94 @@ | |
import cunumeric as num | ||
|
||
KS = [0, -1, 1, -2, 2] | ||
FUNCTIONS = ["tril", "triu"] | ||
N = 100 | ||
|
||
|
||
def _test(mask_func, k): | ||
def _test(mask_func, n, k): | ||
num_f = getattr(num, mask_func) | ||
np_f = getattr(np, mask_func) | ||
|
||
a = num.mask_indices(100, num_f, k=k) | ||
an = np.mask_indices(100, np_f, k=k) | ||
a = num.mask_indices(n, num_f, k=k) | ||
an = np.mask_indices(n, np_f, k=k) | ||
assert num.array_equal(a, an) | ||
|
||
|
||
@pytest.mark.parametrize("k", KS, ids=lambda k: f"(k={k})") | ||
def test_mask_indices_tril(k): | ||
_test("tril", k) | ||
def _test_default_k(mask_func, n): | ||
num_f = getattr(num, mask_func) | ||
np_f = getattr(np, mask_func) | ||
|
||
a = num.mask_indices(n, num_f) | ||
an = np.mask_indices(n, np_f) | ||
assert num.array_equal(a, an) | ||
|
||
|
||
@pytest.mark.parametrize("n", [0, 1, 100], ids=lambda n: f"(n={n})") | ||
@pytest.mark.parametrize("mask_func", FUNCTIONS) | ||
def test_mask_indices_default_k(n, mask_func): | ||
_test_default_k(mask_func, n) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"k", KS + [-N, N, -10 * N, 10 * N], ids=lambda k: f"(k={k})" | ||
) | ||
@pytest.mark.parametrize("mask_func", FUNCTIONS) | ||
def test_mask_indices(k, mask_func): | ||
_test(mask_func, N, k) | ||
|
||
|
||
@pytest.mark.xfail | ||
@pytest.mark.parametrize( | ||
"k", [-10.5, -0.5, 0.5, 10.5], ids=lambda k: f"(k={k})" | ||
) | ||
@pytest.mark.parametrize("mask_func", FUNCTIONS) | ||
def test_mask_indices_float_k(k, mask_func): | ||
# cuNumeric: struct.error: required argument is not an integer | ||
# Numpy: pass | ||
_test(mask_func, N, k) | ||
|
||
|
||
class TestMaskIndicesErrors: | ||
def test_negative_int_n(self): | ||
with pytest.raises(ValueError): | ||
num.mask_indices(-1, num.tril) | ||
|
||
@pytest.mark.parametrize("n", [-10.0, 0.0, 10.5]) | ||
def test_float_n(self, n): | ||
msg = "expected a sequence of integers or a single integer" | ||
with pytest.raises(TypeError, match=msg): | ||
num.mask_indices(n, num.tril) | ||
|
||
@pytest.mark.xfail | ||
def test_k_complex(self): | ||
# In cuNumeric, it raises struct.error, | ||
# msg is required argument is not an integer | ||
# In Numpy, it raises TypeError, | ||
# msg is '<=' not supported between instances of 'complex' and 'int' | ||
with pytest.raises(TypeError): | ||
num.mask_indices(10, num.tril, 1 + 2j) | ||
|
||
@pytest.mark.xfail | ||
def test_k_none(self): | ||
# In cuNumeric, it raises struct.error, | ||
# msg is required argument is not an integer | ||
# In Numpy, it raises TypeError, | ||
# msg is unsupported operand type(s) for -: 'NoneType' and 'int' | ||
with pytest.raises(TypeError): | ||
num.mask_indices(10, num.tril, None) | ||
|
||
def test_bad_mask_func(self): | ||
msg = "takes 1 positional argument but 2 were given" | ||
with pytest.raises(TypeError, match=msg): | ||
num.mask_indices(10, num.block) | ||
|
||
msg = "'str' object is not callable" | ||
with pytest.raises(TypeError, match=msg): | ||
num.mask_indices(10, "abc") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would be better split into two separate named tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
|
||
@pytest.mark.parametrize("k", KS, ids=lambda k: f"(k={k})") | ||
def test_indices_triu(k): | ||
_test("triu", k) | ||
msg = "'NoneType' object is not callable" | ||
with pytest.raises(TypeError, match=msg): | ||
num.mask_indices(10, None) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -44,6 +44,82 @@ def test_moveaxis(ndim, axes): | |||||||||||||||||||||
assert cn_a.sum() == 0 | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def test_moveaxis_with_empty_axis(): | ||||||||||||||||||||||
np_a = np.ones((3, 4, 5)) | ||||||||||||||||||||||
cn_a = cn.ones((3, 4, 5)) | ||||||||||||||||||||||
|
||||||||||||||||||||||
axes = ([], []) | ||||||||||||||||||||||
source, destination = axes | ||||||||||||||||||||||
|
||||||||||||||||||||||
np_res = np.moveaxis(np_a, source, destination) | ||||||||||||||||||||||
cn_res = cn.moveaxis(cn_a, source, destination) | ||||||||||||||||||||||
assert np.array_equal(np_res, cn_res) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
EMPTY_ARRAYS = [ | ||||||||||||||||||||||
[], | ||||||||||||||||||||||
[[]], | ||||||||||||||||||||||
[[], []], | ||||||||||||||||||||||
] | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
@pytest.mark.parametrize("a", EMPTY_ARRAYS) | ||||||||||||||||||||||
def test_moveaxis_with_empty_array(a): | ||||||||||||||||||||||
axes = (0, -1) | ||||||||||||||||||||||
source, destination = axes | ||||||||||||||||||||||
|
||||||||||||||||||||||
np_res = np.moveaxis(a, source, destination) | ||||||||||||||||||||||
cn_res = cn.moveaxis(a, source, destination) | ||||||||||||||||||||||
assert np.array_equal(np_res, cn_res) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
class TestMoveAxisErrors: | ||||||||||||||||||||||
def setup(self): | ||||||||||||||||||||||
self.x = cn.ones((3, 4, 5)) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def test_repeated_axis(self): | ||||||||||||||||||||||
msg = "repeated axis" | ||||||||||||||||||||||
with pytest.raises(ValueError, match=msg): | ||||||||||||||||||||||
cn.moveaxis(self.x, [0, 0], [1, 0]) | ||||||||||||||||||||||
|
||||||||||||||||||||||
with pytest.raises(ValueError, match=msg): | ||||||||||||||||||||||
cn.moveaxis(self.x, [0, 1], [0, -3]) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def test_axis_out_of_bound(self): | ||||||||||||||||||||||
msg = "out of bound" | ||||||||||||||||||||||
with pytest.raises(np.AxisError, match=msg): | ||||||||||||||||||||||
cn.moveaxis(self.x, [0, 3], [0, 1]) | ||||||||||||||||||||||
|
||||||||||||||||||||||
with pytest.raises(np.AxisError, match=msg): | ||||||||||||||||||||||
cn.moveaxis(self.x, [0, 1], [0, -4]) | ||||||||||||||||||||||
|
||||||||||||||||||||||
with pytest.raises(np.AxisError, match=msg): | ||||||||||||||||||||||
cn.moveaxis(self.x, 4, 0) | ||||||||||||||||||||||
|
||||||||||||||||||||||
with pytest.raises(np.AxisError, match=msg): | ||||||||||||||||||||||
cn.moveaxis(self.x, 0, -4) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def test_axis_with_different_length(self): | ||||||||||||||||||||||
msg = "arguments must have the same number of elements" | ||||||||||||||||||||||
with pytest.raises(ValueError, match=msg): | ||||||||||||||||||||||
cn.moveaxis(self.x, [0], [1, 0]) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def test_axis_with_bad_type(self): | ||||||||||||||||||||||
msg = "integer argument expected, got float" | ||||||||||||||||||||||
with pytest.raises(TypeError, match=msg): | ||||||||||||||||||||||
cn.moveaxis(self.x, [0.0, 1], [1, 0]) | ||||||||||||||||||||||
|
||||||||||||||||||||||
with pytest.raises(TypeError, match=msg): | ||||||||||||||||||||||
cn.moveaxis(self.x, [0, 1], [1, 0.0]) | ||||||||||||||||||||||
|
||||||||||||||||||||||
msg = "'NoneType' object is not iterable" | ||||||||||||||||||||||
with pytest.raises(TypeError, match=msg): | ||||||||||||||||||||||
cn.moveaxis(self.x, None, 0) | ||||||||||||||||||||||
|
||||||||||||||||||||||
with pytest.raises(TypeError, match=msg): | ||||||||||||||||||||||
cn.moveaxis(self.x, 0, None) | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. split into separate tests for float and None There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||||||||
import sys | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefer tuples unless mutability is an explicit need
(here, and also several places below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. Tuple is better than list if the case is immutable. Fixed.