Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single processor implementation for linalg.solve #568

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cunumeric/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ class _CunumericSharedLib:
CUNUMERIC_SCAN_PROD: int
CUNUMERIC_SCAN_SUM: int
CUNUMERIC_SEARCHSORTED: int
CUNUMERIC_SOLVE: int
CUNUMERIC_SORT: int
CUNUMERIC_SYRK: int
CUNUMERIC_TILE: int
Expand Down Expand Up @@ -360,6 +361,7 @@ class CuNumericOpCode(IntEnum):
SCAN_GLOBAL = _cunumeric.CUNUMERIC_SCAN_GLOBAL
SCAN_LOCAL = _cunumeric.CUNUMERIC_SCAN_LOCAL
SEARCHSORTED = _cunumeric.CUNUMERIC_SEARCHSORTED
SOLVE = _cunumeric.CUNUMERIC_SOLVE
SORT = _cunumeric.CUNUMERIC_SORT
SYRK = _cunumeric.CUNUMERIC_SYRK
TILE = _cunumeric.CUNUMERIC_TILE
Expand Down
5 changes: 5 additions & 0 deletions cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
UnaryRedCode,
)
from .linalg.cholesky import cholesky
from .linalg.solve import solve
from .sort import sort
from .thunk import NumPyThunk
from .utils import is_advanced_indexing
Expand Down Expand Up @@ -3089,6 +3090,10 @@ def compute_strides(shape: NdShape) -> tuple[int, ...]:
def cholesky(self, src: Any, no_tril: bool = False) -> None:
cholesky(self, src, no_tril)

@auto_convert([1, 2])
def solve(self, a: Any, b: Any) -> None:
solve(self, a, b)

