diff --git a/cunumeric/config.py b/cunumeric/config.py index 752509f85..9300545e4 100644 --- a/cunumeric/config.py +++ b/cunumeric/config.py @@ -186,6 +186,7 @@ class _CunumericSharedLib: CUNUMERIC_SCAN_PROD: int CUNUMERIC_SCAN_SUM: int CUNUMERIC_SEARCHSORTED: int + CUNUMERIC_SOLVE: int CUNUMERIC_SORT: int CUNUMERIC_SYRK: int CUNUMERIC_TILE: int @@ -363,6 +364,7 @@ class CuNumericOpCode(IntEnum): SCAN_GLOBAL = _cunumeric.CUNUMERIC_SCAN_GLOBAL SCAN_LOCAL = _cunumeric.CUNUMERIC_SCAN_LOCAL SEARCHSORTED = _cunumeric.CUNUMERIC_SEARCHSORTED + SOLVE = _cunumeric.CUNUMERIC_SOLVE SORT = _cunumeric.CUNUMERIC_SORT SYRK = _cunumeric.CUNUMERIC_SYRK TILE = _cunumeric.CUNUMERIC_TILE diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index ce0783821..dd7a72e45 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -52,6 +52,7 @@ UnaryRedCode, ) from .linalg.cholesky import cholesky +from .linalg.solve import solve from .sort import sort from .thunk import NumPyThunk from .utils import is_advanced_indexing @@ -3095,6 +3096,10 @@ def compute_strides(shape: NdShape) -> tuple[int, ...]: def cholesky(self, src: Any, no_tril: bool = False) -> None: cholesky(self, src, no_tril) + @auto_convert([1, 2]) + def solve(self, a: Any, b: Any) -> None: + solve(self, a, b) + @auto_convert([2]) def scan( self, diff --git a/cunumeric/eager.py b/cunumeric/eager.py index c6f19fb6b..60213c6d1 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -1579,6 +1579,19 @@ def cholesky(self, src: Any, no_tril: bool) -> None: result = np.triu(result.T.conj(), k=1) + result self.array[:] = result + def solve(self, a: Any, b: Any) -> None: + self.check_eager_args(a, b) + if self.deferred is not None: + self.deferred.solve(a, b) + else: + try: + result = np.linalg.solve(a.array, b.array) + except np.linalg.LinAlgError as e: + from .linalg import LinAlgError + + raise LinAlgError(e) from e + self.array[:] = result + def scan( self, op: int, diff --git a/cunumeric/linalg/cholesky.py b/cunumeric/linalg/cholesky.py index 5c671aaf4..272034127 100644 --- a/cunumeric/linalg/cholesky.py +++ b/cunumeric/linalg/cholesky.py @@ -41,6 +41,9 @@ def transpose_copy_single( # to a column major instance task.add_scalar_arg(False, ty.int32) + task.add_broadcast(output) + task.add_broadcast(input) + task.execute() diff --git a/cunumeric/linalg/linalg.py b/cunumeric/linalg/linalg.py index 6ea383519..8013a04be 100644 --- a/cunumeric/linalg/linalg.py +++ b/cunumeric/linalg/linalg.py @@ -23,7 +23,11 @@ from numpy.core.multiarray import normalize_axis_index # type: ignore from numpy.core.numeric import normalize_axis_tuple # type: ignore +from .exception import LinAlgError + if TYPE_CHECKING: + from typing import Optional + import numpy.typing as npt @@ -80,6 +84,78 @@ def cholesky(a: ndarray) -> ndarray: return _cholesky(a) +@add_boilerplate("a", "b") +def solve(a: ndarray, b: ndarray, out: Optional[ndarray] = None) -> ndarray: + """ + Solve a linear matrix equation, or system of linear scalar equations. + + Computes the "exact" solution, `x`, of the well-determined, i.e., full + rank, linear matrix equation `ax = b`. + + Parameters + ---------- + a : (M, M) array_like + Coefficient matrix. + b : {(M,), (M, K)}, array_like + Ordinate or "dependent variable" values. + out : {(M,), (M, K)}, array_like, optional + An optional output array for the solution + + Returns + ------- + x : {(M,), (M, K)} ndarray + Solution to the system a x = b. Returned shape is identical to `b`. + + Raises + ------ + LinAlgError + If `a` is singular or not square. + + See Also + -------- + numpy.linalg.solve + + Availability + -------- + Single GPU, Single CPU + """ + if a.ndim < 2: + raise LinAlgError( + f"{a.ndim}-dimensional array given. " + "Array must be at least two-dimensional" + ) + if b.ndim < 1: + raise LinAlgError( + f"{b.ndim}-dimensional array given. " + "Array must be at least one-dimensional" + ) + if np.dtype("e") in (a.dtype, b.dtype): + raise TypeError("array type float16 is unsupported in linalg") + if a.ndim > 2 or b.ndim > 2: + raise NotImplementedError( + "cuNumeric does not yet support stacked 2d arrays" + ) + if a.shape[-2] != a.shape[-1]: + raise LinAlgError("Last 2 dimensions of the array must be square") + if a.shape[-1] != b.shape[0]: + if b.ndim == 1: + raise ValueError( + "Input operand 1 has a mismatch in its dimension 0, " + f"with signature (m,m),(m)->(m) (size {b.shape[0]} " + f"is different from {a.shape[-1]})" + ) + else: + raise ValueError( + "Input operand 1 has a mismatch in its dimension 0, " + f"with signature (m,m),(m,n)->(m,n) (size {b.shape[0]} " + f"is different from {a.shape[-1]})" + ) + if a.size == 0 or b.size == 0: + return empty_like(b) + + return _solve(a, b, out) + + # This implementation is adapted closely from NumPy @add_boilerplate("a") def matrix_power(a: ndarray, n: int) -> ndarray: @@ -555,3 +631,40 @@ def _cholesky(a: ndarray, no_tril: bool = False) -> ndarray: ) output._thunk.cholesky(input._thunk, no_tril=no_tril) return output + + +def _solve( + a: ndarray, b: ndarray, output: Optional[ndarray] = None +) -> ndarray: + if a.dtype.kind not in ("f", "c"): + a = a.astype("float64") + if b.dtype.kind not in ("f", "c"): + b = b.astype("float64") + if a.dtype != b.dtype: + dtype = np.find_common_type([a.dtype, b.dtype], []) + a = a.astype(dtype) + b = b.astype(dtype) + + if output is not None: + out = output + if out.shape != b.shape: + raise ValueError( + f"Output shape mismatch: expected {b.shape}, " + f"but found {out.shape}" + ) + elif out.dtype != b.dtype: + raise TypeError( + f"Output type mismatch: expected {b.dtype}, " + f"but found {out.dtype}" + ) + else: + out = ndarray( + shape=b.shape, + dtype=b.dtype, + inputs=( + a, + b, + ), + ) + out._thunk.solve(a._thunk, b._thunk) + return out diff --git a/cunumeric/linalg/solve.py b/cunumeric/linalg/solve.py new file mode 100644 index 000000000..8eca91bc8 --- /dev/null +++ b/cunumeric/linalg/solve.py @@ -0,0 +1,62 @@ +# Copyright 2022 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from cunumeric.config import CuNumericOpCode + +from .cholesky import transpose_copy_single +from .exception import LinAlgError + +if TYPE_CHECKING: + from legate.core.context import Context + from legate.core.store import Store + + from ..deferred import DeferredArray + + +def solve_single(context: Context, a: Store, b: Store) -> None: + task = context.create_auto_task(CuNumericOpCode.SOLVE) + task.throws_exception(LinAlgError) + task.add_input(a) + task.add_input(b) + task.add_output(a) + task.add_output(b) + + task.add_broadcast(a) + task.add_broadcast(b) + + task.execute() + + +def solve(output: DeferredArray, a: DeferredArray, b: DeferredArray) -> None: + from ..deferred import DeferredArray + + runtime = output.runtime + context = output.context + + a_copy = cast( + DeferredArray, + runtime.create_empty_thunk(a.shape, dtype=a.dtype, inputs=(a,)), + ) + transpose_copy_single(context, a.base, a_copy.base) + + if b.ndim > 1: + transpose_copy_single(context, b.base, output.base) + else: + output.copy(b) + + solve_single(context, a_copy.base, output.base) diff --git a/cunumeric/thunk.py b/cunumeric/thunk.py index 6230905af..c7d68f07a 100644 --- a/cunumeric/thunk.py +++ b/cunumeric/thunk.py @@ -686,6 +686,10 @@ def where(self, rhs1: Any, rhs2: Any, rhs3: Any) -> None: def cholesky(self, src: Any, no_tril: bool) -> None: ... + @abstractmethod + def solve(self, a: Any, b: Any) -> None: + ... + @abstractmethod def scan( self, diff --git a/docs/cunumeric/source/api/linalg.rst b/docs/cunumeric/source/api/linalg.rst index 271eaba1d..78394ead0 100644 --- a/docs/cunumeric/source/api/linalg.rst +++ b/docs/cunumeric/source/api/linalg.rst @@ -38,3 +38,12 @@ Norms and other numbers linalg.norm trace + + +Solving equations and inverting matrices +---------------------------------------- + +.. autosummary:: + :toctree: generated/ + + linalg.solve diff --git a/examples/solve.py b/examples/solve.py new file mode 100644 index 000000000..5d5082dd4 --- /dev/null +++ b/examples/solve.py @@ -0,0 +1,70 @@ +# Copyright 2022 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +from legate.timing import time + +import cunumeric as np + + +def solve(m, n, nrhs, dtype): + a = np.random.rand(m, n).astype(dtype=dtype) + b = np.random.rand(n, nrhs).astype(dtype=dtype) + + start = time() + np.linalg.solve(a, b) + stop = time() + + total = (stop - start) / 1000.0 + print(f"Elapsed Time: {total} ms") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--num_rows", + type=int, + default=10, + dest="m", + help="number of rows in the matrix", + ) + parser.add_argument( + "-n", + "--num_cols", + type=int, + default=10, + dest="n", + help="number of columns in the matrix", + ) + parser.add_argument( + "-s", + "--nrhs", + type=int, + default=1, + dest="nrhs", + help="number of right hand sides", + ) + parser.add_argument( + "-t", + "--type", + default="float64", + choices=["float32", "float64", "complex64", "complex128"], + dest="dtype", + help="data type", + ) + args = parser.parse_args() + solve(args.m, args.n, args.nrhs, args.dtype) diff --git a/src/cunumeric.mk b/src/cunumeric.mk index 6ac7307c9..1b7f17080 100644 --- a/src/cunumeric.mk +++ b/src/cunumeric.mk @@ -44,6 +44,7 @@ GEN_CPU_SRC += cunumeric/ternary/where.cc \ cunumeric/matrix/matvecmul.cc \ cunumeric/matrix/dot.cc \ cunumeric/matrix/potrf.cc \ + cunumeric/matrix/solve.cc \ cunumeric/matrix/syrk.cc \ cunumeric/matrix/tile.cc \ cunumeric/matrix/transpose.cc \ @@ -92,6 +93,7 @@ GEN_CPU_SRC += cunumeric/ternary/where_omp.cc \ cunumeric/matrix/matvecmul_omp.cc \ cunumeric/matrix/dot_omp.cc \ cunumeric/matrix/potrf_omp.cc \ + cunumeric/matrix/solve_omp.cc \ cunumeric/matrix/syrk_omp.cc \ cunumeric/matrix/tile_omp.cc \ cunumeric/matrix/transpose_omp.cc \ @@ -136,6 +138,7 @@ GEN_GPU_SRC += cunumeric/ternary/where.cu \ cunumeric/matrix/matvecmul.cu \ cunumeric/matrix/dot.cu \ cunumeric/matrix/potrf.cu \ + cunumeric/matrix/solve.cu \ cunumeric/matrix/syrk.cu \ cunumeric/matrix/tile.cu \ cunumeric/matrix/transpose.cu \ diff --git a/src/cunumeric/cunumeric_c.h b/src/cunumeric/cunumeric_c.h index cdf382a0a..31b8ba50f 100644 --- a/src/cunumeric/cunumeric_c.h +++ b/src/cunumeric/cunumeric_c.h @@ -58,6 +58,7 @@ enum CuNumericOpCode { CUNUMERIC_REPEAT, CUNUMERIC_SCALAR_UNARY_RED, CUNUMERIC_SEARCHSORTED, + CUNUMERIC_SOLVE, CUNUMERIC_SORT, CUNUMERIC_SYRK, CUNUMERIC_TILE, diff --git a/src/cunumeric/mapper.cc b/src/cunumeric/mapper.cc index 8cfc6551a..855121cd2 100644 --- a/src/cunumeric/mapper.cc +++ b/src/cunumeric/mapper.cc @@ -133,6 +133,7 @@ std::vector CuNumericMapper::store_mappings( } case CUNUMERIC_POTRF: case CUNUMERIC_TRSM: + case CUNUMERIC_SOLVE: case CUNUMERIC_SYRK: case CUNUMERIC_GEMM: { std::vector mappings; diff --git a/src/cunumeric/matrix/solve.cc b/src/cunumeric/matrix/solve.cc new file mode 100644 index 000000000..89681caef --- /dev/null +++ b/src/cunumeric/matrix/solve.cc @@ -0,0 +1,41 @@ +/* Copyright 2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/matrix/solve.h" +#include "cunumeric/matrix/solve_template.inl" +#include "cunumeric/matrix/solve_cpu.inl" + +namespace cunumeric { + +using namespace Legion; +using namespace legate; + +/*static*/ const char* SolveTask::ERROR_MESSAGE = "Singular matrix"; + +/*static*/ void SolveTask::cpu_variant(TaskContext& context) +{ +#ifdef LEGATE_USE_OPENMP + openblas_set_num_threads(1); // make sure this isn't overzealous +#endif + solve_template(context); +} + +namespace // unnamed +{ +static void __attribute__((constructor)) register_tasks(void) { SolveTask::register_variants(); } +} // namespace + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/solve.cu b/src/cunumeric/matrix/solve.cu new file mode 100644 index 000000000..7fb3ab20d --- /dev/null +++ b/src/cunumeric/matrix/solve.cu @@ -0,0 +1,117 @@ +/* Copyright 2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/matrix/solve.h" +#include "cunumeric/matrix/solve_template.inl" + +#include "cunumeric/cuda_help.h" + +namespace cunumeric { + +using namespace Legion; +using namespace legate; + +template +static inline void solve_template(GetrfBufferSize getrf_buffer_size, + Getrf getrf, + Getrs getrs, + int32_t m, + int32_t n, + int32_t nrhs, + VAL* a, + VAL* b) +{ + const auto trans = CUBLAS_OP_N; + + auto handle = get_cusolver(); + auto stream = get_cached_stream(); + CHECK_CUSOLVER(cusolverDnSetStream(handle, stream)); + + int32_t buffer_size; + CHECK_CUSOLVER(getrf_buffer_size(handle, m, n, a, m, &buffer_size)); + + auto ipiv = create_buffer(std::min(m, n), Memory::Kind::GPU_FB_MEM); + auto buffer = create_buffer(buffer_size, Memory::Kind::GPU_FB_MEM); + auto info = create_buffer(1, Memory::Kind::Z_COPY_MEM); + + CHECK_CUSOLVER(getrf(handle, m, n, a, m, buffer.ptr(0), ipiv.ptr(0), info.ptr(0))); + CHECK_CUDA(cudaStreamSynchronize(stream)); + + if (info[0] != 0) throw legate::TaskException(SolveTask::ERROR_MESSAGE); + + CHECK_CUSOLVER(getrs(handle, trans, n, nrhs, a, m, ipiv.ptr(0), b, n, info.ptr(0))); + + CHECK_CUDA_STREAM(stream); + +#ifdef DEBUG_CUNUMERIC + assert(info[0] == 0); +#endif +} + +template <> +struct SolveImplBody { + void operator()(int32_t m, int32_t n, int32_t nrhs, float* a, float* b) + { + solve_template( + cusolverDnSgetrf_bufferSize, cusolverDnSgetrf, cusolverDnSgetrs, m, n, nrhs, a, b); + } +}; + +template <> +struct SolveImplBody { + void operator()(int32_t m, int32_t n, int32_t nrhs, double* a, double* b) + { + solve_template( + cusolverDnDgetrf_bufferSize, cusolverDnDgetrf, cusolverDnDgetrs, m, n, nrhs, a, b); + } +}; + +template <> +struct SolveImplBody { + void operator()(int32_t m, int32_t n, int32_t nrhs, complex* a, complex* b) + { + solve_template(cusolverDnCgetrf_bufferSize, + cusolverDnCgetrf, + cusolverDnCgetrs, + m, + n, + nrhs, + reinterpret_cast(a), + reinterpret_cast(b)); + } +}; + +template <> +struct SolveImplBody { + void operator()(int32_t m, int32_t n, int32_t nrhs, complex* a, complex* b) + { + solve_template(cusolverDnZgetrf_bufferSize, + cusolverDnZgetrf, + cusolverDnZgetrs, + m, + n, + nrhs, + reinterpret_cast(a), + reinterpret_cast(b)); + } +}; + +/*static*/ void SolveTask::gpu_variant(TaskContext& context) +{ + solve_template(context); +} + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/solve.h b/src/cunumeric/matrix/solve.h new file mode 100644 index 000000000..8cb6835ad --- /dev/null +++ b/src/cunumeric/matrix/solve.h @@ -0,0 +1,38 @@ +/* Copyright 2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include "cunumeric/cunumeric.h" + +namespace cunumeric { + +class SolveTask : public CuNumericTask { + public: + static const int TASK_ID = CUNUMERIC_SOLVE; + static const char* ERROR_MESSAGE; + + public: + static void cpu_variant(legate::TaskContext& context); +#ifdef LEGATE_USE_OPENMP + static void omp_variant(legate::TaskContext& context); +#endif +#ifdef LEGATE_USE_CUDA + static void gpu_variant(legate::TaskContext& context); +#endif +}; + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/solve_cpu.inl b/src/cunumeric/matrix/solve_cpu.inl new file mode 100644 index 000000000..98cba89aa --- /dev/null +++ b/src/cunumeric/matrix/solve_cpu.inl @@ -0,0 +1,94 @@ +/* Copyright 2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include +#include + +namespace cunumeric { + +using namespace Legion; +using namespace legate; + +template +Memory::Kind get_memory_kind() +{ + if constexpr (KIND == VariantKind::OMP) + return CuNumeric::has_numamem ? Memory::Kind::SOCKET_MEM : Memory::Kind::SYSTEM_MEM; + else + return Memory::Kind::SYSTEM_MEM; +} + +template +struct SolveImplBody { + void operator()(int32_t m, int32_t n, int32_t nrhs, float* a, float* b) + { + auto ipiv = create_buffer(std::min(m, n), get_memory_kind()); + + int32_t info = 0; + LAPACK_sgesv(&n, &nrhs, a, &m, ipiv.ptr(0), b, &n, &info); + + if (info != 0) throw legate::TaskException(SolveTask::ERROR_MESSAGE); + } +}; + +template +struct SolveImplBody { + void operator()(int32_t m, int32_t n, int32_t nrhs, double* a, double* b) + { + auto ipiv = create_buffer(std::min(m, n), get_memory_kind()); + + int32_t info = 0; + LAPACK_dgesv(&n, &nrhs, a, &m, ipiv.ptr(0), b, &n, &info); + + if (info != 0) throw legate::TaskException(SolveTask::ERROR_MESSAGE); + } +}; + +template +struct SolveImplBody { + void operator()(int32_t m, int32_t n, int32_t nrhs, complex* a_, complex* b_) + { + auto ipiv = create_buffer(std::min(m, n), get_memory_kind()); + + auto a = reinterpret_cast<__complex__ float*>(a_); + auto b = reinterpret_cast<__complex__ float*>(b_); + + int32_t info = 0; + LAPACK_cgesv(&n, &nrhs, a, &m, ipiv.ptr(0), b, &n, &info); + + if (info != 0) throw legate::TaskException(SolveTask::ERROR_MESSAGE); + } +}; + +template +struct SolveImplBody { + void operator()(int32_t m, int32_t n, int32_t nrhs, complex* a_, complex* b_) + { + auto ipiv = create_buffer(std::min(m, n), get_memory_kind()); + + auto a = reinterpret_cast<__complex__ double*>(a_); + auto b = reinterpret_cast<__complex__ double*>(b_); + + int32_t info = 0; + LAPACK_zgesv(&n, &nrhs, a, &m, ipiv.ptr(0), b, &n, &info); + + if (info != 0) throw legate::TaskException(SolveTask::ERROR_MESSAGE); + } +}; + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/solve_omp.cc b/src/cunumeric/matrix/solve_omp.cc new file mode 100644 index 000000000..57e14fdb4 --- /dev/null +++ b/src/cunumeric/matrix/solve_omp.cc @@ -0,0 +1,31 @@ +/* Copyright 2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/matrix/solve.h" +#include "cunumeric/matrix/solve_template.inl" +#include "cunumeric/matrix/solve_cpu.inl" + +#include + +namespace cunumeric { + +/*static*/ void SolveTask::omp_variant(TaskContext& context) +{ + openblas_set_num_threads(omp_get_max_threads()); + solve_template(context); +} + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/solve_template.inl b/src/cunumeric/matrix/solve_template.inl new file mode 100644 index 000000000..13b1167ff --- /dev/null +++ b/src/cunumeric/matrix/solve_template.inl @@ -0,0 +1,121 @@ +/* Copyright 2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include + +#include "core/comm/coll.h" + +// Useful for IDEs +#include "cunumeric/matrix/solve.h" + +namespace cunumeric { + +using namespace Legion; +using namespace legate; + +template +struct SolveImplBody; + +template +struct support_solve : std::false_type { +}; +template <> +struct support_solve : std::true_type { +}; +template <> +struct support_solve : std::true_type { +}; +template <> +struct support_solve : std::true_type { +}; +template <> +struct support_solve : std::true_type { +}; + +template +struct SolveImpl { + template ::value>* = nullptr> + void operator()(Array& a_array, Array& b_array) const + { + using VAL = legate_type_of; + +#ifdef DEBUG_CUNUMERIC + assert(a_array.dim() == 2); + assert(b_array.dim() == 1 || b_array.dim() == 2); +#endif + const auto a_shape = a_array.shape<2>(); + + const int64_t m = a_shape.hi[0] - a_shape.lo[0] + 1; + const int64_t n = a_shape.hi[1] - a_shape.lo[1] + 1; + +#ifdef DEBUG_CUNUMERIC + // The Python code guarantees this property + assert(m == n); +#endif + + size_t a_strides[2]; + VAL* a = a_array.read_write_accessor(a_shape).ptr(a_shape, a_strides); +#ifdef DEBUG_CUNUMERIC + assert(a_strides[0] == 1 && a_strides[1] == m); +#endif + VAL* b = nullptr; + + int64_t nrhs = 1; + if (b_array.dim() == 1) { + const auto b_shape = b_array.shape<1>(); +#ifdef DEBUG_CUNUMERIC + assert(m == b_shape.hi[0] - b_shape.lo[0] + 1); +#endif + size_t b_strides; + b = b_array.read_write_accessor(b_shape).ptr(b_shape, &b_strides); + } else { + const auto b_shape = b_array.shape<2>(); +#ifdef DEBUG_CUNUMERIC + assert(m == b_shape.hi[0] - b_shape.lo[0] + 1); +#endif + nrhs = b_shape.hi[1] - b_shape.lo[1] + 1; + size_t b_strides[2]; + b = b_array.read_write_accessor(b_shape).ptr(b_shape, b_strides); +#ifdef DEBUG_CUNUMERIC + assert(b_strides[0] == 1 && b_strides[1] == m); +#endif + } + +#ifdef DEBUG_CUNUMERIC + assert(m > 0 && n > 0 && nrhs > 0); +#endif + + SolveImplBody()(m, n, nrhs, a, b); + } + + template ::value>* = nullptr> + void operator()(Array& a_array, Array& b_array) const + { + assert(false); + } +}; + +template +static void solve_template(TaskContext& context) +{ + auto& a_array = context.outputs()[0]; + auto& b_array = context.outputs()[1]; + type_dispatch(a_array.code(), SolveImpl{}, a_array, b_array); +} + +} // namespace cunumeric diff --git a/tests/integration/test_array_dunders.py b/tests/integration/test_array_dunders.py index a673e76ce..42b2a6ec2 100644 --- a/tests/integration/test_array_dunders.py +++ b/tests/integration/test_array_dunders.py @@ -33,8 +33,8 @@ def test_array_function_implemented(): def test_array_function_unimplemented(): - np_res = np.linalg.solve(np_arr, np_vec) - cn_res = np.linalg.solve(cn_arr, cn_vec) + np_res = np.linalg.tensorsolve(np_arr, np_vec) + cn_res = np.linalg.tensorsolve(cn_arr, cn_vec) assert np.array_equal(np_res, cn_res) assert isinstance(cn_res, np.ndarray) # unimplemented diff --git a/tests/integration/test_solve.py b/tests/integration/test_solve.py new file mode 100644 index 000000000..c14210065 --- /dev/null +++ b/tests/integration/test_solve.py @@ -0,0 +1,71 @@ +# Copyright 2022 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from utils.comparisons import allclose + +import cunumeric as num + +SIZES = (8, 9, 255) + +RTOL = { + np.dtype("f"): 1e-1, + np.dtype("F"): 1e-1, + np.dtype("d"): 1e-5, + np.dtype("D"): 1e-5, +} + +ATOL = { + np.dtype("f"): 1e-3, + np.dtype("F"): 1e-3, + np.dtype("d"): 1e-8, + np.dtype("D"): 1e-8, +} + + +@pytest.mark.parametrize("n", SIZES) +@pytest.mark.parametrize("a_dtype", ("f", "d", "F", "D")) +@pytest.mark.parametrize("b_dtype", ("f", "d", "F", "D")) +def test_solve_1d(n, a_dtype, b_dtype): + a = np.random.rand(n, n).astype(a_dtype) + b = np.random.rand(n).astype(b_dtype) + + out = num.linalg.solve(a, b) + + rtol = RTOL[out.dtype] + atol = ATOL[out.dtype] + assert allclose(b, num.matmul(a, out), rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("n", SIZES) +@pytest.mark.parametrize("a_dtype", ("f", "d", "F", "D")) +@pytest.mark.parametrize("b_dtype", ("f", "d", "F", "D")) +def test_solve_2d(n, a_dtype, b_dtype): + a = np.random.rand(n, n).astype(a_dtype) + b = np.random.rand(n, n + 2).astype(b_dtype) + + out = num.linalg.solve(a, b) + + rtol = RTOL[out.dtype] + atol = ATOL[out.dtype] + assert allclose(b, num.matmul(a, out), rtol=rtol, atol=atol) + + +if __name__ == "__main__": + import sys + + np.random.seed(12345) + sys.exit(pytest.main(sys.argv)) diff --git a/tests/unit/cunumeric/test_config.py b/tests/unit/cunumeric/test_config.py index e8c79159e..ece34d62f 100644 --- a/tests/unit/cunumeric/test_config.py +++ b/tests/unit/cunumeric/test_config.py @@ -160,6 +160,7 @@ def test_CuNumericOpCode() -> None: "SCALAR_UNARY_RED", "SCAN_GLOBAL", "SCAN_LOCAL", + "SOLVE", "SORT", "SEARCHSORTED", "SYRK",