Skip to content

Commit

Permalink
Add typing to runtime.py (#428)
Browse files Browse the repository at this point in the history
* add runtime.py

* improve clone_class

* Fix return types

* remove ThunkArray
  • Loading branch information
bryevdv authored Jun 29, 2022
1 parent 8e2f3a3 commit 19f1e64
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 141 deletions.
9 changes: 6 additions & 3 deletions cunumeric/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dataclasses import dataclass
from functools import wraps
from types import FunctionType, MethodDescriptorType, MethodType, ModuleType
from typing import Any, Callable, Container, Mapping, Optional, cast
from typing import Any, Callable, Container, Mapping, Optional, TypeVar, cast

from typing_extensions import Protocol

Expand Down Expand Up @@ -219,7 +219,10 @@ def clone_module(
new_globals[attr] = value


def clone_class(origin_class: type) -> Callable[[type], type]:
C = TypeVar("C", bound=type)


def clone_class(origin_class: type) -> Callable[[C], C]:
"""Copy attributes from one class to another
Method types are wrapped with a decorator to report API calls. All
Expand All @@ -237,7 +240,7 @@ def should_wrap(obj: object) -> bool:
obj, (FunctionType, MethodType, MethodDescriptorType)
)

def decorator(cls: type) -> type:
def decorator(cls: C) -> C:
class_name = f"{origin_class.__module__}.{origin_class.__name__}"

missing = filter_namespace(
Expand Down
41 changes: 21 additions & 20 deletions cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from enum import IntEnum, unique
from functools import reduce
from itertools import product
from typing import TYPE_CHECKING, Any

import numpy as np

Expand All @@ -39,6 +40,9 @@
from .thunk import NumPyThunk
from .utils import get_arg_value_dtype, is_advanced_indexing

if TYPE_CHECKING:
import numpy.typing as npt


def _complex_field_dtype(dtype):
if dtype == np.complex64:
Expand Down Expand Up @@ -166,8 +170,8 @@ class DeferredArray(NumPyThunk):
:meta private:
"""

def __init__(self, runtime, base, dtype, numpy_array=None):
NumPyThunk.__init__(self, runtime, dtype)
def __init__(self, runtime, base, dtype, numpy_array=None) -> None:
super().__init__(runtime, dtype)
assert base is not None
assert isinstance(base, Store)
self.base = base # a Legate Store
Expand Down Expand Up @@ -203,7 +207,7 @@ def _copy_if_overlapping(self, other):
copy.copy(self, deep=True)
return copy

def __numpy_array__(self):
def __numpy_array__(self) -> npt.NDArray[Any]:
if self.numpy_array is not None:
result = self.numpy_array()
if result is not None:
Expand Down Expand Up @@ -654,7 +658,7 @@ def _convert_future_to_store(self, a):
store_copy.copy(a, deep=True)
return store_copy

def get_item(self, key):
def get_item(self, key) -> DeferredArray:
# Check to see if this is advanced indexing or not
if is_advanced_indexing(key):
# Create the indexing array
Expand Down Expand Up @@ -1133,7 +1137,7 @@ def _fill(self, value):

task.execute()

def fill(self, numpy_array):
def fill(self, numpy_array) -> None:
assert isinstance(numpy_array, np.ndarray)
assert numpy_array.size == 1
assert self.dtype == numpy_array.dtype
Expand Down Expand Up @@ -1478,7 +1482,7 @@ def _diag_helper(
task.execute()

# Create an identity array with the ones offset from the diagonal by k
def eye(self, k):
def eye(self, k) -> None:
assert self.ndim == 2 # Only 2-D arrays should be here
# First issue a fill to zero everything out
self.fill(np.array(0, dtype=self.dtype))
Expand All @@ -1489,7 +1493,7 @@ def eye(self, k):

task.execute()

def arange(self, start, stop, step):
def arange(self, start, stop, step) -> None:
assert self.ndim == 1 # Only 1-D arrays should be here
if self.scalar:
# Handle the special case of a single value here
Expand All @@ -1501,11 +1505,10 @@ def arange(self, start, stop, step):

def create_scalar(value, dtype):
array = np.array(value, dtype)
return self.runtime.create_scalar(
return self.runtime.create_wrapped_scalar(
array.data,
array.dtype,
shape=(1,),
wrap=True,
).base

task = self.context.create_task(CuNumericOpCode.ARANGE)
Expand Down Expand Up @@ -1559,7 +1562,7 @@ def trilu(self, rhs, k, lower):
task.execute()

# Repeat elements of an array.
def repeat(self, repeats, axis, scalar_repeats):
def repeat(self, repeats, axis, scalar_repeats) -> DeferredArray:
out = self.runtime.create_unbound_thunk(self.dtype, ndim=self.ndim)
task = self.context.create_task(CuNumericOpCode.REPEAT)
task.add_input(self.base)
Expand Down Expand Up @@ -1614,11 +1617,10 @@ def bincount(self, rhs, weights=None):
src_array.size == 1 and weight_array.size == 1
)
else:
weight_array = self.runtime.create_scalar(
weight_array = self.runtime.create_wrapped_scalar(
np.array(1, dtype=np.int64),
np.dtype(np.int64),
shape=(),
wrap=True,
)

dst_array.fill(np.array(0, dst_array.dtype))
Expand Down Expand Up @@ -1651,7 +1653,7 @@ def nonzero(self):
task.execute()
return results

def random(self, gen_code, args=[]):
def random(self, gen_code, args=[]) -> None:
task = self.context.create_task(CuNumericOpCode.RAND)

task.add_output(self.base)
Expand All @@ -1663,15 +1665,15 @@ def random(self, gen_code, args=[]):

task.execute()

def random_uniform(self):
def random_uniform(self) -> None:
assert self.dtype == np.float64
self.random(RandGenCode.UNIFORM)

def random_normal(self):
def random_normal(self) -> None:
assert self.dtype == np.float64
self.random(RandGenCode.NORMAL)

def random_integer(self, low, high):
def random_integer(self, low, high) -> None:
assert self.dtype.kind == "i"
low = np.array(low, self.dtype)
high = np.array(high, self.dtype)
Expand Down Expand Up @@ -1799,7 +1801,7 @@ def unary_reduction(
[],
)

def isclose(self, rhs1, rhs2, rtol, atol, equal_nan):
def isclose(self, rhs1, rhs2, rtol, atol, equal_nan) -> None:
assert not equal_nan
args = (
np.array(rtol, dtype=np.float64),
Expand Down Expand Up @@ -1882,11 +1884,10 @@ def add_arguments(self, task, args):
return
for numpy_array in args:
assert numpy_array.size == 1
scalar = self.runtime.create_scalar(
scalar = self.runtime.create_wrapped_scalar(
numpy_array.data,
numpy_array.dtype,
shape=(1,),
wrap=True,
)
task.add_input(scalar.base)

Expand Down Expand Up @@ -1963,7 +1964,7 @@ def partition(
# fallback to sort for now
sort(self, rhs, argpartition, axis, False)

def create_window(self, op_code, M, *args):
def create_window(self, op_code, M, *args) -> None:
task = self.context.create_task(CuNumericOpCode.WINDOW)
task.add_output(self.base)
task.add_scalar_arg(op_code, ty.int32)
Expand Down
39 changes: 23 additions & 16 deletions cunumeric/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np

from .config import (
Expand All @@ -30,6 +32,12 @@
from .thunk import NumPyThunk
from .utils import is_advanced_indexing

if TYPE_CHECKING:
import numpy.typing as npt

from .deferred import DeferredArray


_UNARY_OPS = {
UnaryOpCode.ABSOLUTE: np.absolute,
UnaryOpCode.ARCCOS: np.arccos,
Expand Down Expand Up @@ -163,8 +171,8 @@ class EagerArray(NumPyThunk):
:meta private:
"""

def __init__(self, runtime, array, parent=None, key=None):
NumPyThunk.__init__(self, runtime, array.dtype)
def __init__(self, runtime, array, parent=None, key=None) -> None:
super().__init__(runtime, array.dtype)
self.array = array
self.parent = parent
self.children = None
Expand All @@ -184,7 +192,7 @@ def storage(self):
def shape(self):
return self.array.shape

def __numpy_array__(self):
def __numpy_array__(self) -> npt.NDArray[Any]:
if self.deferred is not None:
return self.deferred.__numpy_array__()
# Track when this escapes. If it escapes we have
Expand Down Expand Up @@ -231,7 +239,7 @@ def _convert_children(self):
for child in self.children:
child._convert_children()

def to_deferred_array(self):
def to_deferred_array(self) -> DeferredArray:
"""This is a really important method. It will convert a tree of
eager NumPy arrays into an equivalent tree of deferred arrays that
are mirrored by an equivalent logical region tree. To be consistent
Expand All @@ -247,11 +255,10 @@ def to_deferred_array(self):
# We are at the root of the tree so we need to
# actually make a DeferredArray to use
if self.array.size == 1:
self.deferred = self.runtime.create_scalar(
self.deferred = self.runtime.create_wrapped_scalar(
self.array.data,
dtype=self.array.dtype,
shape=self.shape,
wrap=True,
)
else:
self.deferred = self.runtime.find_or_create_array_thunk(
Expand Down Expand Up @@ -356,7 +363,7 @@ def _create_indexing_key(self, key):
assert isinstance(key, NumPyThunk)
return self.runtime.to_eager_array(key).array

def get_item(self, key):
def get_item(self, key) -> NumPyThunk:
if self.deferred is not None:
return self.deferred.get_item(key)
if is_advanced_indexing(key):
Expand Down Expand Up @@ -453,7 +460,7 @@ def convert(self, rhs, warn=True):
else:
self.array[:] = rhs.array

def fill(self, value):
def fill(self, value) -> None:
if self.deferred is not None:
self.deferred.fill(value)
else:
Expand All @@ -480,7 +487,7 @@ def transpose(self, axes):
self.children.append(result)
return result

def repeat(self, repeats, axis, scalar_repeats):
def repeat(self, repeats, axis, scalar_repeats) -> NumPyThunk:
if not scalar_repeats:
self.check_eager_args(repeats)
if self.deferred is not None:
Expand Down Expand Up @@ -564,7 +571,7 @@ def _diag_helper(self, rhs, offset, naxes, extract, trace):
axes = tuple(range(ndims - naxes, ndims))
self.array = diagonal_reference(rhs.array, axes)

def eye(self, k):
def eye(self, k) -> None:
if self.deferred is not None:
self.deferred.eye(k)
else:
Expand All @@ -575,7 +582,7 @@ def eye(self, k):
self.shape[0], self.shape[1], k, dtype=self.dtype
)

def arange(self, start, stop, step):
def arange(self, start, stop, step) -> None:
if self.deferred is not None:
self.deferred.arange(start, stop, step)
else:
Expand Down Expand Up @@ -637,7 +644,7 @@ def partition(
else:
self.array = np.partition(rhs.array, kth, axis, kind, order)

def random_uniform(self):
def random_uniform(self) -> None:
if self.deferred is not None:
self.deferred.random_uniform()
else:
Expand All @@ -646,7 +653,7 @@ def random_uniform(self):
else:
self.array[:] = np.random.rand(*(self.array.shape))

def random_normal(self):
def random_normal(self) -> None:
if self.deferred is not None:
self.deferred.random_normal()
else:
Expand All @@ -655,7 +662,7 @@ def random_normal(self):
else:
self.array[:] = np.random.randn(*(self.array.shape))

def random_integer(self, low, high):
def random_integer(self, low, high) -> None:
if self.deferred is not None:
self.deferred.random_integer(low, high)
else:
Expand Down Expand Up @@ -751,7 +758,7 @@ def unary_reduction(
else:
raise RuntimeError("unsupported unary reduction op " + str(op))

def isclose(self, rhs1, rhs2, rtol, atol, equal_nan):
def isclose(self, rhs1, rhs2, rtol, atol, equal_nan) -> None:
self.check_eager_args(rhs1, rhs2)
if self.deferred is not None:
self.deferred.isclose(rhs1, rhs2, rtol, atol, equal_nan)
Expand Down Expand Up @@ -837,7 +844,7 @@ def unique(self):
else:
return EagerArray(self.runtime, np.unique(self.array))

def create_window(self, op_code, M, *args):
def create_window(self, op_code, M, *args) -> None:
if self.deferred is not None:
return self.deferred.create_window(op_code, M, *args)
else:
Expand Down
Loading

0 comments on commit 19f1e64

Please sign in to comment.