Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e9ececb
support for dpnp.scipy.linalg.lu()
abagusetty Feb 27, 2026
206f4ee
Fix the test in test_sycl_queue.py
abagusetty Feb 27, 2026
7a6b5bd
Fix black formatting
abagusetty Feb 28, 2026
1f6c928
fix formatting and apply plint
abagusetty Mar 1, 2026
67314b7
Update dpnp/scipy/linalg/_utils.py
abagusetty Mar 2, 2026
384618b
Update dpnp/scipy/linalg/_utils.py
abagusetty Mar 2, 2026
5b83946
udapte CHANGELOG
abagusetty Mar 2, 2026
7161aef
address comments from PR
abagusetty Mar 2, 2026
4626b54
Merge branch 'master' into linalg_lu
abagusetty Mar 2, 2026
ac8117b
Fix pylint multiple returns
abagusetty Mar 2, 2026
f895d2d
Merge branch 'linalg_lu' of https://github.com/abagusetty/dpnp into l…
abagusetty Mar 2, 2026
3033c0b
Fix the test failing in onemkl tests
abagusetty Mar 2, 2026
cdb5e32
one more try..
abagusetty Mar 2, 2026
d3dc5da
fix one more...
abagusetty Mar 2, 2026
a0f9182
Merge branch 'master' into linalg_lu
abagusetty Mar 2, 2026
6e94734
Update dpnp/tests/test_linalg.py
abagusetty Mar 3, 2026
aa985e1
Update dpnp/tests/test_linalg.py
abagusetty Mar 3, 2026
05e3a0f
Update dpnp/tests/test_linalg.py
abagusetty Mar 3, 2026
c7a117a
Update CHANGELOG.md
abagusetty Mar 3, 2026
4bc7226
Update dpnp/scipy/linalg/_utils.py
abagusetty Mar 3, 2026
643f34d
Update dpnp/tests/test_linalg.py
abagusetty Mar 3, 2026
28f26eb
Update dpnp/tests/test_linalg.py
abagusetty Mar 3, 2026
35a6c6e
Merge branch 'master' into linalg_lu
abagusetty Mar 3, 2026
aa00062
Update dpnp/tests/test_linalg.py
abagusetty Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
* Added implementation of `dpnp.ndarray.__bytes__` method [#2671](https://github.com/IntelPython/dpnp/pull/2671)
* Added implementation of `dpnp.divmod` [#2674](https://github.com/IntelPython/dpnp/pull/2674)
* Added implementation of `dpnp.isin` function [#2595](https://github.com/IntelPython/dpnp/pull/2595)
* Added implementation of `dpnp.scipy.linalg.lu` (SciPy-compatible) [#2787](https://github.com/IntelPython/dpnp/pull/2787)

### Changed

Expand Down
3 changes: 2 additions & 1 deletion dpnp/scipy/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@
"""

from ._decomp_lu import lu_factor, lu_solve
from ._decomp_lu import lu, lu_factor, lu_solve

__all__ = [
"lu",
"lu_factor",
"lu_solve",
]
153 changes: 148 additions & 5 deletions dpnp/scipy/linalg/_decomp_lu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,154 @@
)

from ._utils import (
dpnp_lu,
dpnp_lu_factor,
dpnp_lu_solve,
)


def lu(
a, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False
):
"""
Compute LU decomposition of a matrix with partial pivoting.

The decomposition satisfies::

A = P @ L @ U

where `P` is a permutation matrix, `L` is lower triangular with unit
diagonal elements, and `U` is upper triangular. If `permute_l` is set to
``True`` then `L` is returned already permuted and hence satisfying
``A = L @ U``.

For full documentation refer to :obj:`scipy.linalg.lu`.

Parameters
----------
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
Input array to decompose.
permute_l : bool, optional
Perform the multiplication ``P @ L`` (Default: do not permute).

Default: ``False``.
overwrite_a : {None, bool}, optional
Whether to overwrite data in `a` (may increase performance).

Default: ``False``.
check_finite : {None, bool}, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.

Default: ``True``.
p_indices : bool, optional
If ``True`` the permutation information is returned as row indices
instead of a permutation matrix.

Default: ``False``.

Returns
-------
**(If ``permute_l`` is ``False``)**

p : (..., M, M) dpnp.ndarray or (..., M) dpnp.ndarray
If `p_indices` is ``False`` (default), the permutation matrix.
The permutation matrix always has a real dtype (``float32`` or
``float64``) even when `a` is complex, since it only contains
0s and 1s.
If `p_indices` is ``True``, a 1-D (or batched) array of row
permutation indices such that ``A = L[p] @ U``.
l : (..., M, K) dpnp.ndarray
Lower triangular or trapezoidal matrix with unit diagonal.
``K = min(M, N)``.
u : (..., K, N) dpnp.ndarray
Upper triangular or trapezoidal matrix.

**(If ``permute_l`` is ``True``)**

pl : (..., M, K) dpnp.ndarray
Permuted ``L`` matrix: ``pl = P @ L``.
``K = min(M, N)``.
u : (..., K, N) dpnp.ndarray
Upper triangular or trapezoidal matrix.

Notes
-----
Permutation matrices are costly since they are nothing but row reorder of
``L`` and hence indices are strongly recommended to be used instead if the
permutation is required. The relation in the 2D case then becomes simply
``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l`
to avoid complicated indexing tricks.

In the 2D case, if one has the indices however, for some reason, the
permutation matrix is still needed then it can be constructed by
``dpnp.eye(M)[P, :]``.

Warning
-------
This function synchronizes in order to validate array elements
when ``check_finite=True``, and also synchronizes to compute the
permutation from LAPACK pivot indices.

See Also
--------
:obj:`dpnp.scipy.linalg.lu_factor` : LU factorize a matrix
(compact representation).
:obj:`dpnp.scipy.linalg.lu_solve` : Solve an equation system using
the LU factorization of a matrix.

Examples
--------
>>> import dpnp as np
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8],
... [7, 5, 6, 6], [5, 4, 4, 8]])
>>> p, l, u = np.scipy.linalg.lu(A)
>>> np.allclose(A, p @ l @ u)
array(True)

Retrieve the permutation as row indices with ``p_indices=True``:

>>> p, l, u = np.scipy.linalg.lu(A, p_indices=True)
>>> p
array([1, 3, 0, 2])
>>> np.allclose(A, l[p] @ u)
array(True)

Return the permuted ``L`` directly with ``permute_l=True``:

>>> pl, u = np.scipy.linalg.lu(A, permute_l=True)
>>> np.allclose(A, pl @ u)
array(True)

Non-square matrices are supported:

>>> B = np.array([[1, 2, 3], [4, 5, 6]])
>>> p, l, u = np.scipy.linalg.lu(B)
>>> np.allclose(B, p @ l @ u)
array(True)

Batched input:

>>> C = np.random.randn(3, 2, 4, 4)
>>> p, l, u = np.scipy.linalg.lu(C)
>>> np.allclose(C, p @ l @ u)
array(True)

"""

dpnp.check_supported_arrays_type(a)
assert_stacked_2d(a)

return dpnp_lu(
a,
overwrite_a=overwrite_a,
check_finite=check_finite,
p_indices=p_indices,
permute_l=permute_l,
)


def lu_factor(a, overwrite_a=False, check_finite=True):
"""
Compute the pivoted LU decomposition of `a` matrix.
Expand Down Expand Up @@ -180,13 +323,13 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):

"""

lu, piv = lu_and_piv
dpnp.check_supported_arrays_type(lu, piv, b)
assert_stacked_2d(lu)
assert_stacked_square(lu)
lu_matrix, piv = lu_and_piv
dpnp.check_supported_arrays_type(lu_matrix, piv, b)
assert_stacked_2d(lu_matrix)
assert_stacked_square(lu_matrix)

return dpnp_lu_solve(
lu,
lu_matrix,
piv,
b,
trans=trans,
Expand Down
Loading
Loading