Actual source code: mpimatmatmatmult.c
petsc-3.13.1 2020-05-02
1: /*
2: Defines matrix-matrix-matrix product routines for MPIAIJ matrices
3: D = A * B * C
4: */
5: #include <../src/mat/impls/aij/mpi/mpiaij.h>
7: #if defined(PETSC_HAVE_HYPRE)
8: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Mat,Mat,Mat,PetscReal,Mat);
9: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Mat,Mat,Mat,Mat);
11: PETSC_INTERN PetscErrorCode MatProductNumeric_ABC_Transpose_AIJ_AIJ(Mat RAP)
12: {
14: Mat_Product *product = RAP->product;
15: Mat Rt,R=product->A,A=product->B,P=product->C;
18: MatTransposeGetMat(R,&Rt);
19: MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Rt,A,P,RAP);
20: return(0);
21: }
23: PETSC_INTERN PetscErrorCode MatProductSymbolic_ABC_Transpose_AIJ_AIJ(Mat RAP)
24: {
26: Mat_Product *product = RAP->product;
27: Mat Rt,R=product->A,A=product->B,P=product->C;
28: PetscBool flg;
31: /* local sizes of matrices will be checked by the calling subroutines */
32: MatTransposeGetMat(R,&Rt);
33: PetscObjectTypeCompareAny((PetscObject)Rt,&flg,MATSEQAIJ,MATSEQAIJMKL,MATMPIAIJ,NULL);
34: if (!flg) SETERRQ1(PetscObjectComm((PetscObject)Rt),PETSC_ERR_SUP,"Not for matrix type %s",((PetscObject)Rt)->type_name);
35: MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Rt,A,P,product->fill,RAP);
36: RAP->ops->productnumeric = MatProductNumeric_ABC_Transpose_AIJ_AIJ;
37: return(0);
38: }
40: PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose_AIJ_AIJ(Mat C)
41: {
43: Mat_Product *product = C->product;
46: MatSetType(C,MATAIJ);
47: if (product->type == MATPRODUCT_ABC) {
48: C->ops->productsymbolic = MatProductSymbolic_ABC_Transpose_AIJ_AIJ;
49: } else SETERRQ1(PetscObjectComm((PetscObject)C),PETSC_ERR_SUP,"MatProduct type %s is not supported for Transpose, AIJ and AIJ matrices",MatProductTypes[product->type]);
50: return(0);
51: }
52: #endif
54: PetscErrorCode MatFreeIntermediateDataStructures_MPIAIJ_BC(Mat ABC)
55: {
56: Mat_MPIAIJ *a = (Mat_MPIAIJ*)ABC->data;
57: Mat_MatMatMatMult *matmatmatmult = a->matmatmatmult;
58: PetscErrorCode ierr;
61: if (!matmatmatmult) return(0);
63: MatDestroy(&matmatmatmult->BC);
64: ABC->ops->destroy = matmatmatmult->destroy;
65: PetscFree(a->matmatmatmult);
66: return(0);
67: }
69: PetscErrorCode MatDestroy_MPIAIJ_MatMatMatMult(Mat A)
70: {
71: PetscErrorCode ierr;
74: (*A->ops->freeintermediatedatastructures)(A);
75: (*A->ops->destroy)(A);
76: return(0);
77: }
79: PetscErrorCode MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(Mat A,Mat B,Mat C,PetscReal fill,Mat D)
80: {
82: Mat BC;
83: PetscBool scalable;
84: Mat_Product *product = D->product;
87: MatCreate(PetscObjectComm((PetscObject)A),&BC);
88: if (product) {
89: PetscStrcmp(product->alg,"scalable",&scalable);
90: } else SETERRQ(PetscObjectComm((PetscObject)D),PETSC_ERR_ARG_NULL,"Call MatProductCreate() first");
92: if (scalable) {
93: MatMatMultSymbolic_MPIAIJ_MPIAIJ(B,C,fill,BC);
94: MatZeroEntries(BC); /* initialize value entries of BC */
95: MatMatMultSymbolic_MPIAIJ_MPIAIJ(A,BC,fill,D);
96: } else {
97: MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(B,C,fill,BC);
98: MatZeroEntries(BC); /* initialize value entries of BC */
99: MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(A,BC,fill,D);
100: }
101: product->Dwork = BC;
103: D->ops->matmatmultnumeric = MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ;
104: D->ops->freeintermediatedatastructures = MatFreeIntermediateDataStructures_MPIAIJ_BC;
105: return(0);
106: }
108: PetscErrorCode MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ(Mat A,Mat B,Mat C,Mat D)
109: {
111: Mat_Product *product = D->product;
112: Mat BC = product->Dwork;
115: (BC->ops->matmultnumeric)(B,C,BC);
116: (D->ops->matmultnumeric)(A,BC,D);
117: return(0);
118: }
120: /* ----------------------------------------------------- */
121: PetscErrorCode MatDestroy_MPIAIJ_RARt(Mat C)
122: {
124: Mat_MPIAIJ *c = (Mat_MPIAIJ*)C->data;
125: Mat_RARt *rart = c->rart;
128: MatDestroy(&rart->Rt);
130: C->ops->destroy = rart->destroy;
131: if (C->ops->destroy) {
132: (*C->ops->destroy)(C);
133: }
134: PetscFree(rart);
135: return(0);
136: }
138: PetscErrorCode MatProductNumeric_RARt_MPIAIJ_MPIAIJ(Mat C)
139: {
141: Mat_MPIAIJ *c = (Mat_MPIAIJ*)C->data;
142: Mat_RARt *rart = c->rart;
143: Mat_Product *product = C->product;
144: Mat A=product->A,R=product->B,Rt=rart->Rt;
147: MatTranspose(R,MAT_REUSE_MATRIX,&Rt);
148: (C->ops->matmatmultnumeric)(R,A,Rt,C);
149: return(0);
150: }
152: PetscErrorCode MatProductSymbolic_RARt_MPIAIJ_MPIAIJ(Mat C)
153: {
154: PetscErrorCode ierr;
155: Mat_Product *product = C->product;
156: Mat A=product->A,R=product->B,Rt;
157: PetscReal fill=product->fill;
158: Mat_RARt *rart;
159: Mat_MPIAIJ *c;
162: MatTranspose(R,MAT_INITIAL_MATRIX,&Rt);
163: /* product->Dwork is used to store A*Rt in MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ() */
164: MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(R,A,Rt,fill,C);
165: C->ops->productnumeric = MatProductNumeric_RARt_MPIAIJ_MPIAIJ;
167: /* create a supporting struct */
168: PetscNew(&rart);
169: c = (Mat_MPIAIJ*)C->data;
170: c->rart = rart;
171: rart->Rt = Rt;
172: rart->destroy = C->ops->destroy;
173: C->ops->destroy = MatDestroy_MPIAIJ_RARt;
174: return(0);
175: }