• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

GeoStat-Framework / pentapy / 14784296932

01 May 2025 10:00PM UTC coverage: 79.365% (-1.0%) from 80.328%
14784296932

push

github

web-flow
Merge pull request #31 from GeoStat-Framework/py13_support

Py13 support

3 of 5 new or added lines in 1 file covered. (60.0%)

100 of 126 relevant lines covered (79.37%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.49
/src/pentapy/core.py
1
"""The core module of pentapy."""
2

3
# pylint: disable=C0103, C0415, R0911, E0611
4
import warnings
1✔
5

6
import numpy as np
1✔
7

8
from pentapy.solver import penta_solver1, penta_solver2
1✔
9
from pentapy.tools import _check_penta, create_banded, create_full, shift_banded
1✔
10

11

12
def solve(mat, rhs, is_flat=False, index_row_wise=True, solver=1):
1✔
13
    """
14
    Solver for a pentadiagonal system.
15

16
    The matrix can be given as a full n x n matrix or as a flattend one.
17
    The flattend matrix can be given in a row-wise flattend form::
18

19
      [[Dup2[0]  Dup2[1]  Dup2[2]  ... Dup2[N-2]  0          0       ]
20
       [Dup1[0]  Dup1[1]  Dup1[2]  ... Dup1[N-2]  Dup1[N-1]  0       ]
21
       [Diag[0]  Diag[1]  Diag[2]  ... Diag[N-2]  Diag[N-1]  Diag[N] ]
22
       [0        Dlow1[1] Dlow1[2] ... Dlow1[N-2] Dlow1[N-1] Dlow1[N]]
23
       [0        0        Dlow2[2] ... Dlow2[N-2] Dlow2[N-2] Dlow2[N]]]
24

25
    Or a column-wise flattend form::
26

27
      [[0        0        Dup2[2]  ... Dup2[N-2]  Dup2[N-1]  Dup2[N] ]
28
       [0        Dup1[1]  Dup1[2]  ... Dup1[N-2]  Dup1[N-1]  Dup1[N] ]
29
       [Diag[0]  Diag[1]  Diag[2]  ... Diag[N-2]  Diag[N-1]  Diag[N] ]
30
       [Dlow1[0] Dlow1[1] Dlow1[2] ... Dlow1[N-2] Dlow1[N-1] 0       ]
31
       [Dlow2[0] Dlow2[1] Dlow2[2] ... Dlow2[N-2] 0          0       ]]
32

33
    Dup1 and Dup2 are the first and second upper minor-diagonals
34
    and Dlow1 resp. Dlow2 are the lower ones.
35
    If you provide a column-wise flattend matrix, you have to set::
36

37
      index_row_wise=False
38

39

40
    Parameters
41
    ----------
42
    mat : :class:`numpy.ndarray`
43
        The Matrix or the flattened Version of the pentadiagonal matrix.
44
    rhs : :class:`numpy.ndarray`
45
        The right hand side of the equation system.
46
    is_flat : :class:`bool`, optional
47
        State if the matrix is already flattend. Default: ``False``
48
    index_row_wise : :class:`bool`, optional
49
        State if the flattend matrix is row-wise flattend. Default: ``True``
50
    solver : :class:`int` or :class:`str`, optional
51
        Which solver should be used. The following are provided:
52

53
            * ``[1, "1", "PTRANS-I"]`` : The PTRANS-I algorithm
54
            * ``[2, "2", "PTRANS-II"]`` : The PTRANS-II algorithm
55
            * ``[3, "3", "lapack", "solve_banded"]`` :
56
              scipy.linalg.solve_banded
57
            * ``[4, "4", "spsolve"]`` :
58
              The scipy sparse solver without umf_pack
59
            * ``[5, "5", "spsolve_umf", "umf", "umf_pack"]`` :
60
              The scipy sparse solver with umf_pack
61

62
        Default: ``1``
63

64
    Returns
65
    -------
66
    result : :class:`numpy.ndarray`
67
        Solution of the equation system
68
    """
69
    if solver in [1, "1", "PTRANS-I"]:
1✔
70
        if is_flat and index_row_wise:
1✔
71
            mat_flat = np.asarray(mat, dtype=np.double)
1✔
72
            _check_penta(mat_flat)
1✔
73
        elif is_flat:
1✔
74
            mat_flat = np.array(mat, dtype=np.double)
1✔
75
            _check_penta(mat_flat)
1✔
76
            shift_banded(mat_flat, copy=False)
1✔
77
        else:
78
            mat_flat = create_banded(mat, col_wise=False, dtype=np.double)
1✔
79
        rhs = np.asarray(rhs, dtype=np.double)
1✔
80
        # Special case: Early exit when the matrix has only 3 rows/columns
81
        # NOTE: this avoids memory leakage in the Cython-solver that will iterate over
82
        #       at least 4 rows/columns no matter what
83
        if mat_flat.shape[1] == 3:
1✔
NEW
84
            return np.linalg.solve(a=create_full(mat_flat, col_wise=False), b=rhs)
×
85
        try:
1✔
86
            return penta_solver1(mat_flat, rhs)
1✔
87
        except ZeroDivisionError:
×
88
            warnings.warn("pentapy: PTRANS-I not suitable for input-matrix.")
×
89
            return np.full_like(rhs, np.nan)
×
90
    elif solver in [2, "2", "PTRANS-II"]:
1✔
91
        if is_flat and index_row_wise:
1✔
92
            mat_flat = np.asarray(mat, dtype=np.double)
1✔
93
            _check_penta(mat_flat)
1✔
94
        elif is_flat:
1✔
95
            mat_flat = np.array(mat, dtype=np.double)
1✔
96
            _check_penta(mat_flat)
1✔
97
            shift_banded(mat_flat, copy=False)
1✔
98
        else:
99
            mat_flat = create_banded(mat, col_wise=False, dtype=np.double)
1✔
100
        rhs = np.asarray(rhs, dtype=np.double)
1✔
101
        # Special case: Early exit when the matrix has only 3 rows/columns
102
        # NOTE: this avoids memory leakage in the Cython-solver that will iterate over
103
        #       at least 4 rows/columns no matter what
104
        if mat_flat.shape[1] == 3:
1✔
NEW
105
            return np.linalg.solve(a=create_full(mat_flat, col_wise=False), b=rhs)
×
106
        try:
1✔
107
            return penta_solver2(mat_flat, rhs)
1✔
108
        except ZeroDivisionError:
×
109
            warnings.warn("pentapy: PTRANS-II not suitable for input-matrix.")
×
110
            return np.full_like(rhs, np.nan)
×
111
    elif solver in [3, "3", "lapack", "solve_banded"]:  # pragma: no cover
112
        try:
113
            from scipy.linalg import solve_banded
114
        except ImportError as imp_err:  # pragma: no cover
115
            msg = "pentapy.solve: scipy.linalg.solve_banded could not be imported"
116
            raise ValueError(msg) from imp_err
117
        if is_flat and index_row_wise:
118
            mat_flat = np.array(mat)
119
            _check_penta(mat_flat)
120
            shift_banded(mat_flat, col_to_row=False, copy=False)
121
        elif is_flat:
122
            mat_flat = np.asarray(mat)
123
        else:
124
            mat_flat = create_banded(mat)
125
        return solve_banded((2, 2), mat_flat, rhs)
126
    elif solver in [4, "4", "spsolve"]:  # pragma: no cover
127
        try:
128
            from scipy import sparse as sps
129
            from scipy.sparse.linalg import spsolve
130
        except ImportError as imp_err:
131
            msg = "pentapy.solve: scipy.sparse could not be imported"
132
            raise ValueError(msg) from imp_err
133
        if is_flat and index_row_wise:
134
            mat_flat = np.array(mat)
135
            _check_penta(mat_flat)
136
            shift_banded(mat_flat, col_to_row=False, copy=False)
137
        elif is_flat:
138
            mat_flat = np.asarray(mat)
139
        else:
140
            mat_flat = create_banded(mat)
141
        size = mat_flat.shape[1]
142
        M = sps.spdiags(mat_flat, [2, 1, 0, -1, -2], size, size, format="csc")
143
        return spsolve(M, rhs, use_umfpack=False)
144
    elif solver in [
145
        5,
146
        "5",
147
        "spsolve_umf",
148
        "umf",
149
        "umf_pack",
150
    ]:  # pragma: no cover
151
        try:
152
            from scipy import sparse as sps
153
            from scipy.sparse.linalg import spsolve
154
        except ImportError as imp_err:
155
            msg = "pentapy.solve: scipy.sparse could not be imported"
156
            raise ValueError(msg) from imp_err
157
        if is_flat and index_row_wise:
158
            mat_flat = np.array(mat)
159
            _check_penta(mat_flat)
160
            shift_banded(mat_flat, col_to_row=False, copy=False)
161
        elif is_flat:
162
            mat_flat = np.asarray(mat)
163
        else:
164
            mat_flat = create_banded(mat)
165
        size = mat_flat.shape[1]
166
        M = sps.spdiags(mat_flat, [2, 1, 0, -1, -2], size, size, format="csc")
167
        return spsolve(M, rhs, use_umfpack=True)
168
    else:  # pragma: no cover
169
        msg = f"pentapy.solve: unknown solver ({solver})"
170
        raise ValueError(msg)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc