Actual source code: vpbjacobi_kok.kokkos.cxx

  1: #include <petscvec_kokkos.hpp>
  2: #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
  3: #include <petscdevice.h>
  4: #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h>
  5: #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
  6: #include <../src/mat/impls/aij/mpi/mpiaij.h>
  7: #include <KokkosBlas2_gemv.hpp>

  9: /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */
 10: struct PC_VPBJacobi_Kokkos {
 11:   /* Cache the old sizes to check if we need realloc */
 12:   PetscInt n;       /* number of rows of the local matrix */
 13:   PetscInt nblocks; /* number of point blocks */
 14:   PetscInt nsize;   /* sum of sizes (elements) of the point blocks */

 16:   /* Helper arrays that are pre-computed on host and then copied to device.
 17:     bs:     [nblocks+1], "csr" version of bsizes[]
 18:     bs2:    [nblocks+1], "csr" version of squares of bsizes[]
 19:     blkMap: [n], row i of the local matrix belongs to the blkMap[i]-th block
 20:   */
 21:   PetscIntKokkosDualView bs_dual, bs2_dual, blkMap_dual;
 22:   PetscScalarKokkosView  diag; // buffer to store diagonal blocks
 23:   PetscScalarKokkosView  work; // work buffer, with the same size as diag[]
 24:   PetscLogDouble         setupFlops;

 26:   // clang-format off
 27:   // n:               size of the matrix
 28:   // nblocks:         number of blocks
 29:   // nsize:           sum bsizes[i]^2 for i=0..nblocks
 30:   // bsizes[nblocks]: sizes of blocks
 31:   PC_VPBJacobi_Kokkos(PetscInt n, PetscInt nblocks, PetscInt nsize, const PetscInt *bsizes) :
 32:     n(n), nblocks(nblocks), nsize(nsize), bs_dual(NoInit("bs_dual"), nblocks + 1),
 33:     bs2_dual(NoInit("bs2_dual"), nblocks + 1), blkMap_dual(NoInit("blkMap_dual"), n),
 34:     diag(NoInit("diag"), nsize), work(NoInit("work"), nsize)
 35:   {
 36:     PetscCallVoid(BuildHelperArrays(bsizes));
 37:   }
 38:   // clang-format on

 40: private:
 41:   PetscErrorCode BuildHelperArrays(const PetscInt *bsizes)
 42:   {
 43:     PetscInt *bs_h     = bs_dual.view_host().data();
 44:     PetscInt *bs2_h    = bs2_dual.view_host().data();
 45:     PetscInt *blkMap_h = blkMap_dual.view_host().data();

 47:     PetscFunctionBegin;
 48:     setupFlops = 0.0;
 49:     bs_h[0] = bs2_h[0] = 0;
 50:     for (PetscInt i = 0; i < nblocks; i++) {
 51:       PetscInt m   = bsizes[i];
 52:       bs_h[i + 1]  = bs_h[i] + m;
 53:       bs2_h[i + 1] = bs2_h[i] + m * m;
 54:       for (PetscInt j = 0; j < m; j++) blkMap_h[bs_h[i] + j] = i;
 55:       // m^3/3 FMA for A=LU factorization; m^3 FMA for solving (LU)X=I to get the inverse
 56:       setupFlops += 8.0 * m * m * m / 3;
 57:     }

 59:     PetscCallCXX(bs_dual.modify_host());
 60:     PetscCallCXX(bs2_dual.modify_host());
 61:     PetscCallCXX(blkMap_dual.modify_host());
 62:     PetscCallCXX(bs_dual.sync_device());
 63:     PetscCallCXX(bs2_dual.sync_device());
 64:     PetscCallCXX(blkMap_dual.sync_device());
 65:     PetscCall(PetscLogCpuToGpu(sizeof(PetscInt) * (2 * (nblocks + 1) + n)));
 66:     PetscFunctionReturn(PETSC_SUCCESS);
 67:   }
 68: };

 70: template <PetscBool transpose>
 71: static PetscErrorCode PCApplyOrTranspose_VPBJacobi_Kokkos(PC pc, Vec x, Vec y)
 72: {
 73:   PC_VPBJacobi              *jac   = (PC_VPBJacobi *)pc->data;
 74:   PC_VPBJacobi_Kokkos       *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
 75:   ConstPetscScalarKokkosView xv;
 76:   PetscScalarKokkosView      yv;
 77:   PetscScalarKokkosView      diag   = pckok->diag;
 78:   PetscIntKokkosView         bs     = pckok->bs_dual.view_device();
 79:   PetscIntKokkosView         bs2    = pckok->bs2_dual.view_device();
 80:   PetscIntKokkosView         blkMap = pckok->blkMap_dual.view_device();
 81:   const char                *label  = transpose ? "PCApplyTranspose_VPBJacobi" : "PCApply_VPBJacobi";

 83:   PetscFunctionBegin;
 84:   PetscCall(PetscLogGpuTimeBegin());
 85:   VecErrorIfNotKokkos(x);
 86:   VecErrorIfNotKokkos(y);
 87:   PetscCall(VecGetKokkosView(x, &xv));
 88:   PetscCall(VecGetKokkosViewWrite(y, &yv));
 89: #if 0 // TODO: Why the TeamGemv version is 2x worse than the naive one?
 90:   PetscCallCXX(Kokkos::parallel_for(
 91:     label, Kokkos::TeamPolicy<>(jac->nblocks, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &team) {
 92:       PetscInt           bid  = team.league_rank();    // block id
 93:       PetscInt           n    = bs(bid + 1) - bs(bid); // size of this block
 94:       const PetscScalar *bbuf = &diag(bs2(bid));
 95:       const PetscScalar *xbuf = &xv(bs(bid));
 96:       PetscScalar       *ybuf = &yv(bs(bid));
 97:       const auto        &B    = Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft>(bbuf, n, n); // wrap it in a 2D view in column-major order
 98:       const auto        &x1   = ConstPetscScalarKokkosView(xbuf, n);
 99:       const auto        &y1   = PetscScalarKokkosView(ybuf, n);
100:       if (transpose) {
101:         KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::Transpose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B^T * x1
102:       } else {
103:         KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::NoTranspose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B * x1
104:       }
105:     }));
106: #else
107:   PetscCallCXX(Kokkos::parallel_for(
108:     label, pckok->n, KOKKOS_LAMBDA(PetscInt row) {
109:       const PetscScalar *Bp, *xp;
110:       PetscScalar       *yp;
111:       PetscInt           i, j, k, m;

113:       k  = blkMap(row);                             /* k-th block/matrix */
114:       m  = bs(k + 1) - bs(k);                       /* block size of the k-th block */
115:       i  = row - bs(k);                             /* i-th row of the block */
116:       Bp = &diag(bs2(k) + i * (transpose ? m : 1)); /* Bp points to the first entry of i-th row/column */
117:       xp = &xv(bs(k));
118:       yp = &yv(bs(k));

120:       yp[i] = 0.0;
121:       for (j = 0; j < m; j++) {
122:         yp[i] += Bp[0] * xp[j];
123:         Bp += transpose ? 1 : m;
124:       }
125:     }));
126: #endif
127:   PetscCall(VecRestoreKokkosView(x, &xv));
128:   PetscCall(VecRestoreKokkosViewWrite(y, &yv));
129:   PetscCall(PetscLogGpuFlops(pckok->nsize * 2)); /* FMA on entries in all blocks */
130:   PetscCall(PetscLogGpuTimeEnd());
131:   PetscFunctionReturn(PETSC_SUCCESS);
132: }

