Skip to content

Commit

Permalink
Merge pull request #22 from jorenham/fix-scipy.linalg
Browse files Browse the repository at this point in the history
fix `scipy.linalg` stubtests
  • Loading branch information
jorenham authored Sep 5, 2024
2 parents 83c1287 + f23bcff commit 80cbe7d
Show file tree
Hide file tree
Showing 17 changed files with 2,002 additions and 230 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pip install scipy-stubs
| `scipy.io` | 2: partial |
| `scipy.io.arff` | 2: partial |
| `scipy.io.matlab` | 2: partial |
| `scipy.linalg` | 3: ready |
| `scipy.linalg` | **4: done** |
| `scipy.misc` | 0: missing |
| `scipy.ndimage` | 2: partial |
| `scipy.odr` | 1: skeleton |
Expand Down
161 changes: 116 additions & 45 deletions scipy-stubs/linalg/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,26 +1,4 @@
from . import (
_basic,
_cythonized_array_utils,
_decomp,
_decomp_cholesky,
_decomp_cossin,
_decomp_ldl,
_decomp_lu,
_decomp_polar,
_decomp_qr,
_decomp_qz,
_decomp_schur,
_decomp_svd,
_decomp_update,
_matfuncs,
_misc,
_procrustes,
_sketches,
_solvers,
_special_matrices,
blas,
lapack,
)
from . import blas, cython_blas, cython_lapack, lapack
from ._basic import *
from ._cythonized_array_utils import *
from ._decomp import *
Expand All @@ -43,25 +21,118 @@ from ._special_matrices import *
from .blas import *
from .lapack import *

