13#ifndef SUPERLU_SOLVE_HH
14#define SUPERLU_SOLVE_HH
32template <
class Scalar>
45 template <
class Index>
47 std::vector<Index>
const& ridx,
48 std::vector<Index>
const& cidx,
49 std::vector<Scalar>
const& values)
50 : N(ridx.
size()), n(n_)
52 assert(cidx.size()==N && values.size()==N);
57 std::cout <<
"SuperLU" <<
" solver, n=" << n <<
", nnz=" << N << std::endl;
59 set_default_options(&options);
70 if (std::is_same_v<int,Index>)
74 std::vector<int> iRidx(ridx.size()), iCidx(cidx.size());
75 std::copy(begin(ridx),end(ridx),begin(iRidx));
76 std::copy(begin(cidx),end(cidx),begin(iCidx));
80 dCreate_CompCol_Matrix(&A, n, n, N, &Az[0], &Ai[0], &Ap[0], SLU_NC, SLU_D, SLU_GE);
83 if ( !(rhsb = doubleMalloc(n * nrhs)) ) ABORT(
"Malloc fails for rhsb[].");
84 if ( !(rhsx = doubleMalloc(n * nrhs)) ) ABORT(
"Malloc fails for rhsx[].");
85 dCreate_Dense_Matrix(&B, n, nrhs, rhsb, n, SLU_DN, SLU_D, SLU_GE);
86 dCreate_Dense_Matrix(&X, n, nrhs, rhsx, n, SLU_DN, SLU_D, SLU_GE);
87 xact = doubleMalloc(n * nrhs);
89 dGenXtrue(n, nrhs, xact, ldx);
90 dFillRHS(trans, nrhs, xact, ldx, &A, &B);
92 if ( !(etree = intMalloc(n)) ) ABORT(
"Malloc fails for etree[].");
93 if ( !(perm_r = intMalloc(n)) ) ABORT(
"Malloc fails for perm_r[].");
94 if ( !(perm_c = intMalloc(n)) ) ABORT(
"Malloc fails for perm_c[].");
95 if ( !(R = (
double *) SUPERLU_MALLOC(A.nrow *
sizeof(
double))) )
96 ABORT(
"SUPERLU_MALLOC fails for R[].");
97 if ( !(C = (
double *) SUPERLU_MALLOC(A.ncol *
sizeof(
double))) )
98 ABORT(
"SUPERLU_MALLOC fails for C[].");
100 if ( !(ferr = (
double *) SUPERLU_MALLOC(nrhs *
sizeof(
double))) )
101 ABORT(
"SUPERLU_MALLOC fails for ferr[].");
102 if ( !(berr = (
double *) SUPERLU_MALLOC(nrhs *
sizeof(
double))) )
103 ABORT(
"SUPERLU_MALLOC fails for berr[].");
105 options.Equil = equil;
106 options.DiagPivotThresh = u;
107 options.Trans = trans;
110 dgssvx(&options, &A, perm_c, perm_r, etree, equed, R, C,
111 &L, &U, work, lwork, &B, &X, &rpg, &rcond, ferr, berr,
112 &mem_usage, &stat, &info);
115 if ( (info == 0 || info == n+1) && (verbose>0) ) {
117 if ( options.PivotGrowth ) printf(
"Recip. pivot growth = %e\n", rpg);
118 if ( options.ConditionNumber )
119 printf(
"Recip. condition number = %e\n", rcond);
120 Lstore = (SCformat *) L.Store;
121 Ustore = (NCformat *) U.Store;
122 printf(
"No of nonzeros in factor L = %d\n", Lstore->nnz);
123 printf(
"No of nonzeros in factor U = %d\n", Ustore->nnz);
124 printf(
"No of nonzeros in L+U = %ld\n", Lstore->nnz + Ustore->nnz - n);
125 printf(
"FILL ratio = %.1f\n", (
float)(Lstore->nnz + Ustore->nnz - n)/N);
127 printf(
"L\\U MB %.3f\ttotal MB needed %.3f\n",
128 mem_usage.for_lu/1e6, mem_usage.total_needed/1e6);
131 }
else if ( info > 0 && lwork == -1 ) {
132 printf(
"** Estimated memory: %ld bytes\n", info - n);
135 printf(
"LU factorization: dgssvx() returns info %d\n", info);
137 if ( verbose>0 ) StatPrint(&stat);
149 template <
class Index>
151 std::unique_ptr<std::vector<Index>> ridx,
152 std::unique_ptr<std::vector<Index>> cidx,
153 std::unique_ptr<std::vector<Scalar>> values)
154 : N(ridx->
size()), n(n_)
156 assert(cidx->size()==N && values->size()==N);
160 std::cout <<
"SuperLU" <<
" solver, n=" << n <<
", nnz=" << N << std::endl;
164 set_default_options(&options);
175 if (std::is_same_v<int,Index>)
179 std::vector<int> iRidx(ridx->size()), iCidx(cidx->size());
180 std::copy(begin(*ridx),end(*ridx),begin(iRidx));
181 std::copy(begin(*cidx),end(*cidx),begin(iCidx));
186 dCreate_CompCol_Matrix(&A, n, n, N, &Az[0], &Ap[0], &Ai[0], SLU_NC, SLU_D, SLU_GE);
189 if ( !(rhsb = doubleMalloc(n * nrhs)) ) ABORT(
"Malloc fails for rhsb[].");
190 if ( !(rhsx = doubleMalloc(n * nrhs)) ) ABORT(
"Malloc fails for rhsx[].");
191 dCreate_Dense_Matrix(&B, n, nrhs, rhsb, n, SLU_DN, SLU_D, SLU_GE);
192 dCreate_Dense_Matrix(&X, n, nrhs, rhsx, n, SLU_DN, SLU_D, SLU_GE);
193 xact = doubleMalloc(n * nrhs);
195 dGenXtrue(n, nrhs, xact, ldx);
196 dFillRHS(trans, nrhs, xact, ldx, &A, &B);
198 if ( !(etree = intMalloc(n)) ) ABORT(
"Malloc fails for etree[].");
199 if ( !(perm_r = intMalloc(n)) ) ABORT(
"Malloc fails for perm_r[].");
200 if ( !(perm_c = intMalloc(n)) ) ABORT(
"Malloc fails for perm_c[].");
201 if ( !(R = (
double *) SUPERLU_MALLOC(A.nrow *
sizeof(
double))) )
202 ABORT(
"SUPERLU_MALLOC fails for R[].");
203 if ( !(C = (
double *) SUPERLU_MALLOC(A.ncol *
sizeof(
double))) )
204 ABORT(
"SUPERLU_MALLOC fails for C[].");
205 if ( !(ferr = (
double *) SUPERLU_MALLOC(nrhs *
sizeof(
double))) )
206 ABORT(
"SUPERLU_MALLOC fails for ferr[].");
207 if ( !(berr = (
double *) SUPERLU_MALLOC(nrhs *
sizeof(
double))) )
208 ABORT(
"SUPERLU_MALLOC fails for berr[].");
210 options.Equil = equil;
211 options.DiagPivotThresh = u;
212 options.Trans = trans;
216 dgssvx(&options, &A, perm_c, perm_r, etree, equed, R, C,
217 &L, &U, work, lwork, &B, &X, &rpg, &rcond, ferr, berr,
218 &mem_usage, &stat, &info);
221 if ( (info == 0 || info == n+1) && (verbose>0) ) {
222 if ( options.PivotGrowth ) printf(
"Recip. pivot growth = %e\n", rpg);
223 if ( options.ConditionNumber )
224 printf(
"Recip. condition number = %e\n", rcond);
225 Lstore = (SCformat *) L.Store;
226 Ustore = (NCformat *) U.Store;
227 printf(
"No of nonzeros in factor L = %d\n", Lstore->nnz);
228 printf(
"No of nonzeros in factor U = %d\n", Ustore->nnz);
229 printf(
"No of nonzeros in L+U = %d\n", Lstore->nnz + Ustore->nnz - n);
230 printf(
"FILL ratio = %.1f\n", (
float)(Lstore->nnz + Ustore->nnz - n)/N);
232 printf(
"L\\U MB %.3f\ttotal MB needed %.3f\n",
233 mem_usage.for_lu/1e6, mem_usage.total_needed/1e6);
236 }
else if ( info > 0 && lwork == -1 ) {
237 printf(
"** Estimated memory: %d bytes\n", info - n);
240 printf(
"LU factorization: dgssvx() returns info %d\n", info);
242 if ( verbose>0 ) StatPrint(&stat);
249 SUPERLU_FREE (etree);
250 SUPERLU_FREE (perm_r);
251 SUPERLU_FREE (perm_c);
261 Destroy_SuperNode_Matrix(&L);
262 Destroy_CompCol_Matrix(&U); }
268 void solve(std::vector<Scalar>
const& b, std::vector<Scalar>& x,
bool transposed=
false)
const
272 solve(&b[0],&x[0],transposed);
275 virtual void solve(Scalar
const* b, Scalar* x,
bool transposed=
false)
const
277 options.Fact = FACTORED;
282 std::copy(b,b+n,rhsb);
283 dgssvx(&options, &A, perm_c, perm_r, etree, equed, R, C,
284 &L, &U, work, lwork, &B, &X, &rpg, &rcond, ferr, berr,
285 &mem_usage, &stat, &info);
286 std::copy(rhsx,rhsx+n,x);
289 printf(
"Triangular solve: dgssvx() returns info %d\n", info);
291 if ( (info == 0 || info == n+1) && (verbose>0) )
293 if ( options.IterRefine )
295 printf(
"Iterative Refinement:\n");
296 printf(
"%8s%8s%16s%16s\n",
"rhs",
"Steps",
"FERR",
"BERR");
297 for (
int i = 0; i < nrhs; ++i)
298 printf(
"%8d%8d%16e%16e\n", i+1, stat.RefineSteps, ferr[i], berr[i]);
301 }
else if ( info > 0 && lwork == -1 ) {
302 printf(
"** Estimated memory: %ld bytes\n", info - n);
305 if ( verbose>0 ) StatPrint(&stat);
314 void solve(std::vector<Scalar>& b)
const
322 options.Fact = FACTORED;
327 std::copy(b,b+n,rhsb);
328 dgssvx(&options, &A, perm_c, perm_r, etree, equed, R, C,
329 &L, &U, work, lwork, &B, &X, &rpg, &rcond, ferr, berr,
330 &mem_usage, &stat, &info);
331 std::copy(rhsx,rhsx+n,b);
334 printf(
"Triangular solve: dgssvx() returns info %d\n", info);
336 if ( (info == 0 || info == n+1) && (verbose>0) )
338 if ( options.IterRefine )
340 printf(
"Iterative Refinement:\n");
341 printf(
"%8s%8s%16s%16s\n",
"rhs",
"Steps",
"FERR",
"BERR");
342 for (
int i = 0; i < nrhs; ++i)
343 printf(
"%8d%8d%16e%16e\n", i+1, stat.RefineSteps, ferr[i], berr[i]);
347 else if ( info > 0 && lwork == -1 )
349 printf(
"** Estimated memory: %ld bytes\n", info - n);
364 std::vector<int> Ap, Ai;
365 std::vector<Scalar> Az;
366 mutable superlu_options_t options;
367 mutable int nrhs, ldx;
368 mutable SuperMatrix A, L, U;
369 mutable NCformat *Ustore;
370 mutable SCformat *Lstore;
371 mutable SuperMatrix B, X;
375 mutable double *R, *C;
376 mutable double *ferr, *berr;
377 mutable double *rhsb, *rhsx, *xact;
378 mutable mem_usage_t mem_usage;
379 mutable SuperLUStat_t stat;
380 mutable char equed[1];
382 mutable int info, lwork;
383 mutable double u, rpg, rcond;
384 mutable yes_no_t equil;
385 mutable trans_t trans;
Abstract base class for matrix factorizations.
Factorization of sparse linear systems with mumps.
void solve(std::vector< Scalar > &b) const
Solves the system for the given right hand side.
virtual void solve(Scalar *b) const
Solves the system for the given right hand side .
virtual size_t size() const
reports the dimension of the system
void solve(std::vector< Scalar > const &b, std::vector< Scalar > &x, bool transposed=false) const
virtual void solve(Scalar const *b, Scalar *x, bool transposed=false) const
Solves the system for the given right hand side .
SUPERLUFactorization(Index n_, std::vector< Index > const &ridx, std::vector< Index > const &cidx, std::vector< Scalar > const &values)
Version of constructor keeping input data in triplet format (aka coordinate format) constant.
SUPERLUFactorization(Index n_, std::unique_ptr< std::vector< Index > > ridx, std::unique_ptr< std::vector< Index > > cidx, std::unique_ptr< std::vector< Scalar > > values)
Version of constructor, that destroys input data before factorization: more memory efficient.
void tripletToCompressedColumn(Index nRows, Index nCols, size_t nNonZeros, std::vector< Index > const &ridx, std::vector< Index > const &cidx, std::vector< Scalar > const &values, std::vector< Index > &Ap, std::vector< Index > &Ai, std::vector< Scalar > &Az)
Converts a matrix in triplet format to a compressed column format.