Actual source code: fsolvebaij.F90

  1: !
  2: !
  3: !    Fortran kernel for sparse triangular solve in the BAIJ matrix format
  4: ! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
  5: ! with MatSolve_SeqBAIJ_4_NaturalOrdering()
  6: !
  7: #include <petsc/finclude/petscsys.h>
  8: !

 10:       subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
 11:       implicit none
 12:       MatScalar   a(0:*)
 13:       PetscScalar x(0:*)
 14:       PetscScalar b(0:*)
 15:       PetscInt    n
 16:       PetscInt    ai(0:*)
 17:       PetscInt    aj(0:*)
 18:       PetscInt    adiag(0:*)

 20:       PetscInt    i,j,jstart,jend
 21:       PetscInt    idx,ax,jdx
 22:       PetscScalar s1,s2,s3,s4
 23:       PetscScalar x1,x2,x3,x4
 24: !
 25: !     Forward Solve
 26: !
 27:       PETSC_AssertAlignx(16,a(1))
 28:       PETSC_AssertAlignx(16,x(1))
 29:       PETSC_AssertAlignx(16,b(1))
 30:       PETSC_AssertAlignx(16,ai(1))
 31:       PETSC_AssertAlignx(16,aj(1))
 32:       PETSC_AssertAlignx(16,adiag(1))

 34:          x(0) = b(0)
 35:          x(1) = b(1)
 36:          x(2) = b(2)
 37:          x(3) = b(3)
 38:          idx  = 0
 39:          do 20 i=1,n-1
 40:             jstart = ai(i)
 41:             jend   = adiag(i) - 1
 42:             ax    = 16*jstart
 43:             idx    = idx + 4
 44:             s1     = b(idx)
 45:             s2     = b(idx+1)
 46:             s3     = b(idx+2)
 47:             s4     = b(idx+3)
 48:             do 30 j=jstart,jend
 49:               jdx   = 4*aj(j)

 51:               x1    = x(jdx)
 52:               x2    = x(jdx+1)
 53:               x3    = x(jdx+2)
 54:               x4    = x(jdx+3)
 55:               s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 56:               s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 57:               s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 58:               s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 59:               ax = ax + 16
 60:  30         continue
 61:             x(idx)   = s1
 62:             x(idx+1) = s2
 63:             x(idx+2) = s3
 64:             x(idx+3) = s4
 65:  20      continue

 67: !
 68: !     Backward solve the upper triangular
 69: !
 70:          do 40 i=n-1,0,-1
 71:             jstart  = adiag(i) + 1
 72:             jend    = ai(i+1) - 1
 73:             ax     = 16*jstart
 74:             s1      = x(idx)
 75:             s2      = x(idx+1)
 76:             s3      = x(idx+2)
 77:             s4      = x(idx+3)
 78:             do 50 j=jstart,jend
 79:               jdx   = 4*aj(j)
 80:               x1    = x(jdx)
 81:               x2    = x(jdx+1)
 82:               x3    = x(jdx+2)
 83:               x4    = x(jdx+3)
 84:               s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 85:               s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 86:               s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 87:               s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 88:               ax = ax + 16
 89:  50         continue
 90:             ax      = 16*adiag(i)
 91:             x(idx)   = a(ax)*s1  +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
 92:             x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
 93:             x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
 94:             x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
 95:             idx      = idx - 4
 96:  40      continue
 97:       return
 98:       end

100: !   version that does not call BLAS 2 operation for each row block
101: !
102:       subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
103:       implicit none
104:       MatScalar   a(0:*)
105:       PetscScalar x(0:*),b(0:*),w(0:*)
106:       PetscInt  n,ai(0:*),aj(0:*),adiag(0:*)
107:       PetscInt  ii,jj,i,j

109:       PetscInt  jstart,jend,idx,ax,jdx,kdx,nn
110:       PetscScalar s(0:3)

112: !
113: !     Forward Solve
114: !

116:       PETSC_AssertAlignx(16,a(1))
117:       PETSC_AssertAlignx(16,w(1))
118:       PETSC_AssertAlignx(16,x(1))
119:       PETSC_AssertAlignx(16,b(1))
120:       PETSC_AssertAlignx(16,ai(1))
121:       PETSC_AssertAlignx(16,aj(1))
122:       PETSC_AssertAlignx(16,adiag(1))

124:       x(0) = b(0)
125:       x(1) = b(1)
126:       x(2) = b(2)
127:       x(3) = b(3)
128:       idx  = 0
129:       do 20 i=1,n-1
130: !
131: !        Pack required part of vector into work array
132: !
133:          kdx    = 0
134:          jstart = ai(i)
135:          jend   = adiag(i) - 1
136:          if (jend - jstart .ge. 500) then
137:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
138:          endif
139:          do 30 j=jstart,jend

141:            jdx       = 4*aj(j)

143:            w(kdx)    = x(jdx)
144:            w(kdx+1)  = x(jdx+1)
145:            w(kdx+2)  = x(jdx+2)
146:            w(kdx+3)  = x(jdx+3)
147:            kdx       = kdx + 4
148:  30      continue

150:          ax       = 16*jstart
151:          idx      = idx + 4
152:          s(0)     = b(idx)
153:          s(1)     = b(idx+1)
154:          s(2)     = b(idx+2)
155:          s(3)     = b(idx+3)
156: !
157: !    s = s - a(ax:)*w
158: !
159:          nn = 4*(jend - jstart + 1) - 1
160:          do 100, ii=0,3
161:            do 110, jj=0,nn
162:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
163:  110       continue
164:  100     continue

166:          x(idx)   = s(0)
167:          x(idx+1) = s(1)
168:          x(idx+2) = s(2)
169:          x(idx+3) = s(3)
170:  20   continue

172: !
173: !     Backward solve the upper triangular
174: !
175:       do 40 i=n-1,0,-1
176:          jstart    = adiag(i) + 1
177:          jend      = ai(i+1) - 1
178:          ax        = 16*jstart
179:          s(0)      = x(idx)
180:          s(1)      = x(idx+1)
181:          s(2)      = x(idx+2)
182:          s(3)      = x(idx+3)
183: !
184: !   Pack each chunk of vector needed
185: !
186:          kdx = 0
187:          if (jend - jstart .ge. 500) then
188:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
189:          endif
190:          do 50 j=jstart,jend
191:            jdx      = 4*aj(j)
192:            w(kdx)   = x(jdx)
193:            w(kdx+1) = x(jdx+1)
194:            w(kdx+2) = x(jdx+2)
195:            w(kdx+3) = x(jdx+3)
196:            kdx      = kdx + 4
197:  50      continue
198:          nn = 4*(jend - jstart + 1) - 1
199:          do 200, ii=0,3
200:            do 210, jj=0,nn
201:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
202:  210       continue
203:  200     continue

205:          ax      = 16*adiag(i)
206:          x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
207:          x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
208:          x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
209:          x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
210:          idx     = idx - 4
211:  40   continue

213:       return
214:       end