Actual source code: fsolvebaij.F

  1: !
  2: !
  3: !    Fortran kernel for sparse triangular solve in the BAIJ matrix format
  4: ! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
  5: ! with MatSolve_SeqBAIJ_4_NaturalOrdering()
  6: !
  7: #include <finclude/petscsysdef.h>
  8: !

 10:       subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
 11:       implicit none
 12:       MatScalar   a(0:*)
 13:       PetscScalar x(0:*),b(0:*)
 14:       PetscInt    n,ai(0:*),aj(0:*),adiag(0:*)

 16:       PetscInt    i,j,jstart,jend,idx,ax,jdx
 17:       PetscScalar s1,s2,s3,s4
 18:       PetscScalar x1,x2,x3,x4
 19: !
 20: !     Forward Solve
 21: !
 22:       PETSC_AssertAlignx(16,a(1))
 23:       PETSC_AssertAlignx(16,x(1))
 24:       PETSC_AssertAlignx(16,b(1))
 25:       PETSC_AssertAlignx(16,ai(1))
 26:       PETSC_AssertAlignx(16,aj(1))
 27:       PETSC_AssertAlignx(16,adiag(1))
 28: 
 29:          x(0) = b(0)
 30:          x(1) = b(1)
 31:          x(2) = b(2)
 32:          x(3) = b(3)
 33:          idx  = 0
 34:          do 20 i=1,n-1
 35:             jstart = ai(i)
 36:             jend   = adiag(i) - 1
 37:             ax    = 16*jstart
 38:             idx    = idx + 4
 39:             s1     = b(idx)
 40:             s2     = b(idx+1)
 41:             s3     = b(idx+2)
 42:             s4     = b(idx+3)
 43:             do 30 j=jstart,jend
 44:               jdx   = 4*aj(j)
 45: 
 46:               x1    = x(jdx)
 47:               x2    = x(jdx+1)
 48:               x3    = x(jdx+2)
 49:               x4    = x(jdx+3)
 50:               s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 51:               s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 52:               s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 53:               s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 54:               ax = ax + 16
 55:  30         continue
 56:             x(idx)   = s1
 57:             x(idx+1) = s2
 58:             x(idx+2) = s3
 59:             x(idx+3) = s4
 60:  20      continue

 62: 
 63: !
 64: !     Backward solve the upper triangular
 65: !
 66:          do 40 i=n-1,0,-1
 67:             jstart  = adiag(i) + 1
 68:             jend    = ai(i+1) - 1
 69:             ax     = 16*jstart
 70:             s1      = x(idx)
 71:             s2      = x(idx+1)
 72:             s3      = x(idx+2)
 73:             s4      = x(idx+3)
 74:             do 50 j=jstart,jend
 75:               jdx   = 4*aj(j)
 76:               x1    = x(jdx)
 77:               x2    = x(jdx+1)
 78:               x3    = x(jdx+2)
 79:               x4    = x(jdx+3)
 80:               s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 81:               s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 82:               s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 83:               s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 84:               ax = ax + 16
 85:  50         continue
 86:             ax      = 16*adiag(i)
 87:             x(idx)   = a(ax)*s1  +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
 88:             x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
 89:             x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
 90:             x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
 91:             idx      = idx - 4
 92:  40      continue
 93:       return
 94:       end
 95: 
 96: !
 97: !   version that calls BLAS 2 operation for each row block
 98: !
 99:       subroutine FortranSolveBAIJ4BLAS(n,x,ai,aj,adiag,a,b,w)
100:       implicit none
101:       MatScalar   a(0:*),w(0:*)
102:       PetscScalar x(0:*),b(0:*)
103:       PetscInt n,ai(0:*),aj(0:*),adiag(0:*)

105:       PetscInt i,j,jstart,jend,idx,ax,jdx,kdx
106:       MatScalar s(0:3)
107:       integer   align7
108: !
109: !     Forward Solve
110: !


113:       PETSC_AssertAlignx(16,a(1))
114:       PETSC_AssertAlignx(16,w(1))
115:       PETSC_AssertAlignx(16,x(1))
116:       PETSC_AssertAlignx(16,b(1))
117:       PETSC_AssertAlignx(16,ai(1))
118:       PETSC_AssertAlignx(16,aj(1))
119:       PETSC_AssertAlignx(16,adiag(1))

121:       x(0) = b(0)
122:       x(1) = b(1)
123:       x(2) = b(2)
124:       x(3) = b(3)
125:       idx  = 0
126:       do 20 i=1,n-1
127: !
128: !        Pack required part of vector into work array
129: !
130:          kdx    = 0
131:          jstart = ai(i)
132:          jend   = adiag(i) - 1
133:          if (jend - jstart .ge. 500) then
134:            write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
135:          endif
136:          do 30 j=jstart,jend
137: 
138:            jdx       = 4*aj(j)
139: 
140:            w(kdx)    = x(jdx)
141:            w(kdx+1)  = x(jdx+1)
142:            w(kdx+2)  = x(jdx+2)
143:            w(kdx+3)  = x(jdx+3)
144:            kdx       = kdx + 4
145:  30      continue