@auto_convert([2])
def scan(
self,
Expand Down
13 changes: 13 additions & 0 deletions cunumeric/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,6 +1563,19 @@ def cholesky(self, src: Any, no_tril: bool) -> None:
result = np.triu(result.T.conj(), k=1) + result
self.array[:] = result

def solve(self, a: Any, b: Any) -> None:
self.check_eager_args(a, b)
if self.deferred is not None:
self.deferred.solve(a, b)
else:
try:
result = np.linalg.solve(a.array, b.array)
except np.linalg.LinAlgError as e:
from .linalg import LinAlgError

raise LinAlgError(e) from e
self.array[:] = result

def scan(
self,
op: int,
Expand Down
75 changes: 75 additions & 0 deletions cunumeric/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from numpy.core.multiarray import normalize_axis_index # type: ignore
from numpy.core.numeric import normalize_axis_tuple # type: ignore

from .exception import LinAlgError

if TYPE_CHECKING:
from typing import Optional

import numpy.typing as npt


Expand Down Expand Up @@ -80,6 +84,45 @@ def cholesky(a: ndarray) -> ndarray:
return _cholesky(a)


@add_boilerplate("a", "b")
def solve(a: ndarray, b: ndarray) -> ndarray:
if a.ndim < 2:
raise LinAlgError(
f"{a.ndim}-dimensional array given. "
"Array must be at least two-dimensional"
)
if b.ndim < 1:
raise LinAlgError(
f"{a.ndim}-dimensional array given. "
bryevdv marked this conversation as resolved.
Show resolved Hide resolved
"Array must be at least one-dimensional"
)
if np.dtype("e") in (a.dtype, b.dtype):
raise TypeError("array type float16 is unsupported in linalg")
manopapad marked this conversation as resolved.
Show resolved Hide resolved
if a.ndim > 2 or b.ndim > 2:
raise NotImplementedError(
"cuNumeric needs to support stacked 2d arrays"
bryevdv marked this conversation as resolved.
Show resolved Hide resolved
)
if a.shape[-2] != a.shape[-1]:
raise LinAlgError("Last 2 dimensions of the array must be square")
if a.shape[-1] != b.shape[0]:
if b.ndim == 1:
raise ValueError(
"Input operand 1 has a mismatch in its dimension 0, "
f"with signature (m,m),(m)->(m) (size {b.shape[0]} "
f"is different from {a.shape[-1]})"
)
else:
raise ValueError(
"Input operand 1 has a mismatch in its dimension 0, "
f"with signature (m,m),(m,n)->(m,n) (size {b.shape[0]} "
f"is different from {a.shape[-1]})"
bryevdv marked this conversation as resolved.
Show resolved Hide resolved
)
if a.size == 0 or b.size == 0:
return empty_like(b)

return _solve(a, b)


# This implementation is adapted closely from NumPy
@add_boilerplate("a")
def matrix_power(a: ndarray, n: int) -> ndarray:
Expand Down Expand Up @@ -555,3 +598,35 @@ def _cholesky(a: ndarray, no_tril: bool = False) -> ndarray:
)
output._thunk.cholesky(input._thunk, no_tril=no_tril)
return output


def _solve(
a: ndarray, b: ndarray, output: Optional[ndarray] = None
) -> ndarray:
if a.dtype.kind not in ("f", "c"):
a = a.astype("float64")
if b.dtype.kind not in ("f", "c"):
b = b.astype("float64")
if a.dtype != b.dtype:
dtype = np.find_common_type([a.dtype, b.dtype], [])
a = a.astype(dtype)
b = b.astype(dtype)

if output is not None:
out = output
manopapad marked this conversation as resolved.
Show resolved Hide resolved
if out.shape != b.shape:
raise ValueError(
f"Output shape mismatch: expected {b.shape}, "
f"but found {out.shape}"
)
else:
out = ndarray(
shape=b.shape,
dtype=b.dtype,
inputs=(
a,
b,
),
)
out._thunk.solve(a._thunk, b._thunk)
return out
62 changes: 62 additions & 0 deletions cunumeric/linalg/solve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2022 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations

from typing import TYPE_CHECKING, cast

from cunumeric.config import CuNumericOpCode

from .cholesky import transpose_copy_single
from .exception import LinAlgError

if TYPE_CHECKING:
from legate.core.context import Context
from legate.core.store import Store

from ..deferred import DeferredArray


def solve_single(context: Context, a: Store, b: Store) -> None:
task = context.create_auto_task(CuNumericOpCode.SOLVE)
task.throws_exception(LinAlgError)
task.add_input(a)
task.add_input(b)
task.add_output(a)
task.add_output(b)

task.add_broadcast(a)
task.add_broadcast(b)

task.execute()


def solve(output: DeferredArray, a: DeferredArray, b: DeferredArray) -> None:
from ..deferred import DeferredArray

runtime = output.runtime
context = output.context

a_copy = cast(
DeferredArray,
runtime.create_empty_thunk(a.shape, dtype=a.dtype, inputs=(a,)),
)
transpose_copy_single(context, a.base, a_copy.base)

if b.ndim > 1:
transpose_copy_single(context, b.base, output.base)
else:
output.copy(b)

solve_single(context, a_copy.base, output.base)
4 changes: 4 additions & 0 deletions cunumeric/thunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,10 @@ def where(self, rhs1: Any, rhs2: Any, rhs3: Any) -> None:
def cholesky(self, src: Any, no_tril: bool) -> None:
...

@abstractmethod
def solve(self, a: Any, b: Any) -> None:
...

@abstractmethod
def scan(
self,
Expand Down
70 changes: 70 additions & 0 deletions examples/solve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2022 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse

from legate.timing import time

import cunumeric as np


def solve(m, n, nrhs, dtype):
a = np.random.rand(m, n).astype(dtype=dtype)
b = np.random.rand(n, nrhs).astype(dtype=dtype)

start = time()
np.linalg.solve(a, b)
stop = time()

total = (stop - start) / 1000.0
print(f"Elapsed Time: {total} ms")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--num_rows",
type=int,
default=10,
dest="m",
help="number of rows in the matrix",
)
parser.add_argument(
"-n",
"--num_cols",
type=int,
default=10,
dest="n",
help="number of columns in the matrix",
)
parser.add_argument(
"-s",
"--nrhs",
type=int,
default=1,
dest="nrhs",
help="number of right hand sides",
)
parser.add_argument(
"-t",
"--type",
default="float64",
choices=["float32", "float64", "complex64", "complex128"],
dest="dtype",
help="data type",
)
args, unknown = parser.parse_known_args()
bryevdv marked this conversation as resolved.
Show resolved Hide resolved
solve(args.m, args.n, args.nrhs, args.dtype)
3 changes: 3 additions & 0 deletions src/cunumeric.mk
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ GEN_CPU_SRC += cunumeric/ternary/where.cc \
cunumeric/matrix/matvecmul.cc \
cunumeric/matrix/dot.cc \
cunumeric/matrix/potrf.cc \
cunumeric/matrix/solve.cc \
cunumeric/matrix/syrk.cc \
cunumeric/matrix/tile.cc \
cunumeric/matrix/transpose.cc \
Expand Down Expand Up @@ -92,6 +93,7 @@ GEN_CPU_SRC += cunumeric/ternary/where_omp.cc \
cunumeric/matrix/matvecmul_omp.cc \
cunumeric/matrix/dot_omp.cc \
cunumeric/matrix/potrf_omp.cc \
cunumeric/matrix/solve_omp.cc \
cunumeric/matrix/syrk_omp.cc \
cunumeric/matrix/tile_omp.cc \
cunumeric/matrix/transpose_omp.cc \
Expand Down Expand Up @@ -136,6 +138,7 @@ GEN_GPU_SRC += cunumeric/ternary/where.cu \
cunumeric/matrix/matvecmul.cu \
cunumeric/matrix/dot.cu \
cunumeric/matrix/potrf.cu \
cunumeric/matrix/solve.cu \
cunumeric/matrix/syrk.cu \
cunumeric/matrix/tile.cu \
cunumeric/matrix/transpose.cu \
Expand Down
1 change: 1 addition & 0 deletions src/cunumeric/cunumeric_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ enum CuNumericOpCode {
CUNUMERIC_REPEAT,
CUNUMERIC_SCALAR_UNARY_RED,
CUNUMERIC_SEARCHSORTED,
CUNUMERIC_SOLVE,
CUNUMERIC_SORT,
CUNUMERIC_SYRK,
CUNUMERIC_TILE,
Expand Down
1 change: 1 addition & 0 deletions src/cunumeric/mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ std::vector<StoreMapping> CuNumericMapper::store_mappings(
}
case CUNUMERIC_POTRF:
case CUNUMERIC_TRSM:
case CUNUMERIC_SOLVE:
case CUNUMERIC_SYRK:
case CUNUMERIC_GEMM: {
std::vector<StoreMapping> mappings;
Expand Down
Loading