Actual source code: cupmcontext.hpp

  1: #pragma once

  3: #include <petsc/private/deviceimpl.h>
  4: #include <petsc/private/cupmsolverinterface.hpp>
  5: #include <petsc/private/logimpl.h>

  7: #include <petsc/private/cpp/array.hpp>

  9: #include "../segmentedmempool.hpp"
 10: #include "cupmallocator.hpp"
 11: #include "cupmstream.hpp"
 12: #include "cupmevent.hpp"

 14: namespace Petsc
 15: {

 17: namespace device
 18: {

 20: namespace cupm
 21: {

 23: namespace impl
 24: {

 26: template <DeviceType T>
 27: class DeviceContext : SolverInterface<T> {
 28: public:
 29:   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);

 31: private:
 32:   template <typename H, std::size_t>
 33:   struct HandleTag {
 34:     using type = H;
 35:   };

 37:   using stream_tag = HandleTag<cupmStream_t, 0>;
 38:   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
 39:   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;

 41:   using stream_type = CUPMStream<T>;
 42:   using event_type  = CUPMEvent<T>;

 44: public:
 45:   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
 46:   // header, but since we are using the power of templates it must be declared part of
 47:   // this class to have easy access the same typedefs. Technically one can make a
 48:   // templated struct outside the class but it's more code for the same result.
 49:   struct PetscDeviceContext_IMPLS {
 50:     stream_type stream{};
 51:     cupmEvent_t event{};
 52:     cupmEvent_t begin{}; // timer-only
 53:     cupmEvent_t end{};   // timer-only
 54: #if PetscDefined(USE_DEBUG)
 55:     PetscBool timerInUse{};
 56: #endif
 57:     cupmBlasHandle_t   blas{};
 58:     cupmSolverHandle_t solver{};

 60:     constexpr PetscDeviceContext_IMPLS() noexcept = default;

 62:     PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); }

 64:     PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; }

 66:     PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; }
 67:   };

 69: private:
 70:   static bool initialized_;

 72:   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
 73:   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;

 75:   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }

 77:   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }

 79:   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }

 81:   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }

 83:   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
 84:   // handles
 85:   static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }

 87:   static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
 88:   {
 89:     const auto dci    = impls_cast_(dctx);
 90:     auto      &handle = blashandles_[dctx->device->deviceId];

 92:     PetscFunctionBegin;
 93:     if (!handle) {
 94:       PetscCall(PetscLogEventsPause());
 95:       PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
 96:       for (auto i = 0; i < 3; ++i) {
 97:         const auto cberr = cupmBlasCreate(handle.ptr_to());
 98:         if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
 99:         if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
100:         if (i != 2) {
101:           PetscCall(PetscSleep(3));
102:           continue;
103:         }
104:         PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
105:       }
106:       PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
107:       PetscCall(PetscLogEventsResume());
108:     }
109:     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
110:     dci->blas = handle;
111:     PetscFunctionReturn(PETSC_SUCCESS);
112:   }

114:   static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
115:   {
116:     const auto dci    = impls_cast_(dctx);
117:     auto      &handle = solverhandles_[dctx->device->deviceId];

119:     PetscFunctionBegin;
120:     if (!handle) {
121:       PetscCall(PetscLogEventsPause());
122:       PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
123:       for (auto i = 0; i < 3; ++i) {
124:         const auto cerr = cupmSolverCreate(&handle);
125:         if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
126:         if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
127:         if (i < 2) {
128:           PetscCall(PetscSleep(3));
129:           continue;
130:         }
131:         PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
132:       }
133:       PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
134:       PetscCall(PetscLogEventsResume());
135:     }
136:     PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
137:     dci->solver = handle;
138:     PetscFunctionReturn(PETSC_SUCCESS);
139:   }

141:   static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
142:   {
143:     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;

145:     PetscFunctionBegin;
146:     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
147:                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
148:     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
149:     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
150:     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
151:     PetscFunctionReturn(PETSC_SUCCESS);
152:   }

