Actual source code: vpbjacobi_kok.kokkos.cxx
1: #include <petscvec_kokkos.hpp>
2: #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
3: #include <petscdevice.h>
4: #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h>
5: #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
6: #include <../src/mat/impls/aij/mpi/mpiaij.h>
7: #include <KokkosBlas2_gemv.hpp>
9: /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */
10: struct PC_VPBJacobi_Kokkos {
11: /* Cache the old sizes to check if we need realloc */
12: PetscInt n; /* number of rows of the local matrix */
13: PetscInt nblocks; /* number of point blocks */
14: PetscInt nsize; /* sum of sizes (elements) of the point blocks */
16: /* Helper arrays that are pre-computed on host and then copied to device.
17: bs: [nblocks+1], "csr" version of bsizes[]
18: bs2: [nblocks+1], "csr" version of squares of bsizes[]
19: blkMap: [n], row i of the local matrix belongs to the blkMap[i]-th block
20: */
21: PetscIntKokkosDualView bs_dual, bs2_dual, blkMap_dual;
22: PetscScalarKokkosView diag; // buffer to store diagonal blocks
23: PetscScalarKokkosView work; // work buffer, with the same size as diag[]
24: PetscLogDouble setupFlops;
26: // clang-format off
27: // n: size of the matrix
28: // nblocks: number of blocks
29: // nsize: sum bsizes[i]^2 for i=0..nblocks
30: // bsizes[nblocks]: sizes of blocks
31: PC_VPBJacobi_Kokkos(PetscInt n, PetscInt nblocks, PetscInt nsize, const PetscInt *bsizes) :
32: n(n), nblocks(nblocks), nsize(nsize), bs_dual(NoInit("bs_dual"), nblocks + 1),
33: bs2_dual(NoInit("bs2_dual"), nblocks + 1), blkMap_dual(NoInit("blkMap_dual"), n),
34: diag(NoInit("diag"), nsize), work(NoInit("work"), nsize)
35: {
36: PetscCallVoid(BuildHelperArrays(bsizes));
37: }
38: // clang-format on
40: private:
41: PetscErrorCode BuildHelperArrays(const PetscInt *bsizes)
42: {
43: PetscInt *bs_h = bs_dual.view_host().data();
44: PetscInt *bs2_h = bs2_dual.view_host().data();
45: PetscInt *blkMap_h = blkMap_dual.view_host().data();
47: PetscFunctionBegin;
48: setupFlops = 0.0;
49: bs_h[0] = bs2_h[0] = 0;
50: for (PetscInt i = 0; i < nblocks; i++) {
51: PetscInt m = bsizes[i];
52: bs_h[i + 1] = bs_h[i] + m;
53: bs2_h[i + 1] = bs2_h[i] + m * m;
54: for (PetscInt j = 0; j < m; j++) blkMap_h[bs_h[i] + j] = i;
55: // m^3/3 FMA for A=LU factorization; m^3 FMA for solving (LU)X=I to get the inverse
56: setupFlops += 8.0 * m * m * m / 3;
57: }
59: PetscCallCXX(bs_dual.modify_host());
60: PetscCallCXX(bs2_dual.modify_host());
61: PetscCallCXX(blkMap_dual.modify_host());
62: PetscCallCXX(bs_dual.sync_device());
63: PetscCallCXX(bs2_dual.sync_device());
64: PetscCallCXX(blkMap_dual.sync_device());
65: PetscCall(PetscLogCpuToGpu(sizeof(PetscInt) * (2 * (nblocks + 1) + n)));
66: PetscFunctionReturn(PETSC_SUCCESS);
67: }
68: };
70: template <PetscBool transpose>
71: static PetscErrorCode PCApplyOrTranspose_VPBJacobi_Kokkos(PC pc, Vec x, Vec y)
72: {
73: PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
74: PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
75: ConstPetscScalarKokkosView xv;
76: PetscScalarKokkosView yv;
77: PetscScalarKokkosView diag = pckok->diag;
78: PetscIntKokkosView bs = pckok->bs_dual.view_device();
79: PetscIntKokkosView bs2 = pckok->bs2_dual.view_device();
80: PetscIntKokkosView blkMap = pckok->blkMap_dual.view_device();
81: const char *label = transpose ? "PCApplyTranspose_VPBJacobi" : "PCApply_VPBJacobi";
83: PetscFunctionBegin;
84: PetscCall(PetscLogGpuTimeBegin());
85: VecErrorIfNotKokkos(x);
86: VecErrorIfNotKokkos(y);
87: PetscCall(VecGetKokkosView(x, &xv));
88: PetscCall(VecGetKokkosViewWrite(y, &yv));
89: #if 0 // TODO: Why the TeamGemv version is 2x worse than the naive one?
90: PetscCallCXX(Kokkos::parallel_for(
91: label, Kokkos::TeamPolicy<>(jac->nblocks, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &team) {
92: PetscInt bid = team.league_rank(); // block id
93: PetscInt n = bs(bid + 1) - bs(bid); // size of this block
94: const PetscScalar *bbuf = &diag(bs2(bid));
95: const PetscScalar *xbuf = &xv(bs(bid));
96: PetscScalar *ybuf = &yv(bs(bid));
97: const auto &B = Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft>(bbuf, n, n); // wrap it in a 2D view in column-major order
98: const auto &x1 = ConstPetscScalarKokkosView(xbuf, n);
99: const auto &y1 = PetscScalarKokkosView(ybuf, n);
100: if (transpose) {
101: KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::Transpose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B^T * x1
102: } else {
103: KokkosBlas::TeamGemv<KokkosTeamMemberType, KokkosBlas::Trans::NoTranspose>::invoke(team, 1., B, x1, 0., y1); // y1 = 0.0 * y1 + 1.0 * B * x1
104: }
105: }));
106: #else
107: PetscCallCXX(Kokkos::parallel_for(
108: label, pckok->n, KOKKOS_LAMBDA(PetscInt row) {
109: const PetscScalar *Bp, *xp;
110: PetscScalar *yp;
111: PetscInt i, j, k, m;
113: k = blkMap(row); /* k-th block/matrix */
114: m = bs(k + 1) - bs(k); /* block size of the k-th block */
115: i = row - bs(k); /* i-th row of the block */
116: Bp = &diag(bs2(k) + i * (transpose ? m : 1)); /* Bp points to the first entry of i-th row/column */
117: xp = &xv(bs(k));
118: yp = &yv(bs(k));
120: yp[i] = 0.0;
121: for (j = 0; j < m; j++) {
122: yp[i] += Bp[0] * xp[j];
123: Bp += transpose ? 1 : m;
124: }
125: }));
126: #endif
127: PetscCall(VecRestoreKokkosView(x, &xv));
128: PetscCall(VecRestoreKokkosViewWrite(y, &yv));
129: PetscCall(PetscLogGpuFlops(pckok->nsize * 2)); /* FMA on entries in all blocks */
130: PetscCall(PetscLogGpuTimeEnd());
131: PetscFunctionReturn(PETSC_SUCCESS);
132: }
134: static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc)
135: {
136: PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
138: PetscFunctionBegin;
139: PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr));
140: PetscCall(PCDestroy_VPBJacobi(pc));
141: PetscFunctionReturn(PETSC_SUCCESS);
142: }
144: PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc)
145: {
146: PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
147: PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos *>(jac->spptr);
148: PetscInt i, nlocal, nblocks, nsize = 0;
149: const PetscInt *bsizes;
150: PetscBool ismpi;
151: Mat A;
153: PetscFunctionBegin;
154: PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
155: PetscCall(MatGetLocalSize(pc->pmat, &nlocal, NULL));
156: PetscCheck(!nlocal || nblocks, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must call MatSetVariableBlockSizes() before using PCVPBJACOBI");
158: if (!jac->diag) {
159: PetscInt max_bs = -1, min_bs = PETSC_MAX_INT;
160: for (i = 0; i < nblocks; i++) {
161: min_bs = PetscMin(min_bs, bsizes[i]);
162: max_bs = PetscMax(max_bs, bsizes[i]);
163: nsize += bsizes[i] * bsizes[i];
164: }
165: jac->nblocks = nblocks;
166: jac->min_bs = min_bs;
167: jac->max_bs = max_bs;
168: }
170: // If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway
171: if (pckok && (pckok->n != nlocal || pckok->nblocks != nblocks || pckok->nsize != nsize)) {
172: PetscCallCXX(delete pckok);
173: pckok = nullptr;
174: }
176: PetscCall(PetscLogGpuTimeBegin());
177: if (!pckok) {
178: PetscCallCXX(pckok = new PC_VPBJacobi_Kokkos(nlocal, nblocks, nsize, bsizes));
179: jac->spptr = pckok;
180: }
182: // Extract diagonal blocks from the matrix and compute their inverse
183: const auto &bs = pckok->bs_dual.view_device();
184: const auto &bs2 = pckok->bs2_dual.view_device();
185: const auto &blkMap = pckok->blkMap_dual.view_device();
186: PetscCall(PetscObjectBaseTypeCompare((PetscObject)pc->pmat, MATMPIAIJ, &ismpi));
187: A = ismpi ? static_cast<Mat_MPIAIJ *>((pc->pmat)->data)->A : pc->pmat;
188: PetscCall(MatInvertVariableBlockDiagonal_SeqAIJKokkos(A, bs, bs2, blkMap, pckok->work, pckok->diag));
189: pc->ops->apply = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_FALSE>;
190: pc->ops->applytranspose = PCApplyOrTranspose_VPBJacobi_Kokkos<PETSC_TRUE>;
191: pc->ops->destroy = PCDestroy_VPBJacobi_Kokkos;
192: PetscCall(PetscLogGpuTimeEnd());
193: PetscCall(PetscLogGpuFlops(pckok->setupFlops));
194: PetscFunctionReturn(PETSC_SUCCESS);
195: }