Actual source code: cupmcontext.hpp
1: #pragma once
3: #include <petsc/private/deviceimpl.h>
4: #include <petsc/private/cupmsolverinterface.hpp>
5: #include <petsc/private/logimpl.h>
7: #include <petsc/private/cpp/array.hpp>
9: #include "../segmentedmempool.hpp"
10: #include "cupmallocator.hpp"
11: #include "cupmstream.hpp"
12: #include "cupmevent.hpp"
14: namespace Petsc
15: {
17: namespace device
18: {
20: namespace cupm
21: {
23: namespace impl
24: {
26: template <DeviceType T>
27: class DeviceContext : SolverInterface<T> {
28: public:
29: PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
31: private:
32: template <typename H, std::size_t>
33: struct HandleTag {
34: using type = H;
35: };
37: using stream_tag = HandleTag<cupmStream_t, 0>;
38: using blas_tag = HandleTag<cupmBlasHandle_t, 1>;
39: using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
41: using stream_type = CUPMStream<T>;
42: using event_type = CUPMEvent<T>;
44: public:
45: // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
46: // header, but since we are using the power of templates it must be declared part of
47: // this class to have easy access the same typedefs. Technically one can make a
48: // templated struct outside the class but it's more code for the same result.
49: struct PetscDeviceContext_IMPLS {
50: stream_type stream{};
51: cupmEvent_t event{};
52: cupmEvent_t begin{}; // timer-only
53: cupmEvent_t end{}; // timer-only
54: #if PetscDefined(USE_DEBUG)
55: PetscBool timerInUse{};
56: #endif
57: cupmBlasHandle_t blas{};
58: cupmSolverHandle_t solver{};
60: constexpr PetscDeviceContext_IMPLS() noexcept = default;
62: PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); }
64: PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; }
66: PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; }
67: };
69: private:
70: static bool initialized_;
72: static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_;
73: static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
75: PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
77: PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
79: PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
81: PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
83: // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
84: // handles
85: static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }
87: static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
88: {
89: const auto dci = impls_cast_(dctx);
90: auto &handle = blashandles_[dctx->device->deviceId];
92: PetscFunctionBegin;
93: if (!handle) {
94: PetscCall(PetscLogEventsPause());
95: PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
96: for (auto i = 0; i < 3; ++i) {
97: const auto cberr = cupmBlasCreate(handle.ptr_to());
98: if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
99: if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
100: if (i != 2) {
101: PetscCall(PetscSleep(3));
102: continue;
103: }
104: PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
105: }
106: PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
107: PetscCall(PetscLogEventsResume());
108: }
109: PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
110: dci->blas = handle;
111: PetscFunctionReturn(PETSC_SUCCESS);
112: }
114: static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
115: {
116: const auto dci = impls_cast_(dctx);
117: auto &handle = solverhandles_[dctx->device->deviceId];
119: PetscFunctionBegin;
120: if (!handle) {
121: PetscCall(PetscLogEventsPause());
122: PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
123: for (auto i = 0; i < 3; ++i) {
124: const auto cerr = cupmSolverCreate(&handle);
125: if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
126: if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
127: if (i < 2) {
128: PetscCall(PetscSleep(3));
129: continue;
130: }
131: PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
132: }
133: PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
134: PetscCall(PetscLogEventsResume());
135: }
136: PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
137: dci->solver = handle;
138: PetscFunctionReturn(PETSC_SUCCESS);
139: }
141: static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
142: {
143: const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
145: PetscFunctionBegin;
146: PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
147: PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
148: PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
149: PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
150: PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
151: PetscFunctionReturn(PETSC_SUCCESS);
152: }
154: static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
156: static PetscErrorCode finalize_() noexcept
157: {
158: PetscFunctionBegin;
159: for (auto &&handle : blashandles_) {
160: if (handle) {
161: PetscCallCUPMBLAS(cupmBlasDestroy(handle));
162: handle = nullptr;
163: }
164: }
165: for (auto &&handle : solverhandles_) {
166: if (handle) {
167: PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
168: handle = nullptr;
169: }
170: }
171: initialized_ = false;
172: PetscFunctionReturn(PETSC_SUCCESS);
173: }
175: template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
176: PETSC_NODISCARD static PoolType &default_pool_() noexcept
177: {
178: static PoolType pool;
179: return pool;
180: }
182: static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
183: {
184: PetscFunctionBegin;
185: PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
186: PetscFunctionReturn(PETSC_SUCCESS);
187: }
189: public:
190: // All of these functions MUST be static in order to be callable from C, otherwise they
191: // get the implicit 'this' pointer tacked on
192: static PetscErrorCode destroy(PetscDeviceContext) noexcept;
193: static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
194: static PetscErrorCode setUp(PetscDeviceContext) noexcept;
195: static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
196: static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
197: static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
198: template <typename Handle_t>
199: static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
200: template <typename Handle_t>
201: static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept;
202: static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
203: static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
204: static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
205: static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
206: static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
207: static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
208: static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
209: static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
210: static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;
212: // not a PetscDeviceContext method, this registers the class
213: static PetscErrorCode initialize(PetscDevice) noexcept;
215: // clang-format off
216: static constexpr _DeviceContextOps ops = {
217: PetscDesignatedInitializer(destroy, destroy),
218: PetscDesignatedInitializer(changestreamtype, changeStreamType),
219: PetscDesignatedInitializer(setup, setUp),
220: PetscDesignatedInitializer(query, query),
221: PetscDesignatedInitializer(waitforcontext, waitForContext),
222: PetscDesignatedInitializer(synchronize, synchronize),
223: PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
224: PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
225: PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>),
226: PetscDesignatedInitializer(begintimer, beginTimer),
227: PetscDesignatedInitializer(endtimer, endTimer),
228: PetscDesignatedInitializer(memalloc, memAlloc),
229: PetscDesignatedInitializer(memfree, memFree),
230: PetscDesignatedInitializer(memcopy, memCopy),
231: PetscDesignatedInitializer(memset, memSet),
232: PetscDesignatedInitializer(createevent, createEvent),
233: PetscDesignatedInitializer(recordevent, recordEvent),
234: PetscDesignatedInitializer(waitforevent, waitForEvent)
235: };
236: // clang-format on
237: };
239: // not a PetscDeviceContext method, this initializes the CLASS
240: template <DeviceType T>
241: inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
242: {
243: PetscFunctionBegin;
244: if (PetscUnlikely(!initialized_)) {
245: uint64_t threshold = UINT64_MAX;
246: cupmMemPool_t mempool;
248: initialized_ = true;
249: PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
250: PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
251: blashandles_.fill(nullptr);
252: solverhandles_.fill(nullptr);
253: PetscCall(PetscRegisterFinalize(finalize_));
254: }
255: PetscFunctionReturn(PETSC_SUCCESS);
256: }
258: template <DeviceType T>
259: inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
260: {
261: PetscFunctionBegin;
262: if (const auto dci = impls_cast_(dctx)) {
263: PetscCall(dci->stream.destroy());
264: if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
265: if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
266: if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
267: delete dci;
268: dctx->data = nullptr;
269: }
270: PetscFunctionReturn(PETSC_SUCCESS);
271: }
273: template <DeviceType T>
274: inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
275: {
276: const auto dci = impls_cast_(dctx);
278: PetscFunctionBegin;
279: PetscCall(dci->stream.destroy());
280: // set these to null so they aren't usable until setup is called again
281: dci->blas = nullptr;
282: dci->solver = nullptr;
283: PetscFunctionReturn(PETSC_SUCCESS);
284: }
286: template <DeviceType T>
287: inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
288: {
289: const auto dci = impls_cast_(dctx);
290: auto &event = dci->event;
292: PetscFunctionBegin;
293: PetscCall(check_current_device_(dctx));
294: PetscCall(dci->stream.change_type(dctx->streamType));
295: if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
296: #if PetscDefined(USE_DEBUG)
297: dci->timerInUse = PETSC_FALSE;
298: #endif
299: PetscFunctionReturn(PETSC_SUCCESS);
300: }
302: template <DeviceType T>
303: inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
304: {
305: PetscFunctionBegin;
306: PetscCall(check_current_device_(dctx));
307: switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
308: case cupmSuccess:
309: *idle = PETSC_TRUE;
310: break;
311: case cupmErrorNotReady:
312: *idle = PETSC_FALSE;
313: // reset the error
314: cerr = cupmGetLastError();
315: static_cast<void>(cerr);
316: break;
317: default:
318: PetscCallCUPM(cerr);
319: PetscUnreachable();
320: }
321: PetscFunctionReturn(PETSC_SUCCESS);
322: }
324: template <DeviceType T>
325: inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
326: {
327: const auto dcib = impls_cast_(dctxb);
328: const auto event = dcib->event;
330: PetscFunctionBegin;
331: PetscCall(check_current_device_(dctxa, dctxb));
332: PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
333: PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
334: PetscFunctionReturn(PETSC_SUCCESS);
335: }
337: template <DeviceType T>
338: inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
339: {
340: auto idle = PETSC_TRUE;
342: PetscFunctionBegin;
343: PetscCall(query(dctx, &idle));
344: if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
345: PetscFunctionReturn(PETSC_SUCCESS);
346: }
348: template <DeviceType T>
349: template <typename handle_t>
350: inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
351: {
352: PetscFunctionBegin;
353: PetscCall(initialize_handle_(handle_t{}, dctx));
354: *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
355: PetscFunctionReturn(PETSC_SUCCESS);
356: }
358: template <DeviceType T>
359: template <typename handle_t>
360: inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept
361: {
362: using handle_type = typename handle_t::type;
364: PetscFunctionBegin;
365: PetscCall(initialize_handle_(handle_t{}, dctx));
366: *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{})));
367: PetscFunctionReturn(PETSC_SUCCESS);
368: }
370: template <DeviceType T>
371: inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
372: {
373: const auto dci = impls_cast_(dctx);
375: PetscFunctionBegin;
376: PetscCall(check_current_device_(dctx));
377: #if PetscDefined(USE_DEBUG)
378: PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
379: dci->timerInUse = PETSC_TRUE;
380: #endif
381: if (!dci->begin) {
382: PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
383: PetscCallCUPM(cupmEventCreate(&dci->begin));
384: PetscCallCUPM(cupmEventCreate(&dci->end));
385: }
386: PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
387: PetscFunctionReturn(PETSC_SUCCESS);
388: }
390: template <DeviceType T>
391: inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
392: {
393: float gtime;
394: const auto dci = impls_cast_(dctx);
395: const auto end = dci->end;
397: PetscFunctionBegin;
398: PetscCall(check_current_device_(dctx));
399: #if PetscDefined(USE_DEBUG)
400: PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
401: dci->timerInUse = PETSC_FALSE;
402: #endif
403: PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
404: PetscCallCUPM(cupmEventSynchronize(end));
405: PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end));
406: *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
407: PetscFunctionReturn(PETSC_SUCCESS);
408: }
410: template <DeviceType T>
411: inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
412: {
413: const auto &stream = impls_cast_(dctx)->stream;
415: PetscFunctionBegin;
416: PetscCall(check_current_device_(dctx));
417: PetscCall(check_memtype_(mtype, "allocating"));
418: if (PetscMemTypeHost(mtype)) {
419: PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
420: } else {
421: PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
422: }
423: if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
424: PetscFunctionReturn(PETSC_SUCCESS);
425: }
427: template <DeviceType T>
428: inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
429: {
430: const auto &stream = impls_cast_(dctx)->stream;
432: PetscFunctionBegin;
433: PetscCall(check_current_device_(dctx));
434: PetscCall(check_memtype_(mtype, "freeing"));
435: if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
436: if (PetscMemTypeHost(mtype)) {
437: PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
438: // if ptr exists still exists the pool didn't own it
439: if (*ptr) {
440: auto registered = PETSC_FALSE, managed = PETSC_FALSE;
442: PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed));
443: if (registered) {
444: PetscCallCUPM(cupmFreeHost(*ptr));
445: } else if (managed) {
446: PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
447: }
448: }
449: } else {
450: PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
451: // if ptr still exists the pool didn't own it
452: if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
453: }
454: PetscFunctionReturn(PETSC_SUCCESS);
455: }
457: template <DeviceType T>
458: inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
459: {
460: const auto stream = impls_cast_(dctx)->stream.get_stream();
462: PetscFunctionBegin;
463: // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
464: if (mode == PETSC_DEVICE_COPY_HTOH) {
465: const auto cerr = cupmStreamQuery(stream);
467: // yes this is faster
468: if (cerr == cupmSuccess) {
469: PetscCall(PetscMemcpy(dest, src, n));
470: PetscFunctionReturn(PETSC_SUCCESS);
471: } else if (cerr == cupmErrorNotReady) {
472: auto PETSC_UNUSED unused = cupmGetLastError();
474: static_cast<void>(unused);
475: } else {
476: PetscCallCUPM(cerr);
477: }
478: }
479: PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
480: PetscFunctionReturn(PETSC_SUCCESS);
481: }
483: template <DeviceType T>
484: inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
485: {
486: PetscFunctionBegin;
487: PetscCall(check_current_device_(dctx));
488: PetscCall(check_memtype_(mtype, "zeroing"));
489: PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
490: PetscFunctionReturn(PETSC_SUCCESS);
491: }
493: template <DeviceType T>
494: inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
495: {
496: PetscFunctionBegin;
497: PetscCallCXX(event->data = new event_type{});
498: event->destroy = [](PetscEvent event) {
499: PetscFunctionBegin;
500: delete event_cast_(event);
501: event->data = nullptr;
502: PetscFunctionReturn(PETSC_SUCCESS);
503: };
504: PetscFunctionReturn(PETSC_SUCCESS);
505: }
507: template <DeviceType T>
508: inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
509: {
510: PetscFunctionBegin;
511: PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
512: PetscFunctionReturn(PETSC_SUCCESS);
513: }
515: template <DeviceType T>
516: inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
517: {
518: PetscFunctionBegin;
519: PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
520: PetscFunctionReturn(PETSC_SUCCESS);
521: }
523: // initialize the static member variables
524: template <DeviceType T>
525: bool DeviceContext<T>::initialized_ = false;
527: template <DeviceType T>
528: std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
530: template <DeviceType T>
531: std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
533: template <DeviceType T>
534: constexpr _DeviceContextOps DeviceContext<T>::ops;
536: } // namespace impl
538: // shorten this one up a bit (and instantiate the templates)
539: using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
540: using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>;
542: // shorthand for what is an EXTREMELY long name
543: #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
545: } // namespace cupm
547: } // namespace device
549: } // namespace Petsc