Actual source code: cupmblasinterface.hpp
1: #pragma once
3: #include <petsc/private/cupminterface.hpp>
4: #include <petsc/private/petscadvancedmacros.h>
6: #include <limits> // std::numeric_limits
8: namespace Petsc
9: {
11: namespace device
12: {
14: namespace cupm
15: {
17: namespace impl
18: {
20: #define PetscCallCUPMBLAS_(__abort_fn__, __comm__, ...) \
21: do { \
22: PetscStackUpdateLine; \
23: const cupmBlasError_t cberr_p_ = __VA_ARGS__; \
24: if (PetscUnlikely(cberr_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
25: if (((cberr_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
26: __abort_fn__(__comm__, PETSC_ERR_GPU_RESOURCE, \
27: "%s error %d (%s). Reports not initialized or alloc failed; " \
28: "this indicates the GPU may have run out resources", \
29: cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
30: } \
31: __abort_fn__(__comm__, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
32: } \
33: } while (0)
35: #define PetscCallCUPMBLAS(...) PetscCallCUPMBLAS_(SETERRQ, PETSC_COMM_SELF, __VA_ARGS__)
36: #define PetscCallCUPMBLASAbort(comm_, ...) PetscCallCUPMBLAS_(SETERRABORT, comm_, __VA_ARGS__)
38: // given cupmBlas<T>axpy() then
39: // T = PETSC_CUPBLAS_FP_TYPE
40: // given cupmBlas<T><u>nrm2() then
41: // T = PETSC_CUPMBLAS_FP_INPUT_TYPE
42: // u = PETSC_CUPMBLAS_FP_RETURN_TYPE
43: #if PetscDefined(USE_COMPLEX)
44: #if PetscDefined(USE_REAL_SINGLE)
45: #define PETSC_CUPMBLAS_FP_TYPE_U C
46: #define PETSC_CUPMBLAS_FP_TYPE_L c
47: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U S
48: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L s
49: #elif PetscDefined(USE_REAL_DOUBLE)
50: #define PETSC_CUPMBLAS_FP_TYPE_U Z
51: #define PETSC_CUPMBLAS_FP_TYPE_L z
52: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U D
53: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L d
54: #endif
55: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
56: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
57: #else
58: #if PetscDefined(USE_REAL_SINGLE)
59: #define PETSC_CUPMBLAS_FP_TYPE_U S
60: #define PETSC_CUPMBLAS_FP_TYPE_L s
61: #elif PetscDefined(USE_REAL_DOUBLE)
62: #define PETSC_CUPMBLAS_FP_TYPE_U D
63: #define PETSC_CUPMBLAS_FP_TYPE_L d
64: #endif
65: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
66: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
67: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U
68: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L
69: #endif // USE_COMPLEX
71: #if !defined(PETSC_CUPMBLAS_FP_TYPE_U) && !PetscDefined(USE_REAL___FLOAT128)
72: #error "Unsupported floating-point type for CUDA/HIP BLAS"
73: #endif
75: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED() - Helper macro to build a "modified"
76: // blas function whose return type does not match the input type
77: //
78: // input param:
79: // func - base suffix of the blas function, e.g. nrm2
80: //
81: // notes:
82: // requires PETSC_CUPMBLAS_FP_INPUT_TYPE to be defined as the blas floating point input type
83: // letter ("S" for real/complex single, "D" for real/complex double).
84: //
85: // requires PETSC_CUPMBLAS_FP_RETURN_TYPE to be defined as the blas floating point output type
86: // letter ("c" for complex single, "z" for complex double and <absolutely nothing> for real
87: // single/double).
88: //
89: // In their infinite wisdom nvidia/amd have made the upper-case vs lower-case scheme
90: // infuriatingly inconsistent...
91: //
92: // example usage:
93: // #define PETSC_CUPMBLAS_FP_INPUT_TYPE S
94: // #define PETSC_CUPMBLAS_FP_RETURN_TYPE
95: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Snrm2
96: //
97: // #define PETSC_CUPMBLAS_FP_INPUT_TYPE D
98: // #define PETSC_CUPMBLAS_FP_RETURN_TYPE z
99: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Dznrm2
100: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(func) PetscConcat(PetscConcat(PETSC_CUPMBLAS_FP_INPUT_TYPE, PETSC_CUPMBLAS_FP_RETURN_TYPE), func)
102: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE() - Helper macro to build Iamax and Iamin
103: // because they are both extra special
104: //
105: // input param:
106: // func - base suffix of the blas function, either amax or amin
107: //
108: // notes:
109: // The macro name literally stands for "I" ## "floating point type" because shockingly enough,
110: // that's what it does.
111: //
112: // requires PETSC_CUPMBLAS_FP_TYPE_L to be defined as the lower-case blas floating point input type
113: // letter ("s" for complex single, "z" for complex double, "s" for real single, and "d" for
114: // real double).
115: //
116: // example usage:
117: // #define PETSC_CUPMBLAS_FP_TYPE_L s
118: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amax) -> Isamax
119: //
120: // #define PETSC_CUPMBLAS_FP_TYPE_L z
121: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amin) -> Izamin
122: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(func) PetscConcat(I, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_L, func))
124: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD() - Helper macro to build a "standard"
125: // blas function name
126: //
127: // input param:
128: // func - base suffix of the blas function, e.g. axpy, scal
129: //
130: // notes:
131: // requires PETSC_CUPMBLAS_FP_TYPE to be defined as the blas floating-point letter ("C" for
132: // complex single, "Z" for complex double, "S" for real single, "D" for real double).
133: //
134: // example usage:
135: // #define PETSC_CUPMBLAS_FP_TYPE S
136: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Saxpy
137: //
138: // #define PETSC_CUPMBLAS_FP_TYPE Z
139: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Zaxpy
140: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(func) PetscConcat(PETSC_CUPMBLAS_FP_TYPE, func)
142: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT() - In case CUDA/HIP don't agree with our suffix
143: // one can provide both here
144: //
145: // input params:
146: // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
147: // IFPTYPE
148: // our_suffix - the suffix of the alias function
149: // their_suffix - the suffix of the function being aliased
150: //
151: // notes:
152: // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas function
153: // prefix. requires any other specific definitions required by the specific builder macro to
154: // also be defined. See PETSC_CUPM_ALIAS_FUNCTION_EXACT() for the exact expansion of the
155: // function alias.
156: //
157: // example usage:
158: // #define PETSC_CUPMBLAS_PREFIX cublas
159: // #define PETSC_CUPMBLAS_FP_TYPE C
160: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD,dot,dotc) ->
161: // template <typename... T>
162: // static constexpr auto cupmBlasXdot(T&&... args) *noexcept and returntype detection*
163: // {
164: // return cublasCdotc(std::forward<T>(args)...);
165: // }
166: #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, our_suffix, their_suffix) \
167: PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlasX, our_suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, PetscConcat(PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_, MACRO_SUFFIX)(their_suffix)))
169: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION() - Alias a CUDA/HIP blas function
170: //
171: // input params:
172: // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
173: // IFPTYPE
174: // suffix - the common suffix between CUDA and HIP of the alias function
175: //
176: // notes:
177: // see PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(), this macro just calls that one with "suffix" as
178: // "our_prefix" and "their_prefix"
179: #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MACRO_SUFFIX, suffix) PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, suffix, suffix)
181: // PETSC_CUPMBLAS_ALIAS_FUNCTION() - Alias a CUDA/HIP library function
182: //
183: // input params:
184: // suffix - the common suffix between CUDA and HIP of the alias function
185: //
186: // notes:
187: // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas library
188: // prefix. see PETSC_CUPMM_ALIAS_FUNCTION_EXACT() for the precise expansion of this macro.
189: //
190: // example usage:
191: // #define PETSC_CUPMBLAS_PREFIX hipblas
192: // PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) ->
193: // template <typename... T>
194: // static constexpr auto cupmBlasCreate(T&&... args) *noexcept and returntype detection*
195: // {
196: // return hipblasCreate(std::forward<T>(args)...);
197: // }
198: #define PETSC_CUPMBLAS_ALIAS_FUNCTION(suffix) PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlas, suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, suffix))
200: template <DeviceType>
201: struct BlasInterfaceImpl;
203: // Exists because HIP (for whatever godforsaken reason) has elected to define both their
204: // hipBlasHandle_t and hipSolverHandle_t as void *. So we cannot disambiguate them for overload
205: // resolution and hence need to wrap their types int this mess.
206: template <typename T, std::size_t I>
207: class cupmBlasHandleWrapper {
208: public:
209: constexpr cupmBlasHandleWrapper() noexcept = default;
210: constexpr cupmBlasHandleWrapper(T h) noexcept : handle_{std::move(h)} { static_assert(std::is_standard_layout<cupmBlasHandleWrapper<T, I>>::value, ""); }
212: cupmBlasHandleWrapper &operator=(std::nullptr_t) noexcept
213: {
214: handle_ = nullptr;
215: return *this;
216: }
218: operator T() const { return handle_; }
220: const T *ptr_to() const { return &handle_; }
221: T *ptr_to() { return &handle_; }
223: private:
224: T handle_{};
225: };
227: #if PetscDefined(HAVE_CUDA)
228: #define PETSC_CUPMBLAS_PREFIX cublas
229: #define PETSC_CUPMBLAS_PREFIX_U CUBLAS
230: #define PETSC_CUPMBLAS_FP_TYPE PETSC_CUPMBLAS_FP_TYPE_U
231: #define PETSC_CUPMBLAS_FP_INPUT_TYPE PETSC_CUPMBLAS_FP_INPUT_TYPE_U
232: #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
233: template <>
234: struct BlasInterfaceImpl<DeviceType::CUDA> : Interface<DeviceType::CUDA> {
235: // typedefs
236: using cupmBlasHandle_t = cupmBlasHandleWrapper<cublasHandle_t, 0>;
237: using cupmBlasError_t = cublasStatus_t;
238: using cupmBlasInt_t = int;
239: using cupmBlasPointerMode_t = cublasPointerMode_t;
241: // values
242: static const auto CUPMBLAS_STATUS_SUCCESS = CUBLAS_STATUS_SUCCESS;
243: static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = CUBLAS_STATUS_NOT_INITIALIZED;
244: static const auto CUPMBLAS_STATUS_ALLOC_FAILED = CUBLAS_STATUS_ALLOC_FAILED;
245: static const auto CUPMBLAS_POINTER_MODE_HOST = CUBLAS_POINTER_MODE_HOST;
246: static const auto CUPMBLAS_POINTER_MODE_DEVICE = CUBLAS_POINTER_MODE_DEVICE;
247: static const auto CUPMBLAS_OP_T = CUBLAS_OP_T;
248: static const auto CUPMBLAS_OP_N = CUBLAS_OP_N;
249: static const auto CUPMBLAS_OP_C = CUBLAS_OP_C;
250: static const auto CUPMBLAS_FILL_MODE_LOWER = CUBLAS_FILL_MODE_LOWER;
251: static const auto CUPMBLAS_FILL_MODE_UPPER = CUBLAS_FILL_MODE_UPPER;
252: static const auto CUPMBLAS_SIDE_LEFT = CUBLAS_SIDE_LEFT;
253: static const auto CUPMBLAS_DIAG_NON_UNIT = CUBLAS_DIAG_NON_UNIT;
255: // utility functions
256: PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
257: PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
258: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
259: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
260: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
261: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)
263: // level 1 BLAS
264: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
265: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
266: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
267: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
268: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
269: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
270: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
271: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)
273: // level 2 BLAS
274: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
276: // level 3 BLAS
277: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
278: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsm)
280: // BLAS extensions
281: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
283: PETSC_NODISCARD static const char *cupmBlasGetErrorName(cupmBlasError_t status) noexcept { return PetscCUBLASGetErrorName(status); }
284: };
285: #undef PETSC_CUPMBLAS_PREFIX
286: #undef PETSC_CUPMBLAS_PREFIX_U
287: #undef PETSC_CUPMBLAS_FP_TYPE
288: #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
289: #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
290: #endif // PetscDefined(HAVE_CUDA)
292: #if PetscDefined(HAVE_HIP)
293: #define PETSC_CUPMBLAS_PREFIX hipblas
294: #define PETSC_CUPMBLAS_PREFIX_U HIPBLAS
295: #define PETSC_CUPMBLAS_FP_TYPE PETSC_CUPMBLAS_FP_TYPE_U
296: #define PETSC_CUPMBLAS_FP_INPUT_TYPE PETSC_CUPMBLAS_FP_INPUT_TYPE_U
297: #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
298: template <>
299: struct BlasInterfaceImpl<DeviceType::HIP> : Interface<DeviceType::HIP> {
300: // typedefs
301: using cupmBlasHandle_t = cupmBlasHandleWrapper<hipblasHandle_t, 0>;
302: using cupmBlasError_t = hipblasStatus_t;
303: using cupmBlasInt_t = int; // rocblas will have its own
304: using cupmBlasPointerMode_t = hipblasPointerMode_t;
306: // values
307: static const auto CUPMBLAS_STATUS_SUCCESS = HIPBLAS_STATUS_SUCCESS;
308: static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = HIPBLAS_STATUS_NOT_INITIALIZED;
309: static const auto CUPMBLAS_STATUS_ALLOC_FAILED = HIPBLAS_STATUS_ALLOC_FAILED;
310: static const auto CUPMBLAS_POINTER_MODE_HOST = HIPBLAS_POINTER_MODE_HOST;
311: static const auto CUPMBLAS_POINTER_MODE_DEVICE = HIPBLAS_POINTER_MODE_DEVICE;
312: static const auto CUPMBLAS_OP_T = HIPBLAS_OP_T;
313: static const auto CUPMBLAS_OP_N = HIPBLAS_OP_N;
314: static const auto CUPMBLAS_OP_C = HIPBLAS_OP_C;
315: static const auto CUPMBLAS_FILL_MODE_LOWER = HIPBLAS_FILL_MODE_LOWER;
316: static const auto CUPMBLAS_FILL_MODE_UPPER = HIPBLAS_FILL_MODE_UPPER;
317: static const auto CUPMBLAS_SIDE_LEFT = HIPBLAS_SIDE_LEFT;
318: static const auto CUPMBLAS_DIAG_NON_UNIT = HIPBLAS_DIAG_NON_UNIT;
320: // utility functions
321: PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
322: PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
323: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
324: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
325: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
326: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)
328: // level 1 BLAS
329: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
330: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
331: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
332: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
333: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
334: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
335: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
336: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)
338: // level 2 BLAS
339: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
341: // level 3 BLAS
342: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
343: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsm)
345: // BLAS extensions
346: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
348: PETSC_NODISCARD static const char *cupmBlasGetErrorName(cupmBlasError_t status) noexcept { return PetscHIPBLASGetErrorName(status); }
349: };
350: #undef PETSC_CUPMBLAS_PREFIX
351: #undef PETSC_CUPMBLAS_PREFIX_U
352: #undef PETSC_CUPMBLAS_FP_TYPE
353: #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
354: #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
355: #endif // PetscDefined(HAVE_HIP)
357: #define PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T) \
358: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T); \
359: /* introspection */ \
360: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetErrorName; \
361: /* types */ \
362: using cupmBlasHandle_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasHandle_t; \
363: using cupmBlasError_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasError_t; \
364: using cupmBlasInt_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasInt_t; \
365: using cupmBlasPointerMode_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasPointerMode_t; \
366: /* values */ \
367: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_SUCCESS; \
368: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_NOT_INITIALIZED; \
369: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_ALLOC_FAILED; \
370: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_POINTER_MODE_HOST; \
371: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_POINTER_MODE_DEVICE; \
372: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_T; \
373: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_N; \
374: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_C; \
375: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_FILL_MODE_LOWER; \
376: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_FILL_MODE_UPPER; \
377: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_SIDE_LEFT; \
378: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_DIAG_NON_UNIT; \
379: /* utility functions */ \
380: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasCreate; \
381: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasDestroy; \
382: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetStream; \
383: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasSetStream; \
384: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetPointerMode; \
385: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasSetPointerMode; \
386: /* level 1 BLAS */ \
387: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXaxpy; \
388: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXscal; \
389: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXdot; \
390: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXdotu; \
391: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXswap; \
392: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXnrm2; \
393: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXamax; \
394: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXasum; \
395: /* level 2 BLAS */ \
396: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgemv; \
397: /* level 3 BLAS */ \
398: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgemm; \
399: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtrsm; \
400: /* BLAS extensions */ \
401: using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgeam
403: // The actual interface class
404: template <DeviceType T>
405: struct BlasInterface : BlasInterfaceImpl<T> {
406: PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T);
408: PETSC_NODISCARD static constexpr const char *cupmBlasName() noexcept { return T == DeviceType::CUDA ? "cuBLAS" : "hipBLAS"; }
410: static PetscErrorCode PetscCUPMBlasSetPointerModeFromPointer(cupmBlasHandle_t handle, const void *ptr) noexcept
411: {
412: auto mtype = PETSC_MEMTYPE_HOST;
414: PetscFunctionBegin;
415: PetscCall(PetscCUPMGetMemType(ptr, &mtype));
416: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST));
417: PetscFunctionReturn(PETSC_SUCCESS);
418: }
420: static PetscErrorCode checkCupmBlasIntCast(PetscInt x) noexcept
421: {
422: PetscFunctionBegin;
423: PetscCheck((std::is_same<PetscInt, cupmBlasInt_t>::value) || (x <= std::numeric_limits<cupmBlasInt_t>::max()), PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "%" PetscInt_FMT " is too big for %s, which may be restricted to 32-bit integers", x, cupmBlasName());
424: PetscCheck(x >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Passing negative integer (%" PetscInt_FMT ") to %s routine", x, cupmBlasName());
425: PetscFunctionReturn(PETSC_SUCCESS);
426: }
428: static PetscErrorCode PetscCUPMBlasIntCast(PetscInt x, cupmBlasInt_t *y) noexcept
429: {
430: PetscFunctionBegin;
431: *y = static_cast<cupmBlasInt_t>(x);
432: PetscCall(checkCupmBlasIntCast(x));
433: PetscFunctionReturn(PETSC_SUCCESS);
434: }
436: class CUPMBlasPointerModeGuard {
437: public:
438: CUPMBlasPointerModeGuard(const cupmBlasHandle_t &handle, cupmBlasPointerMode_t mode) noexcept : handle_{handle}
439: {
440: PetscFunctionBegin;
441: PetscCallCUPMBLASAbort(PETSC_COMM_SELF, cupmBlasGetPointerMode(handle, &this->old_));
442: if (this->old_ == mode) {
443: this->set_ = false;
444: } else {
445: this->set_ = true;
446: PetscCallCUPMBLASAbort(PETSC_COMM_SELF, cupmBlasSetPointerMode(handle, mode));
447: }
448: PetscFunctionReturnVoid();
449: }
451: CUPMBlasPointerModeGuard(const cupmBlasHandle_t &handle, PetscMemType mtype) noexcept : CUPMBlasPointerModeGuard{handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST} { }
453: ~CUPMBlasPointerModeGuard() noexcept
454: {
455: PetscFunctionBegin;
456: if (this->set_) PetscCallCUPMBLASAbort(PETSC_COMM_SELF, cupmBlasSetPointerMode(this->handle_, this->old_));
457: PetscFunctionReturnVoid();
458: }
460: private:
461: cupmBlasHandle_t handle_;
462: cupmBlasPointerMode_t old_;
463: bool set_;
464: };
465: };
467: #define PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(T) \
468: PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T); \
469: using ::Petsc::device::cupm::impl::BlasInterface<T>::cupmBlasName; \
470: using ::Petsc::device::cupm::impl::BlasInterface<T>::PetscCUPMBlasSetPointerModeFromPointer; \
471: using ::Petsc::device::cupm::impl::BlasInterface<T>::checkCupmBlasIntCast; \
472: using ::Petsc::device::cupm::impl::BlasInterface<T>::PetscCUPMBlasIntCast; \
473: using CUPMBlasPointerModeGuard = typename ::Petsc::device::cupm::impl::BlasInterface<T>::CUPMBlasPointerModeGuard
475: #if PetscDefined(HAVE_CUDA)
476: extern template struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterface<DeviceType::CUDA>;
477: #endif
479: #if PetscDefined(HAVE_HIP)
480: extern template struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterface<DeviceType::HIP>;
481: #endif
483: } // namespace impl
485: } // namespace cupm
487: } // namespace device
489: } // namespace Petsc