Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the axis sanitization #412

Closed
magnatelee opened this issue Jun 16, 2022 · 8 comments · Fixed by #419
Closed

Refactor the axis sanitization #412

magnatelee opened this issue Jun 16, 2022 · 8 comments · Fixed by #419
Assignees
Labels
enhancement New feature or request

Comments

@magnatelee
Copy link
Contributor

Many operations optionally take an axis argument to denote a specific dimension of an ndarray. Although these operations sanitize the axis value in almost the same way, they implemented the logic individually and there's no common utility. We need to refactor the code so that all those operations use the same sanitization implementation.

@magnatelee magnatelee added the enhancement New feature or request label Jun 16, 2022
@bryevdv
Copy link
Contributor

bryevdv commented Jun 17, 2022

@magnatelee I have one up-front question about philosophy/policy. Is our goal to match numpy's handling of axis args exactly, including any limitations? Or is it acceptable for our APIs to be "wider"? I am thinking specifically of the possibility that some numpy method accepts a strict axis type, but we want to accept a more general (convertible) "axis-like" everywhere for simplicity.

Put another way: is it OK if our functions accept everything numpy accepts, and more? (I'm also speculating here, perhaps numpy already accepts a more general "axis-like" everywhere and the issue is moot).

@magnatelee
Copy link
Contributor Author

The goal is to match NumPy's behavior. In most cases, axis can be either None or an integer, and if the value is negative, it follows the usual Python rule of indexing (i.e., -k means the k-th value from the back). And an IndexError should be raised for out-of-bounds accesses (in this case, non-existent axes). There are some places where the axis can be a tuple of integers, in which case the logic applies element-wise. I feel there should be two functions, one that accepts tuples and one that doesn't, and we should choose the right one to call depending on the context. And of course, we first need to identity all places where an axis can be passed.

@bryevdv
Copy link
Contributor

bryevdv commented Jun 17, 2022

grep etc. was not really workable for a census. I wrote a little script to check more thoroughly based on the actual modules:

from inspect import isclass, isfunction, signature
from pkgutil import walk_packages

import cunumeric

PATH = cunumeric.__path__
PREFIX = f"{cunumeric.__name__}."
SKIP = ("cunumeric.config", "cunumeric.runtime")

for loader, name, _ in walk_packages(path=PATH, prefix=PREFIX):
    if name in SKIP: continue

    module = loader.find_module(name).load_module(name)

    for k, v in vars(module).items():
        if k.startswith("_"): continue

        if isfunction(v) and "axis" in signature(v).parameters:
            print(f"{name}.{k}")

        if isclass(v) and v.__module__ == name:
            for kk, vv in vars(v).items():
                if kk.startswith("_"): continue

                try:
                    if isfunction(vv) and "axis" in signature(vv).parameters:
                        print(f"{name}.{k}.{kk}")
                except ValueError:
                    pass

And the results appear to be:

cunumeric._ufunc.ufunc.binary_ufunc.reduce
cunumeric.array.ndarray.all
cunumeric.array.ndarray.any
cunumeric.array.ndarray.argmax
cunumeric.array.ndarray.argmin
cunumeric.array.ndarray.take
cunumeric.array.ndarray.compress
cunumeric.array.ndarray.max
cunumeric.array.ndarray.mean
cunumeric.array.ndarray.min
cunumeric.array.ndarray.partition
cunumeric.array.ndarray.argpartition
cunumeric.array.ndarray.prod
cunumeric.array.ndarray.sort
cunumeric.array.ndarray.argsort
cunumeric.array.ndarray.squeeze
cunumeric.array.ndarray.sum
cunumeric.array.ndarray.flip
cunumeric.deferred.sort
cunumeric.deferred.DeferredArray.squeeze
cunumeric.deferred.DeferredArray.repeat
cunumeric.eager.EagerArray.squeeze
cunumeric.eager.EagerArray.repeat
cunumeric.eager.EagerArray.sort
cunumeric.eager.EagerArray.partition
cunumeric.fft.fft
cunumeric.fft.ifft
cunumeric.fft.rfft
cunumeric.fft.irfft
cunumeric.fft.hfft
cunumeric.fft.ihfft
cunumeric.fft.fft.fft
cunumeric.fft.fft.ifft
cunumeric.fft.fft.rfft
cunumeric.fft.fft.irfft
cunumeric.fft.fft.hfft
cunumeric.fft.fft.ihfft
cunumeric.linalg.norm
cunumeric.linalg.linalg.norm
cunumeric.module.normalize_axis_tuple
cunumeric.module.linspace
cunumeric.module.squeeze
cunumeric.module.check_shape_dtype
cunumeric.module.append
cunumeric.module.concatenate
cunumeric.module.stack
cunumeric.module.split
cunumeric.module.array_split
cunumeric.module.repeat
cunumeric.module.flip
cunumeric.module.take
cunumeric.module.compress
cunumeric.module.all
cunumeric.module.any
cunumeric.module.prod
cunumeric.module.sum
cunumeric.module.amax
cunumeric.module.max
cunumeric.module.amin
cunumeric.module.min
cunumeric.module.unique
cunumeric.module.argsort
cunumeric.module.sort
cunumeric.module.argpartition
cunumeric.module.partition
cunumeric.module.argmax
cunumeric.module.argmin
cunumeric.module.count_nonzero
cunumeric.module.mean
cunumeric.sort.sort
cunumeric.thunk.NumPyThunk.repeat
cunumeric.thunk.NumPyThunk.squeeze

@bryevdv
Copy link
Contributor

bryevdv commented Jun 21, 2022

@magnatelee I just wrote

def _sanitize_axes(arg: AxisLike, ndim: int, name: str = "axis") -> tuple[int, ...]:
    if arg is None:
        axes = tuple(range(ndim))
    elif isinstance(arg, int):
        axes = (arg,)

    if any(type(ax) != int for ax in axes):
        raise TypeError(
            f"{name!r} must be an integer or a tuple of integers, but got {arg}"
        )

    computed_axes = tuple(ax + ndim if ax < 0 else ax for ax in axes)

    if any(ax < 0 for ax in computed_axes):
        raise ValueError(f"Invalid {name!r} value {arg}")

    return computed_axes

when I discovered that Numpy itself defines an almost identical normalize_axis_tuple, which we already happen to use in a couple of places (e.g. moveaxis). The Numpy function does not handle the None case, and in Numpy we often see that case handled explicitly

    if axis is None:
        return roll(a.ravel(), shift, 0).reshape(a.shape)

    else:
        axis = normalize_axis_tuple(axis, a.ndim, allow_duplicate=True)
        ....

or

    if axis is None:
        axis = tuple(range(x.ndim))
    else:
        axis = normalize_axis_tuple(axis, x.ndim)

Rather than duplicate what is already in numpy I would suggest first just apply normalize_axis_tuple in places where we are already doing that work by hand (and also handling None case explicitly at site, following Numpy's pattern)

@bryevdv
Copy link
Contributor

bryevdv commented Jun 21, 2022

I'm also somewhat confused by some cases like this:

        if axis is not None and not isinstance(axis, int):
            raise ValueError("axis must be an integer")
        return self._perform_unary_reduction(
            UnaryRedCode.ARGMAX,
            self,
            axis=axis,

In that case if axis=None that will get passed on to _perform_unary_reduction but then _perform_unary_reduction will do axis = list(range(src.ndim)) in that case, which is decidedly different than the (axis,) that a singe integer would result in (which is what the exception claims is all that is supported) Should the None default value be removed here? Is the error message misleading?

@magnatelee
Copy link
Contributor Author

Rather than duplicate what is already in numpy I would suggest first just apply normalize_axis_tuple in places where we are already doing that work by hand (and also handling None case explicitly at site, following Numpy's pattern)

If we can reuse the primitives in NumPy, I don't see a reason not to use them, as long as they have the semantics we're expecting. And you're right, axis=None needs a special treatment, as the operators often handle the case differently.

I'm also somewhat confused by some cases like this:

argmin and argmax support only up to one axis, whereas other reduction operators, such as sum and prod, can take multiple axes: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html. peculiarities like this might appear in other operators as well, though I don't know off top of my head.

@bryevdv
Copy link
Contributor

bryevdv commented Jun 21, 2022

argmin and argmax support only up to one axis

So then accepting None there in argmax there is a bug? Because letting it through will result in a tuple of ints being used inside _perform_unary_reduction

@magnatelee
Copy link
Contributor Author

So then accepting None there in argmax there is a bug? Because letting it through will result in a tuple of ints being used inside _perform_unary_reduction

that's not a bug. for some reason, NumPy supports argmin and argmax for a None axis, in which case they yield scalar results. We simply follow this semantics and _perform_unary_reduction will do the right thing if the output is a scalar array.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants