Actual source code: sfkok.kokkos.cxx

  1: #include <../src/vec/is/sf/impls/basic/sfpack.h>

  3: #include <petsc_kokkos.hpp>
  4: #include <petsc/private/kokkosimpl.hpp>

  6: using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace;

  8: typedef Kokkos::View<char *, DefaultMemorySpace>    deviceBuffer_t;
  9: typedef Kokkos::View<char *, HostMirrorMemorySpace> HostBuffer_t;

 11: typedef Kokkos::View<const char *, DefaultMemorySpace>    deviceConstBuffer_t;
 12: typedef Kokkos::View<const char *, HostMirrorMemorySpace> HostConstBuffer_t;

 14: /*====================================================================================*/
 15: /*                             Regular operations                           */
 16: /*====================================================================================*/
 17: template <typename Type>
 18: struct Insert {
 19:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 20:   {
 21:     Type old = x;
 22:     x        = y;
 23:     return old;
 24:   }
 25: };
 26: template <typename Type>
 27: struct Add {
 28:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 29:   {
 30:     Type old = x;
 31:     x += y;
 32:     return old;
 33:   }
 34: };
 35: template <typename Type>
 36: struct Mult {
 37:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 38:   {
 39:     Type old = x;
 40:     x *= y;
 41:     return old;
 42:   }
 43: };
 44: template <typename Type>
 45: struct Min {
 46:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 47:   {
 48:     Type old = x;
 49:     x        = PetscMin(x, y);
 50:     return old;
 51:   }
 52: };
 53: template <typename Type>
 54: struct Max {
 55:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 56:   {
 57:     Type old = x;
 58:     x        = PetscMax(x, y);
 59:     return old;
 60:   }
 61: };
 62: template <typename Type>
 63: struct LAND {
 64:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 65:   {
 66:     Type old = x;
 67:     x        = x && y;
 68:     return old;
 69:   }
 70: };
 71: template <typename Type>
 72: struct LOR {
 73:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 74:   {
 75:     Type old = x;
 76:     x        = x || y;
 77:     return old;
 78:   }
 79: };
 80: template <typename Type>
 81: struct LXOR {
 82:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 83:   {
 84:     Type old = x;
 85:     x        = !x != !y;
 86:     return old;
 87:   }
 88: };
 89: template <typename Type>
 90: struct BAND {
 91:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 92:   {
 93:     Type old = x;
 94:     x        = x & y;
 95:     return old;
 96:   }
 97: };
 98: template <typename Type>
 99: struct BOR {
100:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
101:   {
102:     Type old = x;
103:     x        = x | y;
104:     return old;
105:   }
106: };
107: template <typename Type>
108: struct BXOR {
109:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
110:   {
111:     Type old = x;
112:     x        = x ^ y;
113:     return old;
114:   }
115: };
116: template <typename PairType>
117: struct Minloc {
118:   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
119:   {
120:     PairType old = x;
121:     if (y.first < x.first) x = y;
122:     else if (y.first == x.first) x.second = PetscMin(x.second, y.second);
123:     return old;
124:   }
125: };
126: template <typename PairType>
127: struct Maxloc {
128:   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
129:   {
130:     PairType old = x;
131:     if (y.first > x.first) x = y;
132:     else if (y.first == x.first) x.second = PetscMin(x.second, y.second); /* See MPI MAXLOC */
133:     return old;
134:   }
135: };

137: /*====================================================================================*/
138: /*                             Atomic operations                            */
139: /*====================================================================================*/
140: template <typename Type>
141: struct AtomicInsert {
142:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_store(&x, y); }
143: };
144: template <typename Type>
145: struct AtomicAdd {
146:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_add(&x, y); }
147: };
148: template <typename Type>
149: struct AtomicBAND {
150:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_and(&x, y); }
151: };
152: template <typename Type>
153: struct AtomicBOR {
154:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_or(&x, y); }
155: };
156: template <typename Type>
157: struct AtomicBXOR {
158:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_xor(&x, y); }
159: };
160: template <typename Type>
161: struct AtomicLAND {
162:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
163:   {
164:     const Type zero = 0, one = ~0;
165:     Kokkos::atomic_and(&x, y ? one : zero);
166:   }
167: };
168: template <typename Type>
169: struct AtomicLOR {
170:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
171:   {
172:     const Type zero = 0, one = 1;
173:     Kokkos::atomic_or(&x, y ? one : zero);
174:   }
175: };
176: template <typename Type>
177: struct AtomicMult {
178:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_mul(&x, y); }
179: };
180: template <typename Type>
181: struct AtomicMin {
182:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_min(&x, y); }
183: };
184: template <typename Type>
185: struct AtomicMax {
186:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_max(&x, y); }
187: };
188: /* TODO: struct AtomicLXOR  */
189: template <typename Type>
190: struct AtomicFetchAdd {
191:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { return Kokkos::atomic_fetch_add(&x, y); }
192: };

194: /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
195: static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid)
196: {
197:   PetscInt        i, j, k, m, n, r;
198:   const PetscInt *offset, *start, *dx, *dy, *X, *Y;

200:   n      = opt[0];
201:   offset = opt + 1;
202:   start  = opt + n + 2;
203:   dx     = opt + 2 * n + 2;
204:   dy     = opt + 3 * n + 2;
205:   X      = opt + 5 * n + 2;
206:   Y      = opt + 6 * n + 2;
207:   for (r = 0; r < n; r++) {
208:     if (tid < offset[r + 1]) break;
209:   }
210:   m = (tid - offset[r]);
211:   k = m / (dx[r] * dy[r]);
212:   j = (m - k * dx[r] * dy[r]) / dx[r];
213:   i = m - k * dx[r] * dy[r] - j * dx[r];

215:   return start[r] + k * X[r] * Y[r] + j * X[r] + i;
216: }

218: /*====================================================================================*/
219: /*  Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link'         */
220: /*====================================================================================*/

222: /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
223:    <Type> is PetscReal, which is the primitive type we operate on.
224:    <bs>   is 16, which says <unit> contains 16 primitive types.
225:    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
226:    <EQ>   is 0, which is (bs == BS ? 1 : 0)

228:   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
229:   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
230: */
231: template <typename Type, PetscInt BS, PetscInt EQ>
232: static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data_, void *buf_)
233: {
234:   const PetscInt      *iopt = opt ? opt->array : NULL;
235:   const PetscInt       M = EQ ? 1 : link->bs / BS, MBS = M * BS; /* If EQ, then MBS will be a compile-time const */
236:   const Type          *data = static_cast<const Type *>(data_);
237:   Type                *buf  = static_cast<Type *>(buf_);
238:   DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();

240:   PetscFunctionBegin;
241:   Kokkos::parallel_for(
242:     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
243:       /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
244:        iopt == NULL && idx == NULL ==> the indices are contiguous;
245:      */
246:       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
247:       PetscInt s = tid * MBS;
248:       for (int i = 0; i < MBS; i++) buf[s + i] = data[t + i];
249:     });
250:   PetscFunctionReturn(PETSC_SUCCESS);
251: }

253: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
254: static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data_, const void *buf_)
255: {
256:   Op                   op;
257:   const PetscInt      *iopt = opt ? opt->array : NULL;
258:   const PetscInt       M = EQ ? 1 : link->bs / BS, MBS = M * BS;
259:   Type                *data = static_cast<Type *>(data_);
260:   const Type          *buf  = static_cast<const Type *>(buf_);
261:   DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();

263:   PetscFunctionBegin;
264:   Kokkos::parallel_for(
265:     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
266:       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
267:       PetscInt s = tid * MBS;
268:       for (int i = 0; i < MBS; i++) op(data[t + i], buf[s + i]);
269:     });
270:   PetscFunctionReturn(PETSC_SUCCESS);
271: }

273: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
274: static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf)
275: {
276:   Op                   op;
277:   const PetscInt      *ropt = opt ? opt->array : NULL;
278:   const PetscInt       M = EQ ? 1 : link->bs / BS, MBS = M * BS;
279:   Type                *rootdata = static_cast<Type *>(data), *leafbuf = static_cast<Type *>(buf);
280:   DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();

282:   PetscFunctionBegin;
283:   Kokkos::parallel_for(
284:     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
285:       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
286:       PetscInt l = tid * MBS;
287:       for (int i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]);
288:     });
289:   PetscFunctionReturn(PETSC_SUCCESS);
290: }

292: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
293: static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
294: {
295:   PetscInt             srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0;
296:   const PetscInt       M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
297:   const Type          *src  = static_cast<const Type *>(src_);
298:   Type                *dst  = static_cast<Type *>(dst_);
299:   DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();

301:   PetscFunctionBegin;
302:   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
303:   if (srcOpt) {
304:     srcx     = srcOpt->dx[0];
305:     srcy     = srcOpt->dy[0];
306:     srcX     = srcOpt->X[0];
307:     srcY     = srcOpt->Y[0];
308:     srcStart = srcOpt->start[0];
309:     srcIdx   = NULL;
310:   } else if (!srcIdx) {
311:     srcx = srcX = count;
312:     srcy = srcY = 1;
313:   }

315:   if (dstOpt) {
316:     dstx     = dstOpt->dx[0];
317:     dsty     = dstOpt->dy[0];
318:     dstX     = dstOpt->X[0];
319:     dstY     = dstOpt->Y[0];
320:     dstStart = dstOpt->start[0];
321:     dstIdx   = NULL;
322:   } else if (!dstIdx) {
323:     dstx = dstX = count;
324:     dsty = dstY = 1;
325:   }

327:   Kokkos::parallel_for(
328:     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
329:       PetscInt i, j, k, s, t;
330:       Op       op;
331:       if (!srcIdx) { /* src is in 3D */
332:         k = tid / (srcx * srcy);
333:         j = (tid - k * srcx * srcy) / srcx;
334:         i = tid - k * srcx * srcy - j * srcx;
335:         s = srcStart + k * srcX * srcY + j * srcX + i;
336:       } else { /* src is contiguous */
337:         s = srcIdx[tid];
338:       }

340:       if (!dstIdx) { /* 3D */
341:         k = tid / (dstx * dsty);
342:         j = (tid - k * dstx * dsty) / dstx;
343:         i = tid - k * dstx * dsty - j * dstx;
344:         t = dstStart + k * dstX * dstY + j * dstX + i;
345:       } else { /* contiguous */
346:         t = dstIdx[tid];
347:       }

349:       s *= MBS;
350:       t *= MBS;
351:       for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]);
352:     });
353:   PetscFunctionReturn(PETSC_SUCCESS);
354: }

356: /* Specialization for Insert since we may use memcpy */
357: template <typename Type, PetscInt BS, PetscInt EQ>
358: static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
359: {
360:   const Type          *src  = static_cast<const Type *>(src_);
361:   Type                *dst  = static_cast<Type *>(dst_);
362:   DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();

364:   PetscFunctionBegin;
365:   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
366:   /*src and dst are contiguous */
367:   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
368:     size_t              sz = count * link->unitbytes;
369:     deviceBuffer_t      dbuf(reinterpret_cast<char *>(dst + dstStart * link->bs), sz);
370:     deviceConstBuffer_t sbuf(reinterpret_cast<const char *>(src + srcStart * link->bs), sz);
371:     Kokkos::deep_copy(exec, dbuf, sbuf);
372:   } else {
373:     PetscCall(ScatterAndOp<Type, Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst));
374:   }
375:   PetscFunctionReturn(PETSC_SUCCESS);
376: }

378: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
379: static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata_, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata_, void *leafupdate_)
380: {
381:   Op                   op;
382:   const PetscInt       M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
383:   const PetscInt      *ropt     = rootopt ? rootopt->array : NULL;
384:   const PetscInt      *lopt     = leafopt ? leafopt->array : NULL;
385:   Type                *rootdata = static_cast<Type *>(rootdata_), *leafupdate = static_cast<Type *>(leafupdate_);
386:   const Type          *leafdata = static_cast<const Type *>(leafdata_);
387:   DeviceExecutionSpace exec     = PetscGetKokkosExecutionSpace();

389:   PetscFunctionBegin;
390:   Kokkos::parallel_for(
391:     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
392:       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
393:       PetscInt l = (lopt ? MapTidToIndex(lopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS;
394:       for (int i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]);
395:     });
396:   PetscFunctionReturn(PETSC_SUCCESS);
397: }

399: /*====================================================================================*/
400: /*  Init various types and instantiate pack/unpack function pointers                  */
401: /*====================================================================================*/
402: template <typename Type, PetscInt BS, PetscInt EQ>
403: static void PackInit_RealType(PetscSFLink link)
404: {
405:   /* Pack/unpack for remote communication */
406:   link->d_Pack            = Pack<Type, BS, EQ>;
407:   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
408:   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
409:   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
410:   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
411:   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
412:   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
413:   /* Scatter for local communication */
414:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */
415:   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
416:   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
417:   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
418:   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
419:   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
420:   /* Atomic versions when there are data-race possibilities */
421:   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
422:   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
423:   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
424:   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
425:   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
426:   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;

428:   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
429:   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
430:   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
431:   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
432:   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
433:   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
434: }

436: template <typename Type, PetscInt BS, PetscInt EQ>
437: static void PackInit_IntegerType(PetscSFLink link)
438: {
439:   link->d_Pack            = Pack<Type, BS, EQ>;
440:   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
441:   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
442:   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
443:   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
444:   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
445:   link->d_UnpackAndLAND   = UnpackAndOp<Type, LAND<Type>, BS, EQ>;
446:   link->d_UnpackAndLOR    = UnpackAndOp<Type, LOR<Type>, BS, EQ>;
447:   link->d_UnpackAndLXOR   = UnpackAndOp<Type, LXOR<Type>, BS, EQ>;
448:   link->d_UnpackAndBAND   = UnpackAndOp<Type, BAND<Type>, BS, EQ>;
449:   link->d_UnpackAndBOR    = UnpackAndOp<Type, BOR<Type>, BS, EQ>;
450:   link->d_UnpackAndBXOR   = UnpackAndOp<Type, BXOR<Type>, BS, EQ>;
451:   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;

453:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
454:   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
455:   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
456:   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
457:   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
458:   link->d_ScatterAndLAND   = ScatterAndOp<Type, LAND<Type>, BS, EQ>;
459:   link->d_ScatterAndLOR    = ScatterAndOp<Type, LOR<Type>, BS, EQ>;
460:   link->d_ScatterAndLXOR   = ScatterAndOp<Type, LXOR<Type>, BS, EQ>;
461:   link->d_ScatterAndBAND   = ScatterAndOp<Type, BAND<Type>, BS, EQ>;
462:   link->d_ScatterAndBOR    = ScatterAndOp<Type, BOR<Type>, BS, EQ>;
463:   link->d_ScatterAndBXOR   = ScatterAndOp<Type, BXOR<Type>, BS, EQ>;
464:   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;

466:   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
467:   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
468:   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
469:   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
470:   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
471:   link->da_UnpackAndLAND   = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>;
472:   link->da_UnpackAndLOR    = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>;
473:   link->da_UnpackAndBAND   = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>;
474:   link->da_UnpackAndBOR    = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>;
475:   link->da_UnpackAndBXOR   = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
476:   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;

478:   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
479:   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
480:   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
481:   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
482:   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
483:   link->da_ScatterAndLAND   = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>;
484:   link->da_ScatterAndLOR    = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>;
485:   link->da_ScatterAndBAND   = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>;
486:   link->da_ScatterAndBOR    = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>;
487:   link->da_ScatterAndBXOR   = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
488:   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
489: }

491: #if defined(PETSC_HAVE_COMPLEX)
492: template <typename Type, PetscInt BS, PetscInt EQ>
493: static void PackInit_ComplexType(PetscSFLink link)
494: {
495:   link->d_Pack            = Pack<Type, BS, EQ>;
496:   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
497:   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
498:   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
499:   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;

501:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
502:   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
503:   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
504:   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;

506:   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
507:   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
508:   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
509:   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;

511:   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
512:   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
513:   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
514:   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
515: }
516: #endif

518: template <typename Type>
519: static void PackInit_PairType(PetscSFLink link)
520: {
521:   link->d_Pack            = Pack<Type, 1, 1>;
522:   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, 1, 1>;
523:   link->d_UnpackAndMaxloc = UnpackAndOp<Type, Maxloc<Type>, 1, 1>;
524:   link->d_UnpackAndMinloc = UnpackAndOp<Type, Minloc<Type>, 1, 1>;

526:   link->d_ScatterAndInsert = ScatterAndOp<Type, Insert<Type>, 1, 1>;
527:   link->d_ScatterAndMaxloc = ScatterAndOp<Type, Maxloc<Type>, 1, 1>;
528:   link->d_ScatterAndMinloc = ScatterAndOp<Type, Minloc<Type>, 1, 1>;
529:   /* Atomics for pair types are not implemented yet */
530: }

532: template <typename Type, PetscInt BS, PetscInt EQ>
533: static void PackInit_DumbType(PetscSFLink link)
534: {
535:   link->d_Pack             = Pack<Type, BS, EQ>;
536:   link->d_UnpackAndInsert  = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
537:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
538:   /* Atomics for dumb types are not implemented yet */
539: }

541: /*
542:   Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug
543:   that one is not able to repeatedly create and destroy the object. SF's original design was each
544:   SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from
545:   destroying multiple SFLinks with NULL stream and the default execution space object. To avoid
546:   memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos
547:   does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton
548:   object in Kokkos.
549: */
550: /*
551: static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link)
552: {
553:   PetscFunctionBegin;
554:   PetscFunctionReturn(PETSC_SUCCESS);
555: }
556: */

558: /* Some device-specific utilities */
559: static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link)
560: {
561:   PetscFunctionBegin;
562:   Kokkos::fence();
563:   PetscFunctionReturn(PETSC_SUCCESS);
564: }

566: static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link)
567: {
568:   DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();

570:   PetscFunctionBegin;
571:   exec.fence();
572:   PetscFunctionReturn(PETSC_SUCCESS);
573: }

575: static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n)
576: {
577:   DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();

579:   PetscFunctionBegin;
580:   if (!n) PetscFunctionReturn(PETSC_SUCCESS);
581:   if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) { // H2H
582:     PetscCallCXX(exec.fence());                                   // make sure async kernels on src are finished, in case of unified memory as on AMD MI300A.
583:     PetscCall(PetscMemcpy(dst, src, n));
584:   } else {
585:     if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) { // H2D
586:       deviceBuffer_t    dbuf(static_cast<char *>(dst), n);
587:       HostConstBuffer_t sbuf(static_cast<const char *>(src), n);
588:       PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
589:       PetscCall(PetscLogCpuToGpu(n));
590:     } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) { // D2H
591:       HostBuffer_t        dbuf(static_cast<char *>(dst), n);
592:       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
593:       PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
594:       PetscCallCXX(exec.fence()); // make sure dbuf is ready for use immediately on host
595:       PetscCall(PetscLogGpuToCpu(n));
596:     } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) { // D2D
597:       deviceBuffer_t      dbuf(static_cast<char *>(dst), n);
598:       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
599:       PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
600:     }
601:   }
602:   PetscFunctionReturn(PETSC_SUCCESS);
603: }

605: PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype, size_t size, void **ptr)
606: {
607:   PetscFunctionBegin;
608:   if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr));
609:   else if (PetscMemTypeDevice(mtype)) {
610:     if (!PetscKokkosInitialized) PetscCall(PetscKokkosInitializeCheck());
611:     PetscCallCXX(*ptr = Kokkos::kokkos_malloc<DefaultMemorySpace>(size));
612:   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
613:   PetscFunctionReturn(PETSC_SUCCESS);
614: }

616: PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype, void *ptr)
617: {
618:   PetscFunctionBegin;
619:   if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr));
620:   else if (PetscMemTypeDevice(mtype)) {
621:     PetscCallCXX(Kokkos::kokkos_free<DefaultMemorySpace>(ptr));
622:   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
623:   PetscFunctionReturn(PETSC_SUCCESS);
624: }

626: /* Destructor when the link uses MPI for communication */
627: static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf, PetscSFLink link)
628: {
629:   PetscFunctionBegin;
630:   for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) {
631:     PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
632:     PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
633:   }
634:   PetscFunctionReturn(PETSC_SUCCESS);
635: }

637: /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
638: PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf, PetscSFLink link, MPI_Datatype unit)
639: {
640:   PetscInt  nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0;
641:   PetscBool is2Int, is2PetscInt;
642: #if defined(PETSC_HAVE_COMPLEX)
643:   PetscInt nPetscComplex = 0;
644: #endif

646:   PetscFunctionBegin;
647:   if (link->deviceinited) PetscFunctionReturn(PETSC_SUCCESS);
648:   PetscCall(PetscKokkosInitializeCheck());
649:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar));
650:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar));
651:   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
652:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt));
653:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt));
654:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal));
655: #if defined(PETSC_HAVE_COMPLEX)
656:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex));
657: #endif
658:   PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int));
659:   PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt));

661:   if (is2Int) {
662:     PackInit_PairType<Kokkos::pair<int, int>>(link);
663:   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
664:     PackInit_PairType<Kokkos::pair<PetscInt, PetscInt>>(link);
665:   } else if (nPetscReal) {
666: #if !defined(PETSC_HAVE_DEVICE) /* Skip the unimportant stuff to speed up SF device compilation time */
667:     if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link);
668:     else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link);
669:     else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link);
670:     else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link);
671:     else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link);
672:     else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link);
673:     else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link);
674:     else if (nPetscReal % 1 == 0)
675: #endif
676:       PackInit_RealType<PetscReal, 1, 0>(link);
677:   } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
678: #if !defined(PETSC_HAVE_DEVICE)
679:     if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link);
680:     else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link);
681:     else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link);
682:     else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link);
683:     else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link);
684:     else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link);
685:     else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link);
686:     else if (nPetscInt % 1 == 0)
687: #endif
688:       PackInit_IntegerType<llint, 1, 0>(link);
689:   } else if (nInt) {
690: #if !defined(PETSC_HAVE_DEVICE)
691:     if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link);
692:     else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link);
693:     else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link);
694:     else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link);
695:     else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link);
696:     else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link);
697:     else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link);
698:     else if (nInt % 1 == 0)
699: #endif
700:       PackInit_IntegerType<int, 1, 0>(link);
701:   } else if (nSignedChar) {
702: #if !defined(PETSC_HAVE_DEVICE)
703:     if (nSignedChar == 8) PackInit_IntegerType<char, 8, 1>(link);
704:     else if (nSignedChar % 8 == 0) PackInit_IntegerType<char, 8, 0>(link);
705:     else if (nSignedChar == 4) PackInit_IntegerType<char, 4, 1>(link);
706:     else if (nSignedChar % 4 == 0) PackInit_IntegerType<char, 4, 0>(link);
707:     else if (nSignedChar == 2) PackInit_IntegerType<char, 2, 1>(link);
708:     else if (nSignedChar % 2 == 0) PackInit_IntegerType<char, 2, 0>(link);
709:     else if (nSignedChar == 1) PackInit_IntegerType<char, 1, 1>(link);
710:     else if (nSignedChar % 1 == 0)
711: #endif
712:       PackInit_IntegerType<char, 1, 0>(link);
713:   } else if (nUnsignedChar) {
714: #if !defined(PETSC_HAVE_DEVICE)
715:     if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char, 8, 1>(link);
716:     else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<unsigned char, 8, 0>(link);
717:     else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char, 4, 1>(link);
718:     else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<unsigned char, 4, 0>(link);
719:     else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char, 2, 1>(link);
720:     else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<unsigned char, 2, 0>(link);
721:     else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char, 1, 1>(link);
722:     else if (nUnsignedChar % 1 == 0)
723: #endif
724:       PackInit_IntegerType<unsigned char, 1, 0>(link);
725: #if defined(PETSC_HAVE_COMPLEX)
726:   } else if (nPetscComplex) {
727:   #if !defined(PETSC_HAVE_DEVICE)
728:     if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 1>(link);
729:     else if (nPetscComplex % 8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 0>(link);
730:     else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 1>(link);
731:     else if (nPetscComplex % 4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 0>(link);
732:     else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 1>(link);
733:     else if (nPetscComplex % 2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 0>(link);
734:     else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 1>(link);
735:     else if (nPetscComplex % 1 == 0)
736:   #endif
737:       PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 0>(link);
738: #endif
739:   } else {
740:     MPI_Aint nbyte;

742:     PetscCall(PetscSFGetDatatypeSize_Internal(PETSC_COMM_SELF, unit, &nbyte));
743:     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
744: #if !defined(PETSC_HAVE_DEVICE)
745:       if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link);
746:       else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link);
747:       else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link);
748:       else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link);
749:       else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link);
750:       else if (nbyte % 1 == 0)
751: #endif
752:         PackInit_DumbType<char, 1, 0>(link);
753:     } else {
754:       PetscCall(PetscIntCast(nbyte / sizeof(int), &nInt));
755: #if !defined(PETSC_HAVE_DEVICE)
756:       if (nInt == 8) PackInit_DumbType<int, 8, 1>(link);
757:       else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link);
758:       else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link);
759:       else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link);
760:       else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link);
761:       else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link);
762:       else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link);
763:       else if (nInt % 1 == 0)
764: #endif
765:         PackInit_DumbType<int, 1, 0>(link);
766:     }
767:   }

769:   link->SyncDevice   = PetscSFLinkSyncDevice_Kokkos;
770:   link->SyncStream   = PetscSFLinkSyncStream_Kokkos;
771:   link->Memcpy       = PetscSFLinkMemcpy_Kokkos;
772:   link->Destroy      = PetscSFLinkDestroy_Kokkos;
773:   link->deviceinited = PETSC_TRUE;
774:   PetscFunctionReturn(PETSC_SUCCESS);
775: }