147:          ax      = 16*jstart
148:          idx      = idx + 4
149:          s(0)     = b(idx)
150:          s(1)     = b(idx+1)
151:          s(2)     = b(idx+2)
152:          s(3)     = b(idx+3)
153: !
154: !    s = s - a(ax:)*w
155: !
156:          call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)
157: !         call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)

159:          x(idx)   = s(0)
160:          x(idx+1) = s(1)
161:          x(idx+2) = s(2)
162:          x(idx+3) = s(3)
163:  20   continue
164: 
165: !
166: !     Backward solve the upper triangular
167: !
168:       do 40 i=n-1,0,-1
169:          jstart    = adiag(i) + 1
170:          jend      = ai(i+1) - 1
171:          ax       = 16*jstart
172:          s(0)      = x(idx)
173:          s(1)      = x(idx+1)
174:          s(2)      = x(idx+2)
175:          s(3)      = x(idx+3)
176: !
177: !   Pack each chunk of vector needed
178: !
179:          kdx = 0
180:          if (jend - jstart .ge. 500) then
181:            write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
182:          endif
183:          do 50 j=jstart,jend
184:            jdx      = 4*aj(j)
185:            w(kdx)   = x(jdx)
186:            w(kdx+1) = x(jdx+1)
187:            w(kdx+2) = x(jdx+2)
188:            w(kdx+3) = x(jdx+3)
189:            kdx      = kdx + 4
190:  50      continue
191: !         call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)
192:          call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)

194:          ax      = 16*adiag(i)
195:          x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
196:          x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
197:          x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
198:          x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
199:          idx     = idx - 4
200:  40   continue

202:       return
203:       end
204: 

206: !
207: !   version that does not call BLAS 2 operation for each row block
208: !
209:       subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
210:       implicit none
211:       MatScalar   a(0:*)
212:       PetscScalar x(0:*),b(0:*),w(0:*)
213:       PetscInt  n,ai(0:*),aj(0:*),adiag(0:*)
214:       PetscInt  ii,jj,i,j

216:       PetscInt  jstart,jend,idx,ax,jdx,kdx,nn
217:       PetscScalar s(0:3)

219: !
220: !     Forward Solve
221: !

223:       PETSC_AssertAlignx(16,a(1))
224:       PETSC_AssertAlignx(16,w(1))
225:       PETSC_AssertAlignx(16,x(1))
226:       PETSC_AssertAlignx(16,b(1))
227:       PETSC_AssertAlignx(16,ai(1))
228:       PETSC_AssertAlignx(16,aj(1))
229:       PETSC_AssertAlignx(16,adiag(1))

231:       x(0) = b(0)
232:       x(1) = b(1)
233:       x(2) = b(2)
234:       x(3) = b(3)
235:       idx  = 0
236:       do 20 i=1,n-1
237: !
238: !        Pack required part of vector into work array
239: !
240:          kdx    = 0
241:          jstart = ai(i)
242:          jend   = adiag(i) - 1
243:          if (jend - jstart .ge. 500) then
244:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
245:          endif
246:          do 30 j=jstart,jend
247: 
248:            jdx       = 4*aj(j)
249: 
250:            w(kdx)    = x(jdx)
251:            w(kdx+1)  = x(jdx+1)
252:            w(kdx+2)  = x(jdx+2)
253:            w(kdx+3)  = x(jdx+3)
254:            kdx       = kdx + 4
255:  30      continue

257:          ax       = 16*jstart
258:          idx      = idx + 4
259:          s(0)     = b(idx)
260:          s(1)     = b(idx+1)
261:          s(2)     = b(idx+2)
262:          s(3)     = b(idx+3)
263: !
264: !    s = s - a(ax:)*w
265: !
266:          nn = 4*(jend - jstart + 1) - 1
267:          do 100, ii=0,3
268:            do 110, jj=0,nn
269:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
270:  110       continue
271:  100     continue

273:          x(idx)   = s(0)
274:          x(idx+1) = s(1)
275:          x(idx+2) = s(2)
276:          x(idx+3) = s(3)
277:  20   continue
278: 
279: !
280: !     Backward solve the upper triangular
281: !
282:       do 40 i=n-1,0,-1
283:          jstart    = adiag(i) + 1
284:          jend      = ai(i+1) - 1
285:          ax        = 16*jstart
286:          s(0)      = x(idx)
287:          s(1)      = x(idx+1)
288:          s(2)      = x(idx+2)
289:          s(3)      = x(idx+3)
290: !
291: !   Pack each chunk of vector needed
292: !
293:          kdx = 0
294:          if (jend - jstart .ge. 500) then
295:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
296:          endif
297:          do 50 j=jstart,jend
298:            jdx      = 4*aj(j)
299:            w(kdx)   = x(jdx)
300:            w(kdx+1) = x(jdx+1)
301:            w(kdx+2) = x(jdx+2)
302:            w(kdx+3) = x(jdx+3)
303:            kdx      = kdx + 4
304:  50      continue
305:          nn = 4*(jend - jstart + 1) - 1
306:          do 200, ii=0,3
307:            do 210, jj=0,nn
308:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
309:  210       continue
310:  200     continue

312:          ax      = 16*adiag(i)
313:          x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
314:          x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
315:          x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
316:          x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
317:          idx     = idx - 4
318:  40   continue

320:       return
321:       end
322: