Actual source code: sfcupm_impl.hpp
1: #pragma once
3: #include "sfcupm.hpp"
4: #include <../src/sys/objects/device/impls/cupm/kernels.hpp>
5: #include <petsc/private/cupmatomics.hpp>
7: namespace Petsc
8: {
10: namespace sf
11: {
13: namespace cupm
14: {
16: namespace kernels
17: {
19: /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
20: PETSC_NODISCARD static PETSC_DEVICE_INLINE_DECL PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid) noexcept
21: {
22: PetscInt i, j, k, m, n, r;
23: const PetscInt *offset, *start, *dx, *dy, *X, *Y;
25: n = opt[0];
26: offset = opt + 1;
27: start = opt + n + 2;
28: dx = opt + 2 * n + 2;
29: dy = opt + 3 * n + 2;
30: X = opt + 5 * n + 2;
31: Y = opt + 6 * n + 2;
32: for (r = 0; r < n; r++) {
33: if (tid < offset[r + 1]) break;
34: }
35: m = (tid - offset[r]);
36: k = m / (dx[r] * dy[r]);
37: j = (m - k * dx[r] * dy[r]) / dx[r];
38: i = m - k * dx[r] * dy[r] - j * dx[r];
40: return (start[r] + k * X[r] * Y[r] + j * X[r] + i);
41: }
43: /*====================================================================================*/
44: /* Templated CUPM kernels for pack/unpack. The Op can be regular or atomic */
45: /*====================================================================================*/
47: /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
48: <Type> is PetscReal, which is the primitive type we operate on.
49: <bs> is 16, which says <unit> contains 16 primitive types.
50: <BS> is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
51: <EQ> is 0, which is (bs == BS ? 1 : 0)
53: If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
54: For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
55: */
56: template <class Type, PetscInt BS, PetscInt EQ>
57: PETSC_KERNEL_DECL static void d_Pack(PetscInt bs, PetscInt count, PetscInt start, const PetscInt *opt, const PetscInt *idx, const Type *data, Type *buf)
58: {
59: const PetscInt M = (EQ) ? 1 : bs / BS; /* If EQ, then M=1 enables compiler's const-propagation */
60: const PetscInt MBS = M * BS; /* MBS=bs. We turn MBS into a compile-time const when EQ=1. */
62: ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
63: PetscInt t = (opt ? MapTidToIndex(opt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
64: PetscInt s = tid * MBS;
65: for (PetscInt i = 0; i < MBS; i++) buf[s + i] = data[t + i];
66: });
67: }
69: template <class Type, class Op, PetscInt BS, PetscInt EQ>
70: PETSC_KERNEL_DECL static void d_UnpackAndOp(PetscInt bs, PetscInt count, PetscInt start, const PetscInt *opt, const PetscInt *idx, Type *data, const Type *buf)
71: {
72: const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
73: Op op;
75: ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
76: PetscInt t = (opt ? MapTidToIndex(opt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
77: PetscInt s = tid * MBS;
78: for (PetscInt i = 0; i < MBS; i++) op(data[t + i], buf[s + i]);
79: });
80: }
82: template <class Type, class Op, PetscInt BS, PetscInt EQ>
83: PETSC_KERNEL_DECL static void d_FetchAndOp(PetscInt bs, PetscInt count, PetscInt rootstart, const PetscInt *rootopt, const PetscInt *rootidx, Type *rootdata, Type *leafbuf)
84: {
85: const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
86: Op op;
88: ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
89: PetscInt r = (rootopt ? MapTidToIndex(rootopt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
90: PetscInt l = tid * MBS;
91: for (PetscInt i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]);
92: });
93: }
95: template <class Type, class Op, PetscInt BS, PetscInt EQ>
96: PETSC_KERNEL_DECL static void d_ScatterAndOp(PetscInt bs, PetscInt count, PetscInt srcx, PetscInt srcy, PetscInt srcX, PetscInt srcY, PetscInt srcStart, const PetscInt *srcIdx, const Type *src, PetscInt dstx, PetscInt dsty, PetscInt dstX, PetscInt dstY, PetscInt dstStart, const PetscInt *dstIdx, Type *dst)
97: {
98: const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
99: Op op;
101: ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
102: PetscInt s, t;
104: if (!srcIdx) { /* src is either contiguous or 3D */
105: PetscInt k = tid / (srcx * srcy);
106: PetscInt j = (tid - k * srcx * srcy) / srcx;
107: PetscInt i = tid - k * srcx * srcy - j * srcx;
109: s = srcStart + k * srcX * srcY + j * srcX + i;
110: } else {
111: s = srcIdx[tid];
112: }
114: if (!dstIdx) { /* dst is either contiguous or 3D */
115: PetscInt k = tid / (dstx * dsty);
116: PetscInt j = (tid - k * dstx * dsty) / dstx;
117: PetscInt i = tid - k * dstx * dsty - j * dstx;
119: t = dstStart + k * dstX * dstY + j * dstX + i;
120: } else {
121: t = dstIdx[tid];
122: }
124: s *= MBS;
125: t *= MBS;
126: for (PetscInt i = 0; i < MBS; i++) op(dst[t + i], src[s + i]);
127: });
128: }
130: template <class Type, class Op, PetscInt BS, PetscInt EQ>
131: PETSC_KERNEL_DECL static void d_FetchAndOpLocal(PetscInt bs, PetscInt count, PetscInt rootstart, const PetscInt *rootopt, const PetscInt *rootidx, Type *rootdata, PetscInt leafstart, const PetscInt *leafopt, const PetscInt *leafidx, const Type *leafdata, Type *leafupdate)
132: {
133: const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
134: Op op;
136: ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
137: PetscInt r = (rootopt ? MapTidToIndex(rootopt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
138: PetscInt l = (leafopt ? MapTidToIndex(leafopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS;
139: for (PetscInt i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]);
140: });
141: }
143: /*====================================================================================*/
144: /* Regular operations on device */
145: /*====================================================================================*/
146: template <typename Type>
147: struct Insert {
148: PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
149: {
150: Type old = x;
151: x = y;
152: return old;
153: }
154: };
155: template <typename Type>
156: struct Add {
157: PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
158: {
159: Type old = x;
160: x += y;
161: return old;
162: }
163: };
164: template <typename Type>
165: struct Mult {
166: PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
167: {
168: Type old = x;
169: x *= y;
170: return old;
171: }
172: };
173: template <typename Type>
174: struct Min {
175: PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
176: {
177: Type old = x;
178: x = PetscMin(x, y);
179: return old;
180: }
181: };
182: template <typename Type>
183: struct Max {
184: PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
185: {
186: Type old = x;
187: x = PetscMax(x, y);
188: return old;
189: }
190: };
191: template <typename Type>
192: struct LAND {
193: PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
194: {
195: Type old = x;
196: x = x && y;
197: return old;
198: }
199: };
200: template <typename Type>
201: struct LOR {
202: PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
203: {
204: Type old = x;
205: x = x || y;
206: return old;
207: }
208: };
209: template <typename Type>
210: struct LXOR {
211: PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
212: {
213: Type old = x;
214: x = !x != !y;
215: return old;
216: }
217: };
218: template <typename Type>
219: struct BAND {
220: PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
221: {
222: Type old = x;
223: x = x & y;
224: return old;
225: }
226: };
227: template <typename Type>
228: struct BOR {
229: PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
230: {
231: Type old = x;
232: x = x | y;
233: return old;
234: }
235: };
236: template <typename Type>
237: struct BXOR {
238: PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
239: {
240: Type old = x;
241: x = x ^ y;
242: return old;
243: }
244: };
245: template <typename Type>
246: struct Minloc {
247: PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
248: {
249: Type old = x;
250: if (y.a < x.a) x = y;
251: else if (y.a == x.a) x.b = min(x.b, y.b);
252: return old;
253: }
254: };
255: template <typename Type>
256: struct Maxloc {
257: PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
258: {
259: Type old = x;
260: if (y.a > x.a) x = y;
261: else if (y.a == x.a) x.b = min(x.b, y.b); /* See MPI MAXLOC */
262: return old;
263: }
264: };
266: } // namespace kernels
268: namespace impl
269: {
271: /*====================================================================================*/
272: /* Wrapper functions of cupm kernels. Function pointers are stored in 'link' */
273: /*====================================================================================*/
274: template <device::cupm::DeviceType T>
275: template <typename Type, PetscInt BS, PetscInt EQ>
276: inline PetscErrorCode SfInterface<T>::Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data, void *buf) noexcept
277: {
278: const PetscInt *iarray = opt ? opt->array : NULL;
280: PetscFunctionBegin;
281: if (!count) PetscFunctionReturn(PETSC_SUCCESS);
282: if (PetscDefined(USING_NVCC) && !opt && !idx) { /* It is a 'CUDA data to nvshmem buf' memory copy */
283: PetscCallCUPM(cupmMemcpyAsync(buf, (char *)data + start * link->unitbytes, count * link->unitbytes, cupmMemcpyDeviceToDevice, link->stream));
284: } else {
285: PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_Pack<Type, BS, EQ>, link->bs, count, start, iarray, idx, (const Type *)data, (Type *)buf));
286: }
287: PetscFunctionReturn(PETSC_SUCCESS);
288: }
290: template <device::cupm::DeviceType T>
291: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
292: inline PetscErrorCode SfInterface<T>::UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, const void *buf) noexcept
293: {
294: const PetscInt *iarray = opt ? opt->array : NULL;
296: PetscFunctionBegin;
297: if (!count) PetscFunctionReturn(PETSC_SUCCESS);
298: if (PetscDefined(USING_NVCC) && std::is_same<Op, kernels::Insert<Type>>::value && !opt && !idx) { /* It is a 'nvshmem buf to CUDA data' memory copy */
299: PetscCallCUPM(cupmMemcpyAsync((char *)data + start * link->unitbytes, buf, count * link->unitbytes, cupmMemcpyDeviceToDevice, link->stream));
300: } else {
301: PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_UnpackAndOp<Type, Op, BS, EQ>, link->bs, count, start, iarray, idx, (Type *)data, (const Type *)buf));
302: }
303: PetscFunctionReturn(PETSC_SUCCESS);
304: }
306: template <device::cupm::DeviceType T>
307: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
308: inline PetscErrorCode SfInterface<T>::FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf) noexcept
309: {
310: const PetscInt *iarray = opt ? opt->array : NULL;
312: PetscFunctionBegin;
313: if (!count) PetscFunctionReturn(PETSC_SUCCESS);
314: PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_FetchAndOp<Type, Op, BS, EQ>, link->bs, count, start, iarray, idx, (Type *)data, (const Type *)buf));
315: PetscFunctionReturn(PETSC_SUCCESS);
316: }
318: template <device::cupm::DeviceType T>
319: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
320: inline PetscErrorCode SfInterface<T>::ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst) noexcept
321: {
322: PetscInt nthreads = 256;
323: PetscInt nblocks = (count + nthreads - 1) / nthreads;
324: PetscInt srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0;
326: PetscFunctionBegin;
327: if (!count) PetscFunctionReturn(PETSC_SUCCESS);
328: nblocks = PetscMin(nblocks, link->maxResidentThreadsPerGPU / nthreads);
330: /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use 3D grid and block */
331: if (srcOpt) {
332: srcx = srcOpt->dx[0];
333: srcy = srcOpt->dy[0];
334: srcX = srcOpt->X[0];
335: srcY = srcOpt->Y[0];
336: srcStart = srcOpt->start[0];
337: srcIdx = NULL;
338: } else if (!srcIdx) {
339: srcx = srcX = count;
340: srcy = srcY = 1;
341: }
343: if (dstOpt) {
344: dstx = dstOpt->dx[0];
345: dsty = dstOpt->dy[0];
346: dstX = dstOpt->X[0];
347: dstY = dstOpt->Y[0];
348: dstStart = dstOpt->start[0];
349: dstIdx = NULL;
350: } else if (!dstIdx) {
351: dstx = dstX = count;
352: dsty = dstY = 1;
353: }
355: PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_ScatterAndOp<Type, Op, BS, EQ>, link->bs, count, srcx, srcy, srcX, srcY, srcStart, srcIdx, (const Type *)src, dstx, dsty, dstX, dstY, dstStart, dstIdx, (Type *)dst));
356: PetscFunctionReturn(PETSC_SUCCESS);
357: }
359: template <device::cupm::DeviceType T>
360: /* Specialization for Insert since we may use cupmMemcpyAsync */
361: template <typename Type, PetscInt BS, PetscInt EQ>
362: inline PetscErrorCode SfInterface<T>::ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst) noexcept
363: {
364: PetscFunctionBegin;
365: if (!count) PetscFunctionReturn(PETSC_SUCCESS);
366: /*src and dst are contiguous */
367: if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
368: PetscCallCUPM(cupmMemcpyAsync((Type *)dst + dstStart * link->bs, (const Type *)src + srcStart * link->bs, count * link->unitbytes, cupmMemcpyDeviceToDevice, link->stream));
369: } else {
370: PetscCall(ScatterAndOp<Type, kernels::Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst));
371: }
372: PetscFunctionReturn(PETSC_SUCCESS);
373: }
375: template <device::cupm::DeviceType T>
376: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
377: inline PetscErrorCode SfInterface<T>::FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata, void *leafupdate) noexcept
378: {
379: const PetscInt *rarray = rootopt ? rootopt->array : NULL;
380: const PetscInt *larray = leafopt ? leafopt->array : NULL;
382: PetscFunctionBegin;
383: if (!count) PetscFunctionReturn(PETSC_SUCCESS);
384: PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_FetchAndOpLocal<Type, Op, BS, EQ>, link->bs, count, rootstart, rarray, rootidx, (Type *)rootdata, leafstart, larray, leafidx, (const Type *)leafdata, (Type *)leafupdate));
385: PetscFunctionReturn(PETSC_SUCCESS);
386: }
388: /*====================================================================================*/
389: /* Init various types and instantiate pack/unpack function pointers */
390: /*====================================================================================*/
391: template <device::cupm::DeviceType T>
392: template <typename Type, PetscInt BS, PetscInt EQ>
393: inline void SfInterface<T>::PackInit_RealType(PetscSFLink link) noexcept
394: {
395: /* Pack/unpack for remote communication */
396: link->d_Pack = Pack<Type, BS, EQ>;
397: link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>;
398: link->d_UnpackAndAdd = UnpackAndOp<Type, kernels::Add<Type>, BS, EQ>;
399: link->d_UnpackAndMult = UnpackAndOp<Type, kernels::Mult<Type>, BS, EQ>;
400: link->d_UnpackAndMin = UnpackAndOp<Type, kernels::Min<Type>, BS, EQ>;
401: link->d_UnpackAndMax = UnpackAndOp<Type, kernels::Max<Type>, BS, EQ>;
402: link->d_FetchAndAdd = FetchAndOp<Type, kernels::Add<Type>, BS, EQ>;
404: /* Scatter for local communication */
405: link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */
406: link->d_ScatterAndAdd = ScatterAndOp<Type, kernels::Add<Type>, BS, EQ>;
407: link->d_ScatterAndMult = ScatterAndOp<Type, kernels::Mult<Type>, BS, EQ>;
408: link->d_ScatterAndMin = ScatterAndOp<Type, kernels::Min<Type>, BS, EQ>;
409: link->d_ScatterAndMax = ScatterAndOp<Type, kernels::Max<Type>, BS, EQ>;
410: link->d_FetchAndAddLocal = FetchAndOpLocal<Type, kernels::Add<Type>, BS, EQ>;
412: /* Atomic versions when there are data-race possibilities */
413: link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
414: link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
415: link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
416: link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
417: link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
418: link->da_FetchAndAdd = FetchAndOp<Type, AtomicAdd<Type>, BS, EQ>;
420: link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
421: link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
422: link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
423: link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
424: link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
425: link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicAdd<Type>, BS, EQ>;
426: }
428: /* Have this templated class to specialize for char integers */
429: template <device::cupm::DeviceType T>
430: template <typename Type, PetscInt BS, PetscInt EQ, PetscInt size /*sizeof(Type)*/>
431: struct SfInterface<T>::PackInit_IntegerType_Atomic {
432: static inline void Init(PetscSFLink link) noexcept
433: {
434: link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
435: link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
436: link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
437: link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
438: link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
439: link->da_UnpackAndLAND = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>;
440: link->da_UnpackAndLOR = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>;
441: link->da_UnpackAndLXOR = UnpackAndOp<Type, AtomicLXOR<Type>, BS, EQ>;
442: link->da_UnpackAndBAND = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>;
443: link->da_UnpackAndBOR = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>;
444: link->da_UnpackAndBXOR = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
445: link->da_FetchAndAdd = FetchAndOp<Type, AtomicAdd<Type>, BS, EQ>;
447: link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
448: link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
449: link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
450: link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
451: link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
452: link->da_ScatterAndLAND = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>;
453: link->da_ScatterAndLOR = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>;
454: link->da_ScatterAndLXOR = ScatterAndOp<Type, AtomicLXOR<Type>, BS, EQ>;
455: link->da_ScatterAndBAND = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>;
456: link->da_ScatterAndBOR = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>;
457: link->da_ScatterAndBXOR = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
458: link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicAdd<Type>, BS, EQ>;
459: }
460: };
462: /* CUDA does not support atomics on chars. It is TBD in PETSc. */
463: template <device::cupm::DeviceType T>
464: template <typename Type, PetscInt BS, PetscInt EQ>
465: struct SfInterface<T>::PackInit_IntegerType_Atomic<Type, BS, EQ, 1> {
466: static inline void Init(PetscSFLink)
467: { /* Nothing to leave function pointers NULL */
468: }
469: };
471: template <device::cupm::DeviceType T>
472: template <typename Type, PetscInt BS, PetscInt EQ>
473: inline void SfInterface<T>::PackInit_IntegerType(PetscSFLink link) noexcept
474: {
475: link->d_Pack = Pack<Type, BS, EQ>;
476: link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>;
477: link->d_UnpackAndAdd = UnpackAndOp<Type, kernels::Add<Type>, BS, EQ>;
478: link->d_UnpackAndMult = UnpackAndOp<Type, kernels::Mult<Type>, BS, EQ>;
479: link->d_UnpackAndMin = UnpackAndOp<Type, kernels::Min<Type>, BS, EQ>;
480: link->d_UnpackAndMax = UnpackAndOp<Type, kernels::Max<Type>, BS, EQ>;
481: link->d_UnpackAndLAND = UnpackAndOp<Type, kernels::LAND<Type>, BS, EQ>;
482: link->d_UnpackAndLOR = UnpackAndOp<Type, kernels::LOR<Type>, BS, EQ>;
483: link->d_UnpackAndLXOR = UnpackAndOp<Type, kernels::LXOR<Type>, BS, EQ>;
484: link->d_UnpackAndBAND = UnpackAndOp<Type, kernels::BAND<Type>, BS, EQ>;
485: link->d_UnpackAndBOR = UnpackAndOp<Type, kernels::BOR<Type>, BS, EQ>;
486: link->d_UnpackAndBXOR = UnpackAndOp<Type, kernels::BXOR<Type>, BS, EQ>;
487: link->d_FetchAndAdd = FetchAndOp<Type, kernels::Add<Type>, BS, EQ>;
489: link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
490: link->d_ScatterAndAdd = ScatterAndOp<Type, kernels::Add<Type>, BS, EQ>;
491: link->d_ScatterAndMult = ScatterAndOp<Type, kernels::Mult<Type>, BS, EQ>;
492: link->d_ScatterAndMin = ScatterAndOp<Type, kernels::Min<Type>, BS, EQ>;
493: link->d_ScatterAndMax = ScatterAndOp<Type, kernels::Max<Type>, BS, EQ>;
494: link->d_ScatterAndLAND = ScatterAndOp<Type, kernels::LAND<Type>, BS, EQ>;
495: link->d_ScatterAndLOR = ScatterAndOp<Type, kernels::LOR<Type>, BS, EQ>;
496: link->d_ScatterAndLXOR = ScatterAndOp<Type, kernels::LXOR<Type>, BS, EQ>;
497: link->d_ScatterAndBAND = ScatterAndOp<Type, kernels::BAND<Type>, BS, EQ>;
498: link->d_ScatterAndBOR = ScatterAndOp<Type, kernels::BOR<Type>, BS, EQ>;
499: link->d_ScatterAndBXOR = ScatterAndOp<Type, kernels::BXOR<Type>, BS, EQ>;
500: link->d_FetchAndAddLocal = FetchAndOpLocal<Type, kernels::Add<Type>, BS, EQ>;
501: PackInit_IntegerType_Atomic<Type, BS, EQ, sizeof(Type)>::Init(link);
502: }
504: #if defined(PETSC_HAVE_COMPLEX)
505: template <device::cupm::DeviceType T>
506: template <typename Type, PetscInt BS, PetscInt EQ>
507: inline void SfInterface<T>::PackInit_ComplexType(PetscSFLink link) noexcept
508: {
509: link->d_Pack = Pack<Type, BS, EQ>;
510: link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>;
511: link->d_UnpackAndAdd = UnpackAndOp<Type, kernels::Add<Type>, BS, EQ>;
512: link->d_UnpackAndMult = UnpackAndOp<Type, kernels::Mult<Type>, BS, EQ>;
513: link->d_FetchAndAdd = FetchAndOp<Type, kernels::Add<Type>, BS, EQ>;
515: link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
516: link->d_ScatterAndAdd = ScatterAndOp<Type, kernels::Add<Type>, BS, EQ>;
517: link->d_ScatterAndMult = ScatterAndOp<Type, kernels::Mult<Type>, BS, EQ>;
518: link->d_FetchAndAddLocal = FetchAndOpLocal<Type, kernels::Add<Type>, BS, EQ>;
520: link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
521: link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
522: link->da_UnpackAndMult = NULL; /* Not implemented yet */
523: link->da_FetchAndAdd = NULL; /* Return value of atomicAdd on complex is not atomic */
525: link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
526: link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
527: }
528: #endif
530: typedef signed char SignedChar;
531: typedef unsigned char UnsignedChar;
532: typedef struct {
533: int a;
534: int b;
535: } PairInt;
536: typedef struct {
537: PetscInt a;
538: PetscInt b;
539: } PairPetscInt;
541: template <device::cupm::DeviceType T>
542: template <typename Type>
543: inline void SfInterface<T>::PackInit_PairType(PetscSFLink link) noexcept
544: {
545: link->d_Pack = Pack<Type, 1, 1>;
546: link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, 1, 1>;
547: link->d_UnpackAndMaxloc = UnpackAndOp<Type, kernels::Maxloc<Type>, 1, 1>;
548: link->d_UnpackAndMinloc = UnpackAndOp<Type, kernels::Minloc<Type>, 1, 1>;
550: link->d_ScatterAndInsert = ScatterAndOp<Type, kernels::Insert<Type>, 1, 1>;
551: link->d_ScatterAndMaxloc = ScatterAndOp<Type, kernels::Maxloc<Type>, 1, 1>;
552: link->d_ScatterAndMinloc = ScatterAndOp<Type, kernels::Minloc<Type>, 1, 1>;
553: /* Atomics for pair types are not implemented yet */
554: }
556: template <device::cupm::DeviceType T>
557: template <typename Type, PetscInt BS, PetscInt EQ>
558: inline void SfInterface<T>::PackInit_DumbType(PetscSFLink link) noexcept
559: {
560: link->d_Pack = Pack<Type, BS, EQ>;
561: link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>;
562: link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
563: /* Atomics for dumb types are not implemented yet */
564: }
566: /* Some device-specific utilities */
567: template <device::cupm::DeviceType T>
568: inline PetscErrorCode SfInterface<T>::LinkSyncDevice(PetscSFLink) noexcept
569: {
570: PetscFunctionBegin;
571: PetscCallCUPM(cupmDeviceSynchronize());
572: PetscFunctionReturn(PETSC_SUCCESS);
573: }
575: template <device::cupm::DeviceType T>
576: inline PetscErrorCode SfInterface<T>::LinkSyncStream(PetscSFLink link) noexcept
577: {
578: PetscFunctionBegin;
579: PetscCallCUPM(cupmStreamSynchronize(link->stream));
580: PetscFunctionReturn(PETSC_SUCCESS);
581: }
583: template <device::cupm::DeviceType T>
584: inline PetscErrorCode SfInterface<T>::LinkMemcpy(PetscSFLink link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n) noexcept
585: {
586: PetscFunctionBegin;
587: cupmMemcpyKind_t kinds[2][2] = {
588: {cupmMemcpyHostToHost, cupmMemcpyHostToDevice },
589: {cupmMemcpyDeviceToHost, cupmMemcpyDeviceToDevice}
590: };
592: if (n) {
593: if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) { /* Separate HostToHost so that pure-cpu code won't call cupm runtime */
594: PetscCall(PetscMemcpy(dst, src, n));
595: } else {
596: int stype = PetscMemTypeDevice(srcmtype) ? 1 : 0;
597: int dtype = PetscMemTypeDevice(dstmtype) ? 1 : 0;
598: PetscCallCUPM(cupmMemcpyAsync(dst, src, n, kinds[stype][dtype], link->stream));
599: }
600: }
601: PetscFunctionReturn(PETSC_SUCCESS);
602: }
604: template <device::cupm::DeviceType T>
605: inline PetscErrorCode SfInterface<T>::Malloc(PetscMemType mtype, size_t size, void **ptr) noexcept
606: {
607: PetscFunctionBegin;
608: if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr));
609: else if (PetscMemTypeDevice(mtype)) {
610: PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
611: PetscCallCUPM(cupmMalloc(ptr, size));
612: } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
613: PetscFunctionReturn(PETSC_SUCCESS);
614: }
616: template <device::cupm::DeviceType T>
617: inline PetscErrorCode SfInterface<T>::Free(PetscMemType mtype, void *ptr) noexcept
618: {
619: PetscFunctionBegin;
620: if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr));
621: else if (PetscMemTypeDevice(mtype)) PetscCallCUPM(cupmFree(ptr));
622: else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
623: PetscFunctionReturn(PETSC_SUCCESS);
624: }
626: /* Destructor when the link uses MPI for communication on CUPM device */
627: template <device::cupm::DeviceType T>
628: inline PetscErrorCode SfInterface<T>::LinkDestroy_MPI(PetscSF, PetscSFLink link) noexcept
629: {
630: PetscFunctionBegin;
631: for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) {
632: PetscCallCUPM(cupmFree(link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
633: PetscCallCUPM(cupmFree(link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
634: }
635: PetscFunctionReturn(PETSC_SUCCESS);
636: }
638: /*====================================================================================*/
639: /* Main driver to init MPI datatype on device */
640: /*====================================================================================*/
642: /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
643: template <device::cupm::DeviceType T>
644: inline PetscErrorCode SfInterface<T>::LinkSetUp(PetscSF sf, PetscSFLink link, MPI_Datatype unit) noexcept
645: {
646: PetscInt nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0;
647: PetscBool is2Int, is2PetscInt;
648: #if defined(PETSC_HAVE_COMPLEX)
649: PetscInt nPetscComplex = 0;
650: #endif
652: PetscFunctionBegin;
653: if (link->deviceinited) PetscFunctionReturn(PETSC_SUCCESS);
654: PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar));
655: PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar));
656: /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
657: PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt));
658: PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt));
659: PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal));
660: #if defined(PETSC_HAVE_COMPLEX)
661: PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex));
662: #endif
663: PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int));
664: PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt));
666: if (is2Int) {
667: PackInit_PairType<PairInt>(link);
668: } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
669: PackInit_PairType<PairPetscInt>(link);
670: } else if (nPetscReal) {
671: #if !defined(PETSC_HAVE_DEVICE)
672: if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link);
673: else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link);
674: else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link);
675: else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link);
676: else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link);
677: else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link);
678: else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link);
679: else if (nPetscReal % 1 == 0)
680: #endif
681: PackInit_RealType<PetscReal, 1, 0>(link);
682: } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
683: #if !defined(PETSC_HAVE_DEVICE)
684: if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link);
685: else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link);
686: else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link);
687: else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link);
688: else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link);
689: else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link);
690: else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link);
691: else if (nPetscInt % 1 == 0)
692: #endif
693: PackInit_IntegerType<llint, 1, 0>(link);
694: } else if (nInt) {
695: #if !defined(PETSC_HAVE_DEVICE)
696: if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link);
697: else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link);
698: else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link);
699: else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link);
700: else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link);
701: else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link);
702: else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link);
703: else if (nInt % 1 == 0)
704: #endif
705: PackInit_IntegerType<int, 1, 0>(link);
706: } else if (nSignedChar) {
707: #if !defined(PETSC_HAVE_DEVICE)
708: if (nSignedChar == 8) PackInit_IntegerType<SignedChar, 8, 1>(link);
709: else if (nSignedChar % 8 == 0) PackInit_IntegerType<SignedChar, 8, 0>(link);
710: else if (nSignedChar == 4) PackInit_IntegerType<SignedChar, 4, 1>(link);
711: else if (nSignedChar % 4 == 0) PackInit_IntegerType<SignedChar, 4, 0>(link);
712: else if (nSignedChar == 2) PackInit_IntegerType<SignedChar, 2, 1>(link);
713: else if (nSignedChar % 2 == 0) PackInit_IntegerType<SignedChar, 2, 0>(link);
714: else if (nSignedChar == 1) PackInit_IntegerType<SignedChar, 1, 1>(link);
715: else if (nSignedChar % 1 == 0)
716: #endif
717: PackInit_IntegerType<SignedChar, 1, 0>(link);
718: } else if (nUnsignedChar) {
719: #if !defined(PETSC_HAVE_DEVICE)
720: if (nUnsignedChar == 8) PackInit_IntegerType<UnsignedChar, 8, 1>(link);
721: else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<UnsignedChar, 8, 0>(link);
722: else if (nUnsignedChar == 4) PackInit_IntegerType<UnsignedChar, 4, 1>(link);
723: else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<UnsignedChar, 4, 0>(link);
724: else if (nUnsignedChar == 2) PackInit_IntegerType<UnsignedChar, 2, 1>(link);
725: else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<UnsignedChar, 2, 0>(link);
726: else if (nUnsignedChar == 1) PackInit_IntegerType<UnsignedChar, 1, 1>(link);
727: else if (nUnsignedChar % 1 == 0)
728: #endif
729: PackInit_IntegerType<UnsignedChar, 1, 0>(link);
730: #if defined(PETSC_HAVE_COMPLEX)
731: } else if (nPetscComplex) {
732: #if !defined(PETSC_HAVE_DEVICE)
733: if (nPetscComplex == 8) PackInit_ComplexType<PetscComplex, 8, 1>(link);
734: else if (nPetscComplex % 8 == 0) PackInit_ComplexType<PetscComplex, 8, 0>(link);
735: else if (nPetscComplex == 4) PackInit_ComplexType<PetscComplex, 4, 1>(link);
736: else if (nPetscComplex % 4 == 0) PackInit_ComplexType<PetscComplex, 4, 0>(link);
737: else if (nPetscComplex == 2) PackInit_ComplexType<PetscComplex, 2, 1>(link);
738: else if (nPetscComplex % 2 == 0) PackInit_ComplexType<PetscComplex, 2, 0>(link);
739: else if (nPetscComplex == 1) PackInit_ComplexType<PetscComplex, 1, 1>(link);
740: else if (nPetscComplex % 1 == 0)
741: #endif
742: PackInit_ComplexType<PetscComplex, 1, 0>(link);
743: #endif
744: } else {
745: MPI_Aint lb, nbyte;
746: PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte));
747: PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb);
748: if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
749: #if !defined(PETSC_HAVE_DEVICE)
750: if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link);
751: else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link);
752: else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link);
753: else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link);
754: else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link);
755: else if (nbyte % 1 == 0)
756: #endif
757: PackInit_DumbType<char, 1, 0>(link);
758: } else {
759: nInt = nbyte / sizeof(int);
760: #if !defined(PETSC_HAVE_DEVICE)
761: if (nInt == 8) PackInit_DumbType<int, 8, 1>(link);
762: else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link);
763: else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link);
764: else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link);
765: else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link);
766: else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link);
767: else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link);
768: else if (nInt % 1 == 0)
769: #endif
770: PackInit_DumbType<int, 1, 0>(link);
771: }
772: }
774: if (!sf->maxResidentThreadsPerGPU) { /* Not initialized */
775: int device;
776: cupmDeviceProp_t props;
777: PetscCallCUPM(cupmGetDevice(&device));
778: PetscCallCUPM(cupmGetDeviceProperties(&props, device));
779: sf->maxResidentThreadsPerGPU = props.maxThreadsPerMultiProcessor * props.multiProcessorCount;
780: }
781: link->maxResidentThreadsPerGPU = sf->maxResidentThreadsPerGPU;
783: {
784: cupmStream_t *stream;
785: PetscDeviceContext dctx;
787: PetscCall(PetscDeviceContextGetCurrentContextAssertType_Internal(&dctx, PETSC_DEVICE_CUPM()));
788: PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
789: link->stream = *stream;
790: }
791: link->Destroy = LinkDestroy_MPI;
792: link->SyncDevice = LinkSyncDevice;
793: link->SyncStream = LinkSyncStream;
794: link->Memcpy = LinkMemcpy;
795: link->deviceinited = PETSC_TRUE;
796: PetscFunctionReturn(PETSC_SUCCESS);
797: }
799: } // namespace impl
801: } // namespace cupm
803: } // namespace sf
805: } // namespace Petsc