Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Numpy axis normalizations where possible #419

Merged
merged 3 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 15 additions & 26 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from collections.abc import Iterable
from functools import reduce, wraps
from inspect import signature
from typing import Callable, Optional, Set, Tuple, TypeVar
from typing import Callable, Optional, Set, TypeVar

import numpy as np
import pyarrow
from numpy.core.multiarray import normalize_axis_index
from numpy.core.numeric import normalize_axis_tuple
from typing_extensions import ParamSpec

from legate.core import Array
Expand Down Expand Up @@ -1782,10 +1784,8 @@ def take(self, indices, axis=None, out=None, mode="raise"):
if axis is None:
self = self.ravel()
axis = 0
elif axis < 0:
axis = self.ndim + axis
if axis < 0 or axis >= self.ndim:
raise ValueError("axis argument is out of bounds")
else:
axis = normalize_axis_index(axis, self.ndim)

# TODO remove "raise" logic when bounds check for advanced
# indexing is implementd
Expand Down Expand Up @@ -1935,7 +1935,7 @@ def choose(self, choices, out=None, mode="raise"):
def compress(self, condition, axis=None, out=None):
"""a.compress(self, condition, axis=None, out=None)

Return selected slices of an array along given axis..
Return selected slices of an array along given axis.

Refer to :func:`cunumeric.compress` for full documentation.

Expand All @@ -1959,9 +1959,12 @@ def compress(self, condition, axis=None, out=None):
category=RuntimeWarning,
)
condition = condition.astype(bool)

if axis is None:
axis = 0
a = self.ravel()
else:
axis = normalize_axis_index(axis, self.ndim)
bryevdv marked this conversation as resolved.
Show resolved Hide resolved

if a.shape[axis] < condition.shape[0]:
raise ValueError(
Expand Down Expand Up @@ -2530,7 +2533,7 @@ def getfield(self, dtype, offset=0):
"for ndarray.getfield"
)

def _convert_singleton_key(self, args: Tuple):
def _convert_singleton_key(self, args: tuple):
if len(args) == 0 and self.size == 1:
return (0,) * self.ndim
if len(args) == 1 and isinstance(args[0], int):
Expand Down Expand Up @@ -3070,9 +3073,7 @@ def squeeze(self, axis=None):
"all axis to squeeze must be less than ndim"
)
if self.shape[axis] != 1:
raise ValueError(
"axis to squeeze must have extent " "of one"
)
raise ValueError("axis to squeeze must have extent of one")
elif isinstance(axis, tuple):
for ax in axis:
if ax >= self.ndim:
Expand Down Expand Up @@ -3593,23 +3594,11 @@ def _perform_unary_reduction(
raise NotImplementedError(
"(arg)max/min not supported for complex-type arrays"
)
# Compute the output shape
axes = axis
if axes is None:
axes = tuple(range(src.ndim))
elif not isinstance(axes, tuple):
axes = (axes,)

if any(type(ax) != int for ax in axes):
raise TypeError(
"'axis' must be an integer or a tuple of integers, "
f"but got {axis}"
)

axes = tuple(ax + src.ndim if ax < 0 else ax for ax in axes)

if any(ax < 0 for ax in axes):
raise ValueError(f"Invalid 'axis' value {axis}")
if axis is None:
axes = tuple(range(src.ndim))
else:
axes = normalize_axis_tuple(axis, src.ndim)

out_shape = ()
for dim in range(src.ndim):
Expand Down
18 changes: 7 additions & 11 deletions cunumeric/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
from cunumeric._ufunc.math import add, sqrt as _sqrt
from cunumeric.array import add_boilerplate, convert_to_cunumeric_ndarray
from cunumeric.module import dot, empty_like, eye, matmul, ndarray
from numpy.core.multiarray import ( # type: ignore [attr-defined]
normalize_axis_index,
)
from numpy.core.multiarray import normalize_axis_index # type: ignore
from numpy.core.numeric import normalize_axis_tuple # type: ignore
bryevdv marked this conversation as resolved.
Show resolved Hide resolved

if TYPE_CHECKING:
import numpy.typing as npt
Expand Down Expand Up @@ -424,14 +423,11 @@ def norm(
ret = ret.reshape(ndim * [1])
return ret

# Normalize the `axis` argument to a tuple.
nd = x.ndim
if axis is None:
computed_axis = tuple(range(nd))
elif not isinstance(axis, tuple):
computed_axis = (axis,)
computed_axis = tuple(range(x.ndim))
else:
computed_axis = axis
computed_axis = normalize_axis_tuple(axis, x.ndim)

for ax in computed_axis:
if not isinstance(ax, int):
raise TypeError(
Expand Down Expand Up @@ -469,8 +465,8 @@ def norm(
return ret
elif len(computed_axis) == 2:
row_axis, col_axis = computed_axis
row_axis = normalize_axis_index(row_axis, nd)
col_axis = normalize_axis_index(col_axis, nd)
row_axis = normalize_axis_index(row_axis, x.ndim)
col_axis = normalize_axis_index(col_axis, x.ndim)
if row_axis == col_axis:
raise ValueError("Duplicate axes given")
if ord == 2:
Expand Down
15 changes: 7 additions & 8 deletions cunumeric/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from cunumeric.config import CuNumericOpCode
from numpy.core.multiarray import normalize_axis_index

from legate.core import types as ty

Expand All @@ -32,7 +33,7 @@ def sort_flattened(output, input, argsort, stable):


def sort_swapped(output, input, argsort, sort_axis, stable):
assert sort_axis < input.ndim - 1 and sort_axis >= 0
sort_axis = normalize_axis_index(sort_axis, input.ndim)

# swap axes
swapped = input.swapaxes(sort_axis, input.ndim - 1)
Expand Down Expand Up @@ -97,12 +98,10 @@ def sort(output, input, argsort, axis=-1, stable=False):
else:
if axis is None:
axis = 0
elif axis < 0:
axis = input.ndim + axis

if axis is not input.ndim - 1:
bryevdv marked this conversation as resolved.
Show resolved Hide resolved
sort_swapped(output, input, argsort, axis, stable)

else:
# run actual sort task
axis = normalize_axis_index(axis, input.ndim)

if axis == input.ndim - 1:
sort_task(output, input, argsort, stable)
else:
sort_swapped(output, input, argsort, axis, stable)