154:   static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }

156:   static PetscErrorCode finalize_() noexcept
157:   {
158:     PetscFunctionBegin;
159:     for (auto &&handle : blashandles_) {
160:       if (handle) {
161:         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
162:         handle = nullptr;
163:       }
164:     }
165:     for (auto &&handle : solverhandles_) {
166:       if (handle) {
167:         PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
168:         handle = nullptr;
169:       }
170:     }
171:     initialized_ = false;
172:     PetscFunctionReturn(PETSC_SUCCESS);
173:   }

175:   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
176:   PETSC_NODISCARD static PoolType &default_pool_() noexcept
177:   {
178:     static PoolType pool;
179:     return pool;
180:   }

182:   static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
183:   {
184:     PetscFunctionBegin;
185:     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
186:     PetscFunctionReturn(PETSC_SUCCESS);
187:   }

189: public:
190:   // All of these functions MUST be static in order to be callable from C, otherwise they
191:   // get the implicit 'this' pointer tacked on
192:   static PetscErrorCode destroy(PetscDeviceContext) noexcept;
193:   static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
194:   static PetscErrorCode setUp(PetscDeviceContext) noexcept;
195:   static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
196:   static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
197:   static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
198:   template <typename Handle_t>
199:   static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
200:   template <typename Handle_t>
201:   static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept;
202:   static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
203:   static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
204:   static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
205:   static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
206:   static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
207:   static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
208:   static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
209:   static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
210:   static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;

212:   // not a PetscDeviceContext method, this registers the class
213:   static PetscErrorCode initialize(PetscDevice) noexcept;

215:   // clang-format off
216:   static constexpr _DeviceContextOps ops = {
217:     PetscDesignatedInitializer(destroy, destroy),
218:     PetscDesignatedInitializer(changestreamtype, changeStreamType),
219:     PetscDesignatedInitializer(setup, setUp),
220:     PetscDesignatedInitializer(query, query),
221:     PetscDesignatedInitializer(waitforcontext, waitForContext),
222:     PetscDesignatedInitializer(synchronize, synchronize),
223:     PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
224:     PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
225:     PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>),
226:     PetscDesignatedInitializer(begintimer, beginTimer),
227:     PetscDesignatedInitializer(endtimer, endTimer),
228:     PetscDesignatedInitializer(memalloc, memAlloc),
229:     PetscDesignatedInitializer(memfree, memFree),
230:     PetscDesignatedInitializer(memcopy, memCopy),
231:     PetscDesignatedInitializer(memset, memSet),
232:     PetscDesignatedInitializer(createevent, createEvent),
233:     PetscDesignatedInitializer(recordevent, recordEvent),
234:     PetscDesignatedInitializer(waitforevent, waitForEvent)
235:   };
236:   // clang-format on
237: };

239: // not a PetscDeviceContext method, this initializes the CLASS
240: template <DeviceType T>
241: inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
242: {
243:   PetscFunctionBegin;
244:   if (PetscUnlikely(!initialized_)) {
245:     uint64_t      threshold = UINT64_MAX;
246:     cupmMemPool_t mempool;

248:     initialized_ = true;
249:     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
250:     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
251:     blashandles_.fill(nullptr);
252:     solverhandles_.fill(nullptr);
253:     PetscCall(PetscRegisterFinalize(finalize_));
254:   }
255:   PetscFunctionReturn(PETSC_SUCCESS);
256: }

258: template <DeviceType T>
259: inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
260: {
261:   PetscFunctionBegin;
262:   if (const auto dci = impls_cast_(dctx)) {
263:     PetscCall(dci->stream.destroy());
264:     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
265:     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
266:     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
267:     delete dci;
268:     dctx->data = nullptr;
269:   }
270:   PetscFunctionReturn(PETSC_SUCCESS);
271: }

273: template <DeviceType T>
274: inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
275: {
276:   const auto dci = impls_cast_(dctx);

278:   PetscFunctionBegin;
279:   PetscCall(dci->stream.destroy());
280:   // set these to null so they aren't usable until setup is called again
281:   dci->blas   = nullptr;
282:   dci->solver = nullptr;
283:   PetscFunctionReturn(PETSC_SUCCESS);
284: }

286: template <DeviceType T>
287: inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
288: {
289:   const auto dci   = impls_cast_(dctx);
290:   auto      &event = dci->event;

292:   PetscFunctionBegin;
293:   PetscCall(check_current_device_(dctx));
294:   PetscCall(dci->stream.change_type(dctx->streamType));
295:   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
296: #if PetscDefined(USE_DEBUG)
297:   dci->timerInUse = PETSC_FALSE;
298: #endif
299:   PetscFunctionReturn(PETSC_SUCCESS);
300: }

302: template <DeviceType T>
303: inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
304: {
305:   PetscFunctionBegin;
306:   PetscCall(check_current_device_(dctx));
307:   switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
308:   case cupmSuccess:
309:     *idle = PETSC_TRUE;
310:     break;
311:   case cupmErrorNotReady:
312:     *idle = PETSC_FALSE;
313:     // reset the error
314:     cerr = cupmGetLastError();
315:     static_cast<void>(cerr);
316:     break;
317:   default:
318:     PetscCallCUPM(cerr);
319:     PetscUnreachable();
320:   }
321:   PetscFunctionReturn(PETSC_SUCCESS);
322: }

324: template <DeviceType T>
325: inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
326: {
327:   const auto dcib  = impls_cast_(dctxb);
328:   const auto event = dcib->event;

330:   PetscFunctionBegin;
331:   PetscCall(check_current_device_(dctxa, dctxb));
332:   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
333:   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
334:   PetscFunctionReturn(PETSC_SUCCESS);
335: }

337: template <DeviceType T>
338: inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
339: {
340:   auto idle = PETSC_TRUE;

342:   PetscFunctionBegin;
343:   PetscCall(query(dctx, &idle));
344:   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
345:   PetscFunctionReturn(PETSC_SUCCESS);
346: }

348: template <DeviceType T>
349: template <typename handle_t>
350: inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
351: {
352:   PetscFunctionBegin;
353:   PetscCall(initialize_handle_(handle_t{}, dctx));
354:   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
355:   PetscFunctionReturn(PETSC_SUCCESS);
356: }

358: template <DeviceType T>
359: template <typename handle_t>
360: inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept
361: {
362:   using handle_type = typename handle_t::type;

364:   PetscFunctionBegin;
365:   PetscCall(initialize_handle_(handle_t{}, dctx));
366:   *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{})));
367:   PetscFunctionReturn(PETSC_SUCCESS);
368: }

370: template <DeviceType T>
371: inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
372: {
373:   const auto dci = impls_cast_(dctx);

375:   PetscFunctionBegin;
376:   PetscCall(check_current_device_(dctx));
377: #if PetscDefined(USE_DEBUG)
378:   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
379:   dci->timerInUse = PETSC_TRUE;
380: #endif
381:   if (!dci->begin) {
382:     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
383:     PetscCallCUPM(cupmEventCreate(&dci->begin));
384:     PetscCallCUPM(cupmEventCreate(&dci->end));
385:   }
386:   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
387:   PetscFunctionReturn(PETSC_SUCCESS);
388: }

390: template <DeviceType T>
391: inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
392: {
393:   float      gtime;
394:   const auto dci = impls_cast_(dctx);
395:   const auto end = dci->end;

397:   PetscFunctionBegin;
398:   PetscCall(check_current_device_(dctx));
399: #if PetscDefined(USE_DEBUG)
400:   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
401:   dci->timerInUse = PETSC_FALSE;
402: #endif
403:   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
404:   PetscCallCUPM(cupmEventSynchronize(end));
405:   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
406:   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
407:   PetscFunctionReturn(PETSC_SUCCESS);
408: }

410: template <DeviceType T>
411: inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
412: {
413:   const auto &stream = impls_cast_(dctx)->stream;

415:   PetscFunctionBegin;
416:   PetscCall(check_current_device_(dctx));
417:   PetscCall(check_memtype_(mtype, "allocating"));
418:   if (PetscMemTypeHost(mtype)) {
419:     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
420:   } else {
421:     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
422:   }
423:   if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
424:   PetscFunctionReturn(PETSC_SUCCESS);
425: }

427: template <DeviceType T>
428: inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
429: {
430:   const auto &stream = impls_cast_(dctx)->stream;

432:   PetscFunctionBegin;
433:   PetscCall(check_current_device_(dctx));
434:   PetscCall(check_memtype_(mtype, "freeing"));
435:   if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
436:   if (PetscMemTypeHost(mtype)) {
437:     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
438:     // if ptr exists still exists the pool didn't own it
439:     if (*ptr) {
440:       auto registered = PETSC_FALSE, managed = PETSC_FALSE;

442:       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
443:       if (registered) {
444:         PetscCallCUPM(cupmFreeHost(*ptr));
445:       } else if (managed) {
446:         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
447:       }
448:     }
449:   } else {
450:     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
451:     // if ptr still exists the pool didn't own it
452:     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
453:   }
454:   PetscFunctionReturn(PETSC_SUCCESS);
455: }

457: template <DeviceType T>
458: inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
459: {
460:   const auto stream = impls_cast_(dctx)->stream.get_stream();

462:   PetscFunctionBegin;
463:   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
464:   if (mode == PETSC_DEVICE_COPY_HTOH) {
465:     const auto cerr = cupmStreamQuery(stream);

467:     // yes this is faster
468:     if (cerr == cupmSuccess) {
469:       PetscCall(PetscMemcpy(dest, src, n));
470:       PetscFunctionReturn(PETSC_SUCCESS);
471:     } else if (cerr == cupmErrorNotReady) {
472:       auto PETSC_UNUSED unused = cupmGetLastError();

474:       static_cast<void>(unused);
475:     } else {
476:       PetscCallCUPM(cerr);
477:     }
478:   }
479:   PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
480:   PetscFunctionReturn(PETSC_SUCCESS);
481: }

483: template <DeviceType T>
484: inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
485: {
486:   PetscFunctionBegin;
487:   PetscCall(check_current_device_(dctx));
488:   PetscCall(check_memtype_(mtype, "zeroing"));
489:   PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
490:   PetscFunctionReturn(PETSC_SUCCESS);
491: }

493: template <DeviceType T>
494: inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
495: {
496:   PetscFunctionBegin;
497:   PetscCallCXX(event->data = new event_type{});
498:   event->destroy = [](PetscEvent event) {
499:     PetscFunctionBegin;
500:     delete event_cast_(event);
501:     event->data = nullptr;
502:     PetscFunctionReturn(PETSC_SUCCESS);
503:   };
504:   PetscFunctionReturn(PETSC_SUCCESS);
505: }

507: template <DeviceType T>
508: inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
509: {
510:   PetscFunctionBegin;
511:   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
512:   PetscFunctionReturn(PETSC_SUCCESS);
513: }

515: template <DeviceType T>
516: inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
517: {
518:   PetscFunctionBegin;
519:   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
520:   PetscFunctionReturn(PETSC_SUCCESS);
521: }

523: // initialize the static member variables
524: template <DeviceType T>
525: bool DeviceContext<T>::initialized_ = false;

527: template <DeviceType T>
528: std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};

530: template <DeviceType T>
531: std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};

533: template <DeviceType T>
534: constexpr _DeviceContextOps DeviceContext<T>::ops;

536: } // namespace impl

538: // shorten this one up a bit (and instantiate the templates)
539: using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
540: using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;

542: // shorthand for what is an EXTREMELY long name
543: #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS

545: } // namespace cupm

547: } // namespace device

549: } // namespace Petsc