__all__: list[str] = []
__all__ += _basic.__all__
__all__ += _cythonized_array_utils.__all__
__all__ += _decomp.__all__
__all__ += _decomp_cholesky.__all__
__all__ += _decomp_cossin.__all__
__all__ += _decomp_ldl.__all__
__all__ += _decomp_lu.__all__
__all__ += _decomp_polar.__all__
__all__ += _decomp_qr.__all__
__all__ += _decomp_qz.__all__
__all__ += _decomp_schur.__all__
__all__ += _decomp_svd.__all__
__all__ += _decomp_update.__all__
__all__ += _matfuncs.__all__
__all__ += _misc.__all__
__all__ += _procrustes.__all__
__all__ += _sketches.__all__
__all__ += _solvers.__all__
__all__ += _special_matrices.__all__
__all__ += blas.__all__
__all__ += lapack.__all__
__all__ = [
"LinAlgError",
"LinAlgWarning",
"bandwidth",
# "basic",
"blas",
"block_diag",
"cdf2rdf",
"cho_factor",
"cho_solve",
"cho_solve_banded",
"cholesky",
"cholesky_banded",
"circulant",
"clarkson_woodruff_transform",
"companion",
"convolution_matrix",
"coshm",
"cosm",
"cossin",
"cython_blas",
"cython_lapack",
# "decomp",
# "decomp_cholesky",
# "decomp_lu",
# "decomp_qr",
# "decomp_schur",
# "decomp_svd",
"det",
"dft",
"diagsvd",
"eig",
"eig_banded",
"eigh",
"eigh_tridiagonal",
"eigvals",
"eigvals_banded",
"eigvalsh",
"eigvalsh_tridiagonal",
"expm",
"expm_cond",
"expm_frechet",
"fiedler",
"fiedler_companion",
"find_best_blas_type",
"fractional_matrix_power",
"funm",
"get_blas_funcs",
"get_lapack_funcs",
"hadamard",
"hankel",
"helmert",
"hessenberg",
"hilbert",
"inv",
"invhilbert",
"invpascal",
"ishermitian",
"issymmetric",
"khatri_rao",
"kron",
"lapack",
"ldl",
"leslie",
"logm",
"lstsq",
"lu",
"lu_factor",
"lu_solve",
# "matfuncs",
"matmul_toeplitz",
"matrix_balance",
# "misc",
"norm",
"null_space",
"ordqz",
"orth",
"orthogonal_procrustes",
"pascal",
"pinv",
"pinvh",
"polar",
"qr",
"qr_delete",
"qr_insert",
"qr_multiply",
"qr_update",
"qz",
"rq",
"rsf2csf",
"schur",
"signm",
"sinhm",
"sinm",
"solve",
"solve_banded",
"solve_circulant",
"solve_continuous_are",
"solve_continuous_lyapunov",
"solve_discrete_are",
"solve_discrete_lyapunov",
"solve_lyapunov",
"solve_sylvester",
"solve_toeplitz",
"solve_triangular",
"solveh_banded",
# "special_matrices",
"sqrtm",
"subspace_angles",
"svd",
"svdvals",
"tanhm",
"tanm",
"toeplitz",
]
20 changes: 2 additions & 18 deletions scipy-stubs/linalg/_basic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,7 @@ def solve(
overwrite_a: bool = False,
overwrite_b: bool = False,
check_finite: bool = True,
assume_a: Literal[
"diagonal",
"tridiagonal",
"upper triangular",
"lower triangular",
"sym",
"symmetric",
"her",
"hermitian",
"positive definite",
"pos",
"general",
"gen",
]
| None = None,
assume_a: Literal["gen", "sym", "her", "pos"] = "gen",
transposed: bool = False,
) -> _Array_fc_2d: ...
def solve_triangular(
Expand Down Expand Up @@ -150,7 +136,6 @@ def pinvh(
rtol: spt.AnyReal | None,
lower: bool,
return_rank: Literal[True],
/,
check_finite: bool = True,
) -> tuple[_Array_fc_2d, int]: ...
@overload
Expand All @@ -176,12 +161,11 @@ def matrix_balance(
permute: bool,
scale: bool,
separate: Literal[True],
/,
overwrite_a: bool = False,
) -> tuple[_Array_fc_2d, tuple[_Array_fc_1d, _Array_fc_1d]]: ...
def matmul_toeplitz(
c_or_cr: npt.ArrayLike | tuple[npt.ArrayLike, npt.ArrayLike],
x: npt.ArrayLike,
check_finite: bool = True,
check_finite: bool = False,
workers: int | None = None,
) -> _Array_fc_1d | _Array_fc_2d: ...
8 changes: 4 additions & 4 deletions scipy-stubs/linalg/_cythonized_array_utils.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import numpy.typing as npt
__all__ = ["bandwidth", "ishermitian", "issymmetric"]

# see `scipy/linalg/_cythonized_array_utils.pxd`
numeric_t: TypeAlias = ( # noqa: PYI042
_Numeric: TypeAlias = (
np.int8
| np.int16
| np.int32
Expand All @@ -22,6 +22,6 @@ numeric_t: TypeAlias = ( # noqa: PYI042
| np.complex128
)

def bandwidth(a: npt.NDArray[numeric_t]) -> tuple[int, int]: ...
def issymmetric(a: npt.NDArray[numeric_t], atol: float | None = None, rtol: float | None = None) -> bool: ...
def ishermitian(a: npt.NDArray[numeric_t], atol: float | None = None, rtol: float | None = None) -> bool: ...
def bandwidth(a: npt.NDArray[_Numeric]) -> tuple[int, int]: ...
def issymmetric(a: npt.NDArray[_Numeric], atol: float | None = None, rtol: float | None = None) -> bool: ...
def ishermitian(a: npt.NDArray[_Numeric], atol: float | None = None, rtol: float | None = None) -> bool: ...
6 changes: 1 addition & 5 deletions scipy-stubs/linalg/_decomp.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def eig(
a: npt.ArrayLike,
b: npt.ArrayLike | None,
left: Literal[True],
/,
right: Literal[False] = False,
overwrite_a: bool = False,
overwrite_b: bool = False,
Expand All @@ -90,7 +89,6 @@ def eig(
b: npt.ArrayLike | None,
left: Literal[False],
right: Literal[True],
/,
overwrite_a: bool = False,
overwrite_b: bool = False,
check_finite: bool = True,
Expand All @@ -114,7 +112,6 @@ def eig(
b: npt.ArrayLike | None,
left: Literal[True],
right: Literal[True],
/,
overwrite_a: bool = False,
overwrite_b: bool = False,
check_finite: bool = True,
Expand Down Expand Up @@ -178,7 +175,6 @@ def eig_banded(
a_band: npt.ArrayLike,
lower: bool,
eigvals_only: Literal[True],
/,
overwrite_a_band: bool = False,
select: _EigSelect = "a",
select_range: _EigSelectRange | None = None,
Expand Down Expand Up @@ -207,7 +203,7 @@ def eigvalsh(
) -> _Array_fc_1d: ...
def eigvals_banded(
a_band: npt.ArrayLike,
lower: bool = True,
lower: bool = False,
overwrite_a_band: bool = False,
select: _EigSelect = "a",
select_range: _EigSelectRange | None = None,
Expand Down
38 changes: 24 additions & 14 deletions scipy-stubs/linalg/_decomp_lu.pyi
Original file line number Diff line number Diff line change
@@ -1,34 +1,53 @@
from typing import Literal, TypeAlias, overload
from collections.abc import Sequence
from typing import Any, Literal, TypeAlias, overload

import numpy as np
import numpy.typing as npt
import optype.numpy as onpt

__all__ = ["lu", "lu_factor", "lu_solve"]

_ArrayLike_2d_fc: TypeAlias = onpt.AnyNumberArray | Sequence[Sequence[complex | np.number[Any]]]
_Array_i: TypeAlias = np.ndarray[tuple[int, ...], np.dtype[np.intp]]
_Array_fc: TypeAlias = np.ndarray[tuple[int, ...], np.dtype[np.inexact[npt.NBitBase]]]
_Array_fc_1d: TypeAlias = np.ndarray[tuple[int], np.dtype[np.inexact[npt.NBitBase]]]
_Array_fc_2d: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.inexact[npt.NBitBase]]]

def lu_factor(a: npt.ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> tuple[_Array_fc_2d, _Array_fc_1d]: ...
def lu_factor(
a: _ArrayLike_2d_fc,
overwrite_a: bool = False,
check_finite: bool = True,
) -> tuple[_Array_fc_2d, _Array_fc_1d]: ...

#
def lu_solve(
lu_and_piv: tuple[_Array_fc_2d, _Array_fc_1d],
b: npt.ArrayLike,
trans: Literal[0, 1, 2] = 0,
overwrite_b: bool = False,
check_finite: bool = True,
) -> _Array_fc_2d: ...

#
@overload
def lu(
a: npt.ArrayLike,
a: _ArrayLike_2d_fc,
permute_l: Literal[False, 0] = False,
overwrite_a: bool = False,
check_finite: bool = True,
p_indices: Literal[False] = False,
) -> tuple[_Array_fc, _Array_fc, _Array_fc]: ...
@overload
def lu(
a: npt.ArrayLike,
a: _ArrayLike_2d_fc,
permute_l: Literal[False],
overwrite_a: bool,
check_finite: bool,
p_indices: Literal[True],
) -> tuple[_Array_i, _Array_fc, _Array_fc]: ...
@overload
def lu(
a: _ArrayLike_2d_fc,
permute_l: Literal[False, 0] = False,
overwrite_a: bool = False,
check_finite: bool = True,
Expand All @@ -37,16 +56,7 @@ def lu(
) -> tuple[_Array_i, _Array_fc, _Array_fc]: ...
@overload
def lu(
a: npt.ArrayLike,
permute_l: Literal[False],
overwrite_a: bool,
check_finite: bool,
p_indices: Literal[True],
/,
) -> tuple[_Array_i, _Array_fc, _Array_fc]: ...
@overload
def lu(
a: npt.ArrayLike,
a: _ArrayLike_2d_fc,
permute_l: Literal[True],
overwrite_a: bool = False,
check_finite: bool = True,
Expand Down
10 changes: 6 additions & 4 deletions scipy-stubs/linalg/_decomp_lu_cython.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ from typing import TypeVar
import numpy as np
import numpy.typing as npt

__all__ = ["lu_decompose", "lu_dispatcher"]

# this name was chosen to match `ctypedef fused lapack_t`
_LapackT = TypeVar("_LapackT", bound=np.float32 | np.float64 | np.complex64 | np.complex128)

def lu_decompose(a: npt.NDArray[_LapackT], lu: npt.NDArray[_LapackT], perm: npt.NDArray[np.int_], permute_l: bool) -> None: ...
def lu_dispatcher(a: npt.NDArray[_LapackT], lu: npt.NDArray[_LapackT], perm: npt.NDArray[np.int_], permute_l: bool) -> None: ...
def lu_dispatcher(
a: npt.NDArray[_LapackT],
u: npt.NDArray[_LapackT],
piv: npt.NDArray[np.int32 | np.int64],
permute_l: bool,
) -> None: ...
Loading

0 comments on commit 80cbe7d

Please sign in to comment.