10 #ifndef EIGEN_PASTIXSUPPORT_H
11 #define EIGEN_PASTIXSUPPORT_H
23 template<
typename _MatrixType,
bool IsStrSym = false>
class PastixLU;
24 template<
typename _MatrixType,
int Options>
class PastixLLT;
25 template<
typename _MatrixType,
int Options>
class PastixLDLT;
30 template<
class Pastix>
struct pastix_traits;
32 template<
typename _MatrixType>
33 struct pastix_traits< PastixLU<_MatrixType> >
35 typedef _MatrixType MatrixType;
36 typedef typename _MatrixType::Scalar Scalar;
37 typedef typename _MatrixType::RealScalar RealScalar;
38 typedef typename _MatrixType::Index Index;
41 template<
typename _MatrixType,
int Options>
42 struct pastix_traits< PastixLLT<_MatrixType,Options> >
44 typedef _MatrixType MatrixType;
45 typedef typename _MatrixType::Scalar Scalar;
46 typedef typename _MatrixType::RealScalar RealScalar;
47 typedef typename _MatrixType::Index Index;
50 template<
typename _MatrixType,
int Options>
51 struct pastix_traits< PastixLDLT<_MatrixType,Options> >
53 typedef _MatrixType MatrixType;
54 typedef typename _MatrixType::Scalar Scalar;
55 typedef typename _MatrixType::RealScalar RealScalar;
56 typedef typename _MatrixType::Index Index;
59 void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx,
float *vals,
int *perm,
int * invp,
float *x,
int nbrhs,
int *iparm,
double *dparm)
61 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
62 if (nbrhs == 0) {x = NULL; nbrhs=1;}
63 s_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
66 void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx,
double *vals,
int *perm,
int * invp,
double *x,
int nbrhs,
int *iparm,
double *dparm)
68 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
69 if (nbrhs == 0) {x = NULL; nbrhs=1;}
70 d_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
73 void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx, std::complex<float> *vals,
int *perm,
int * invp, std::complex<float> *x,
int nbrhs,
int *iparm,
double *dparm)
75 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
76 if (nbrhs == 0) {x = NULL; nbrhs=1;}
77 c_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<COMPLEX*>(vals), perm, invp, reinterpret_cast<COMPLEX*>(x), nbrhs, iparm, dparm);
80 void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx, std::complex<double> *vals,
int *perm,
int * invp, std::complex<double> *x,
int nbrhs,
int *iparm,
double *dparm)
82 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
83 if (nbrhs == 0) {x = NULL; nbrhs=1;}
84 z_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<DCOMPLEX*>(vals), perm, invp, reinterpret_cast<DCOMPLEX*>(x), nbrhs, iparm, dparm);
88 template <
typename MatrixType>
89 void c_to_fortran_numbering (MatrixType& mat)
91 if ( !(mat.outerIndexPtr()[0]) )
94 for(i = 0; i <= mat.rows(); ++i)
95 ++mat.outerIndexPtr()[i];
96 for(i = 0; i < mat.nonZeros(); ++i)
97 ++mat.innerIndexPtr()[i];
102 template <
typename MatrixType>
103 void fortran_to_c_numbering (MatrixType& mat)
106 if ( mat.outerIndexPtr()[0] == 1 )
109 for(i = 0; i <= mat.rows(); ++i)
110 --mat.outerIndexPtr()[i];
111 for(i = 0; i < mat.nonZeros(); ++i)
112 --mat.innerIndexPtr()[i];
119 template <
class Derived>
120 class PastixBase : internal::noncopyable
123 typedef typename internal::pastix_traits<Derived>::MatrixType _MatrixType;
124 typedef _MatrixType MatrixType;
125 typedef typename MatrixType::Scalar Scalar;
126 typedef typename MatrixType::RealScalar RealScalar;
127 typedef typename MatrixType::Index Index;
128 typedef Matrix<Scalar,Dynamic,1> Vector;
129 typedef SparseMatrix<Scalar, ColMajor> ColSpMatrix;
133 PastixBase() : m_initisOk(false), m_analysisIsOk(false), m_factorizationIsOk(false), m_isInitialized(false), m_pastixdata(0), m_size(0)
147 template<
typename Rhs>
148 inline const internal::solve_retval<PastixBase, Rhs>
149 solve(
const MatrixBase<Rhs>& b)
const
151 eigen_assert(m_isInitialized &&
"Pastix solver is not initialized.");
152 eigen_assert(rows()==b.rows()
153 &&
"PastixBase::solve(): invalid number of rows of the right hand side matrix b");
154 return internal::solve_retval<PastixBase, Rhs>(*
this, b.derived());
157 template<
typename Rhs,
typename Dest>
158 bool _solve (
const MatrixBase<Rhs> &b, MatrixBase<Dest> &x)
const;
161 template<
typename Rhs,
typename DestScalar,
int DestOptions,
typename DestIndex>
162 void _solve_sparse(
const Rhs& b, SparseMatrix<DestScalar,DestOptions,DestIndex> &dest)
const
164 eigen_assert(m_factorizationIsOk &&
"The decomposition is not in a valid state for solving, you must first call either compute() or symbolic()/numeric()");
165 eigen_assert(rows()==b.rows());
168 static const int NbColsAtOnce = 1;
169 int rhsCols = b.cols();
172 for(
int k=0; k<rhsCols; k+=NbColsAtOnce)
174 int actualCols = std::min<int>(rhsCols-k, NbColsAtOnce);
175 tmp.leftCols(actualCols) = b.middleCols(k,actualCols);
176 tmp.leftCols(actualCols) = derived().solve(tmp.leftCols(actualCols));
177 dest.middleCols(k,actualCols) = tmp.leftCols(actualCols).sparseView();
183 return *
static_cast<Derived*
>(
this);
185 const Derived& derived()
const
187 return *
static_cast<const Derived*
>(
this);
195 Array<Index,IPARM_SIZE,1>& iparm()
204 int& iparm(
int idxparam)
206 return m_iparm(idxparam);
213 Array<RealScalar,IPARM_SIZE,1>& dparm()
222 double& dparm(
int idxparam)
224 return m_dparm(idxparam);
227 inline Index cols()
const {
return m_size; }
228 inline Index rows()
const {
return m_size; }
240 eigen_assert(m_isInitialized &&
"Decomposition is not initialized.");
248 template<
typename Rhs>
249 inline const internal::sparse_solve_retval<PastixBase, Rhs>
250 solve(
const SparseMatrixBase<Rhs>& b)
const
252 eigen_assert(m_isInitialized &&
"Pastix LU, LLT or LDLT is not initialized.");
253 eigen_assert(rows()==b.rows()
254 &&
"PastixBase::solve(): invalid number of rows of the right hand side matrix b");
255 return internal::sparse_solve_retval<PastixBase, Rhs>(*
this, b.derived());
264 void analyzePattern(ColSpMatrix& mat);
267 void factorize(ColSpMatrix& mat);
272 eigen_assert(m_initisOk &&
"The Pastix structure should be allocated first");
273 m_iparm(IPARM_START_TASK) = API_TASK_CLEAN;
274 m_iparm(IPARM_END_TASK) = API_TASK_CLEAN;
275 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
276 m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
279 void compute(ColSpMatrix& mat);
283 int m_factorizationIsOk;
284 bool m_isInitialized;
286 mutable pastix_data_t *m_pastixdata;
288 mutable Matrix<int,IPARM_SIZE,1> m_iparm;
289 mutable Matrix<double,DPARM_SIZE,1> m_dparm;
290 mutable Matrix<Index,Dynamic,1> m_perm;
291 mutable Matrix<Index,Dynamic,1> m_invp;
299 template <
class Derived>
300 void PastixBase<Derived>::init()
303 m_iparm.setZero(IPARM_SIZE);
304 m_dparm.setZero(DPARM_SIZE);
306 m_iparm(IPARM_MODIFY_PARAMETER) = API_NO;
307 pastix(&m_pastixdata, MPI_COMM_WORLD,
309 0, 0, 0, 1, m_iparm.data(), m_dparm.data());
311 m_iparm[IPARM_MATRIX_VERIFICATION] = API_NO;
312 m_iparm[IPARM_VERBOSE] = 2;
313 m_iparm[IPARM_ORDERING] = API_ORDER_SCOTCH;
314 m_iparm[IPARM_INCOMPLETE] = API_NO;
315 m_iparm[IPARM_OOC_LIMIT] = 2000;
316 m_iparm[IPARM_RHS_MAKING] = API_RHS_B;
317 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
319 m_iparm(IPARM_START_TASK) = API_TASK_INIT;
320 m_iparm(IPARM_END_TASK) = API_TASK_INIT;
321 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
322 0, 0, 0, 0, m_iparm.data(), m_dparm.data());
325 if(m_iparm(IPARM_ERROR_NUMBER)) {
335 template <
class Derived>
336 void PastixBase<Derived>::compute(ColSpMatrix& mat)
338 eigen_assert(mat.rows() == mat.cols() &&
"The input matrix should be squared");
343 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
344 m_isInitialized = m_factorizationIsOk;
348 template <
class Derived>
349 void PastixBase<Derived>::analyzePattern(ColSpMatrix& mat)
351 eigen_assert(m_initisOk &&
"The initialization of PaSTiX failed");
358 m_perm.resize(m_size);
359 m_invp.resize(m_size);
361 m_iparm(IPARM_START_TASK) = API_TASK_ORDERING;
362 m_iparm(IPARM_END_TASK) = API_TASK_ANALYSE;
363 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
364 mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
367 if(m_iparm(IPARM_ERROR_NUMBER))
370 m_analysisIsOk =
false;
375 m_analysisIsOk =
true;
379 template <
class Derived>
380 void PastixBase<Derived>::factorize(ColSpMatrix& mat)
383 eigen_assert(m_analysisIsOk &&
"The analysis phase should be called before the factorization phase");
384 m_iparm(IPARM_START_TASK) = API_TASK_NUMFACT;
385 m_iparm(IPARM_END_TASK) = API_TASK_NUMFACT;
388 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
389 mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
392 if(m_iparm(IPARM_ERROR_NUMBER))
395 m_factorizationIsOk =
false;
396 m_isInitialized =
false;
401 m_factorizationIsOk =
true;
402 m_isInitialized =
true;
407 template<
typename Base>
408 template<
typename Rhs,
typename Dest>
409 bool PastixBase<Base>::_solve (
const MatrixBase<Rhs> &b, MatrixBase<Dest> &x)
const
411 eigen_assert(m_isInitialized &&
"The matrix should be factorized first");
413 THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
418 for (
int i = 0; i < b.cols(); i++){
419 m_iparm[IPARM_START_TASK] = API_TASK_SOLVE;
420 m_iparm[IPARM_END_TASK] = API_TASK_REFINE;
422 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, x.rows(), 0, 0, 0,
423 m_perm.data(), m_invp.data(), &x(0, i), rhs, m_iparm.data(), m_dparm.data());
429 return m_iparm(IPARM_ERROR_NUMBER)==0;
451 template<
typename _MatrixType,
bool IsStrSym>
452 class PastixLU :
public PastixBase< PastixLU<_MatrixType> >
455 typedef _MatrixType MatrixType;
456 typedef PastixBase<PastixLU<MatrixType> > Base;
458 typedef typename MatrixType::Index Index;
466 PastixLU(
const MatrixType& matrix):Base()
478 m_structureIsUptodate =
false;
480 grabMatrix(matrix, temp);
490 m_structureIsUptodate =
false;
492 grabMatrix(matrix, temp);
504 grabMatrix(matrix, temp);
511 m_structureIsUptodate =
false;
512 m_iparm(IPARM_SYM) = API_SYM_NO;
513 m_iparm(IPARM_FACTORIZATION) = API_FACT_LU;
516 void grabMatrix(
const MatrixType& matrix, ColSpMatrix& out)
522 if(!m_structureIsUptodate)
525 m_transposedStructure = matrix.transpose();
528 for (Index j=0; j<m_transposedStructure.
outerSize(); ++j)
529 for(
typename ColSpMatrix::InnerIterator it(m_transposedStructure, j); it; ++it)
532 m_structureIsUptodate =
true;
535 out = m_transposedStructure + matrix;
537 internal::c_to_fortran_numbering(out);
543 ColSpMatrix m_transposedStructure;
544 bool m_structureIsUptodate;
561 template<
typename _MatrixType,
int _UpLo>
562 class PastixLLT :
public PastixBase< PastixLLT<_MatrixType, _UpLo> >
565 typedef _MatrixType MatrixType;
566 typedef PastixBase<PastixLLT<MatrixType, _UpLo> > Base;
570 enum { UpLo = _UpLo };
576 PastixLLT(
const MatrixType& matrix):Base()
588 grabMatrix(matrix, temp);
599 grabMatrix(matrix, temp);
608 grabMatrix(matrix, temp);
616 m_iparm(IPARM_SYM) = API_SYM_YES;
617 m_iparm(IPARM_FACTORIZATION) = API_FACT_LLT;
620 void grabMatrix(
const MatrixType& matrix, ColSpMatrix& out)
623 out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
624 internal::c_to_fortran_numbering(out);
642 template<
typename _MatrixType,
int _UpLo>
643 class PastixLDLT :
public PastixBase< PastixLDLT<_MatrixType, _UpLo> >
646 typedef _MatrixType MatrixType;
647 typedef PastixBase<PastixLDLT<MatrixType, _UpLo> > Base;
651 enum { UpLo = _UpLo };
669 grabMatrix(matrix, temp);
680 grabMatrix(matrix, temp);
689 grabMatrix(matrix, temp);
698 m_iparm(IPARM_SYM) = API_SYM_YES;
699 m_iparm(IPARM_FACTORIZATION) = API_FACT_LDLT;
702 void grabMatrix(
const MatrixType& matrix, ColSpMatrix& out)
705 out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
706 internal::c_to_fortran_numbering(out);
712 template<
typename _MatrixType,
typename Rhs>
713 struct solve_retval<PastixBase<_MatrixType>, Rhs>
714 : solve_retval_base<PastixBase<_MatrixType>, Rhs>
716 typedef PastixBase<_MatrixType> Dec;
717 EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs)
719 template<typename Dest>
void evalTo(Dest& dst)
const
721 dec()._solve(rhs(),dst);
725 template<
typename _MatrixType,
typename Rhs>
726 struct sparse_solve_retval<PastixBase<_MatrixType>, Rhs>
727 : sparse_solve_retval_base<PastixBase<_MatrixType>, Rhs>
729 typedef PastixBase<_MatrixType> Dec;
730 EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec,Rhs)
732 template<typename Dest>
void evalTo(Dest& dst)
const
734 dec()._solve_sparse(rhs(),dst);