134: static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc)
135: {
136:   PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;

138:   PetscFunctionBegin;
139:   PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr));
140:   PetscCall(PCDestroy_VPBJacobi(pc));
141:   PetscFunctionReturn(PETSC_SUCCESS);
142: }

144: PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc)
145: {
146:   PC_VPBJacobi        *jac   = (PC_VPBJacobi *)pc->data;
147:   PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
148:   PetscInt             i, nlocal, nblocks, nsize = 0;
149:   const PetscInt      *bsizes;
150:   PetscBool            ismpi;
151:   Mat                  A;

153:   PetscFunctionBegin;
154:   PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
155:   PetscCall(MatGetLocalSize(pc->pmat, &nlocal, NULL));
156:   PetscCheck(!nlocal || nblocks, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must call MatSetVariableBlockSizes() before using PCVPBJACOBI");

158:   if (!jac->diag) {
159:     PetscInt max_bs = -1, min_bs = PETSC_MAX_INT;
160:     for (i = 0; i < nblocks; i++) {
161:       min_bs = PetscMin(min_bs, bsizes[i]);
162:       max_bs = PetscMax(max_bs, bsizes[i]);
163:       nsize += bsizes[i] * bsizes[i];
164:     }
165:     jac->nblocks = nblocks;
166:     jac->min_bs  = min_bs;
167:     jac->max_bs  = max_bs;
168:   }

170:   // If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway
171:   if (pckok && (pckok->n != nlocal || pckok->nblocks != nblocks || pckok->nsize != nsize)) {
172:     PetscCallCXX(delete pckok);
173:     pckok = nullptr;
174:   }

176:   PetscCall(PetscLogGpuTimeBegin());
177:   if (!pckok) {
178:     PetscCallCXX(pckok = new PC_VPBJacobi_Kokkos(nlocal, nblocks, nsize, bsizes));
179:     jac->spptr = pckok;
180:   }

182:   // Extract diagonal blocks from the matrix and compute their inverse
183:   const auto &bs     = pckok->bs_dual.view_device();
184:   const auto &bs2    = pckok->bs2_dual.view_device();
185:   const auto &blkMap = pckok->blkMap_dual.view_device();
186:   PetscCall(PetscObjectBaseTypeCompare((PetscObject)pc->pmat, MATMPIAIJ, &ismpi));
187:   A = ismpi ? static_cast<Mat_MPIAIJ *>((pc->pmat)->data)->A : pc->pmat;
188:   PetscCall(MatInvertVariableBlockDiagonal_SeqAIJKokkos(A, bs, bs2, blkMap, pckok->work, pckok->diag));
189:   pc->ops->apply          = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_FALSE>;
190:   pc->ops->applytranspose = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_TRUE>;
191:   pc->ops->destroy        = PCDestroy_VPBJacobi_Kokkos;
192:   PetscCall(PetscLogGpuTimeEnd());
193:   PetscCall(PetscLogGpuFlops(pckok->setupFlops));
194:   PetscFunctionReturn(PETSC_SUCCESS);
195: }