Actual source code: fnsqrt.c
slepc-3.13.3 2020-06-14
1: /*
2: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
3: SLEPc - Scalable Library for Eigenvalue Problem Computations
4: Copyright (c) 2002-2020, Universitat Politecnica de Valencia, Spain
6: This file is part of SLEPc.
7: SLEPc is distributed under a 2-clause BSD license (see LICENSE).
8: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
9: */
10: /*
11: Square root function sqrt(x)
12: */
14: #include <slepc/private/fnimpl.h>
15: #include <slepcblaslapack.h>
17: PetscErrorCode FNEvaluateFunction_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
18: {
20: #if !defined(PETSC_USE_COMPLEX)
21: if (x<0.0) SETERRQ(PETSC_COMM_SELF,1,"Function not defined in the requested value");
22: #endif
23: *y = PetscSqrtScalar(x);
24: return(0);
25: }
27: PetscErrorCode FNEvaluateDerivative_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
28: {
30: if (x==0.0) SETERRQ(PETSC_COMM_SELF,1,"Derivative not defined in the requested value");
31: #if !defined(PETSC_USE_COMPLEX)
32: if (x<0.0) SETERRQ(PETSC_COMM_SELF,1,"Derivative not defined in the requested value");
33: #endif
34: *y = 1.0/(2.0*PetscSqrtScalar(x));
35: return(0);
36: }
38: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Schur(FN fn,Mat A,Mat B)
39: {
41: PetscBLASInt n;
42: PetscScalar *T;
43: PetscInt m;
46: if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
47: MatDenseGetArray(B,&T);
48: MatGetSize(A,&m,NULL);
49: PetscBLASIntCast(m,&n);
50: SlepcSqrtmSchur(n,T,n,PETSC_FALSE);
51: MatDenseRestoreArray(B,&T);
52: return(0);
53: }
55: PetscErrorCode FNEvaluateFunctionMatVec_Sqrt_Schur(FN fn,Mat A,Vec v)
56: {
58: PetscBLASInt n;
59: PetscScalar *T;
60: PetscInt m;
61: Mat B;
64: FN_AllocateWorkMat(fn,A,&B);
65: MatDenseGetArray(B,&T);
66: MatGetSize(A,&m,NULL);
67: PetscBLASIntCast(m,&n);
68: SlepcSqrtmSchur(n,T,n,PETSC_TRUE);
69: MatDenseRestoreArray(B,&T);
70: MatGetColumnVector(B,v,0);
71: FN_FreeWorkMat(fn,&B);
72: return(0);
73: }
75: PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP(FN fn,Mat A,Mat B)
76: {
78: PetscBLASInt n;
79: PetscScalar *T;
80: PetscInt m;
83: if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
84: MatDenseGetArray(B,&T);
85: MatGetSize(A,&m,NULL);
86: PetscBLASIntCast(m,&n);
87: SlepcSqrtmDenmanBeavers(n,T,n,PETSC_FALSE);
88: MatDenseRestoreArray(B,&T);
89: return(0);
90: }
92: PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS(FN fn,Mat A,Mat B)
93: {
95: PetscBLASInt n;
96: PetscScalar *Ba;
97: PetscInt m;
100: if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
101: MatDenseGetArray(B,&Ba);
102: MatGetSize(A,&m,NULL);
103: PetscBLASIntCast(m,&n);
104: SlepcSqrtmNewtonSchulz(n,Ba,n,PETSC_FALSE);
105: MatDenseRestoreArray(B,&Ba);
106: return(0);
107: }
109: #define MAXIT 50
111: /*
112: Computes the principal square root of the matrix A using the
113: Sadeghi iteration. A is overwritten with sqrtm(A).
114: */
115: static PetscErrorCode SlepcSqrtmSadeghi(PetscBLASInt n,PetscScalar *A,PetscBLASInt ld)
116: {
117: PetscScalar *M,*M2,*G,*X=A,*work,work1,alpha,sqrtnrm;
118: PetscScalar szero=0.0,sone=1.0,smfive=-5.0,s1d16=1.0/16.0;
119: PetscReal tol,Mres=0.0,nrm,rwork[1];
120: PetscBLASInt N,i,it,*piv=NULL,info,lwork,query=-1;
121: const PetscBLASInt one=1;
122: PetscBool converged=PETSC_FALSE;
123: PetscErrorCode ierr;
124: unsigned int ftz;
127: N = n*n;
128: tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
129: SlepcSetFlushToZero(&ftz);
131: /* query work size */
132: PetscStackCallBLAS("LAPACKgetri",LAPACKgetri_(&n,A,&ld,piv,&work1,&query,&info));
133: PetscBLASIntCast((PetscInt)PetscRealPart(work1),&lwork);
135: PetscMalloc5(N,&M,N,&M2,N,&G,lwork,&work,n,&piv);
136: PetscArraycpy(M,A,N);
138: /* scale M */
139: nrm = LAPACKlange_("fro",&n,&n,M,&n,rwork);
140: if (nrm>1.0) {
141: sqrtnrm = PetscSqrtReal(nrm);
142: alpha = 1.0/nrm;
143: PetscStackCallBLAS("BLASscal",BLASscal_(&N,&alpha,M,&one));
144: tol *= nrm;
145: }
146: PetscInfo2(NULL,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol);
148: /* X = I */
149: PetscArrayzero(X,N);
150: for (i=0;i<n;i++) X[i+i*ld] = 1.0;
152: for (it=0;it<MAXIT && !converged;it++) {
154: /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
155: PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M,&ld,M,&ld,&szero,M2,&ld));
156: PetscStackCallBLAS("BLASaxpy",BLASaxpy_(&N,&smfive,M,&one,M2,&one));
157: for (i=0;i<n;i++) M2[i+i*ld] += 15.0;
158: PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&s1d16,M,&ld,M2,&ld,&szero,G,&ld));
159: for (i=0;i<n;i++) G[i+i*ld] += 5.0/16.0;
161: /* X = X*G */
162: PetscArraycpy(M2,X,N);
163: PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M2,&ld,G,&ld,&szero,X,&ld));
165: /* M = M*inv(G*G) */
166: PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,G,&ld,&szero,M2,&ld));
167: PetscStackCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,M2,&ld,piv,&info));
168: SlepcCheckLapackInfo("getrf",info);
169: PetscStackCallBLAS("LAPACKgetri",LAPACKgetri_(&n,M2,&ld,piv,work,&lwork,&info));
170: SlepcCheckLapackInfo("getri",info);
172: PetscArraycpy(G,M,N);
173: PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,M2,&ld,&szero,M,&ld));
175: /* check ||I-M|| */
176: PetscArraycpy(M2,M,N);
177: for (i=0;i<n;i++) M2[i+i*ld] -= 1.0;
178: Mres = LAPACKlange_("fro",&n,&n,M2,&n,rwork);
179: PetscIsNanReal(Mres);
180: if (Mres<=tol) converged = PETSC_TRUE;
181: PetscInfo2(NULL,"it: %D res: %g\n",it,(double)Mres);
182: PetscLogFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n);
183: }
185: if (Mres>tol) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations",MAXIT);
187: /* undo scaling */
188: if (nrm>1.0) PetscStackCallBLAS("BLASscal",BLASscal_(&N,&sqrtnrm,A,&one));
190: PetscFree5(M,M2,G,work,piv);
191: SlepcResetFlushToZero(&ftz);
192: return(0);
193: }
195: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi(FN fn,Mat A,Mat B)
196: {
198: PetscBLASInt n;
199: PetscScalar *Ba;
200: PetscInt m;
203: if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
204: MatDenseGetArray(B,&Ba);
205: MatGetSize(A,&m,NULL);
206: PetscBLASIntCast(m,&n);
207: SlepcSqrtmSadeghi(n,Ba,n);
208: MatDenseRestoreArray(B,&Ba);
209: return(0);
210: }
212: PetscErrorCode FNView_Sqrt(FN fn,PetscViewer viewer)
213: {
215: PetscBool isascii;
216: char str[50];
217: const char *methodname[] = {
218: "Schur method for the square root",
219: "Denman-Beavers (product form)",
220: "Newton-Schulz iteration",
221: "Sadeghi iteration"
222: };
223: const int nmeth=sizeof(methodname)/sizeof(methodname[0]);
226: PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii);
227: if (isascii) {
228: if (fn->beta==(PetscScalar)1.0) {
229: if (fn->alpha==(PetscScalar)1.0) {
230: PetscViewerASCIIPrintf(viewer," Square root: sqrt(x)\n");
231: } else {
232: SlepcSNPrintfScalar(str,50,fn->alpha,PETSC_TRUE);
233: PetscViewerASCIIPrintf(viewer," Square root: sqrt(%s*x)\n",str);
234: }
235: } else {
236: SlepcSNPrintfScalar(str,50,fn->beta,PETSC_TRUE);
237: if (fn->alpha==(PetscScalar)1.0) {
238: PetscViewerASCIIPrintf(viewer," Square root: %s*sqrt(x)\n",str);
239: } else {
240: PetscViewerASCIIPrintf(viewer," Square root: %s",str);
241: PetscViewerASCIIUseTabs(viewer,PETSC_FALSE);
242: SlepcSNPrintfScalar(str,50,fn->alpha,PETSC_TRUE);
243: PetscViewerASCIIPrintf(viewer,"*sqrt(%s*x)\n",str);
244: PetscViewerASCIIUseTabs(viewer,PETSC_TRUE);
245: }
246: }
247: if (fn->method<nmeth) {
248: PetscViewerASCIIPrintf(viewer," computing matrix functions with: %s\n",methodname[fn->method]);
249: }
250: }
251: return(0);
252: }
254: SLEPC_EXTERN PetscErrorCode FNCreate_Sqrt(FN fn)
255: {
257: fn->ops->evaluatefunction = FNEvaluateFunction_Sqrt;
258: fn->ops->evaluatederivative = FNEvaluateDerivative_Sqrt;
259: fn->ops->evaluatefunctionmat[0] = FNEvaluateFunctionMat_Sqrt_Schur;
260: fn->ops->evaluatefunctionmat[1] = FNEvaluateFunctionMat_Sqrt_DBP;
261: fn->ops->evaluatefunctionmat[2] = FNEvaluateFunctionMat_Sqrt_NS;
262: fn->ops->evaluatefunctionmat[3] = FNEvaluateFunctionMat_Sqrt_Sadeghi;
263: fn->ops->evaluatefunctionmatvec[0] = FNEvaluateFunctionMatVec_Sqrt_Schur;
264: fn->ops->view = FNView_Sqrt;
265: return(0);
266: }