Actual source code: cupmstream.hpp
1: #pragma once
3: #include <petsc/private/cupminterface.hpp>
5: #include "../segmentedmempool.hpp"
6: #include "cupmevent.hpp"
8: namespace Petsc
9: {
11: namespace device
12: {
14: namespace cupm
15: {
17: // A bare wrapper around a cupmStream_t. The reason it exists is because we need to uniquely
18: // identify separate cupm streams. This is so that the memory pool can accelerate allocation
19: // calls as it can just pass back a pointer to memory that was used on the same
20: // stream. Otherwise it must either serialize with another stream or allocate a new chunk.
21: // Address of the objects does not suffice since cupmStreams are very likely internally reused.
23: template <DeviceType T>
24: class CUPMStream : public StreamBase<CUPMStream<T>>, impl::Interface<T> {
25: using crtp_base_type = StreamBase<CUPMStream<T>>;
26: friend crtp_base_type;
28: public:
29: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
31: using stream_type = cupmStream_t;
32: using id_type = typename crtp_base_type::id_type;
33: using event_type = CUPMEvent<T>;
34: using flag_type = unsigned int;
36: CUPMStream() noexcept = default;
38: PetscErrorCode destroy() noexcept;
39: PetscErrorCode create(flag_type) noexcept;
40: PetscErrorCode change_type(PetscStreamType) noexcept;
42: private:
43: stream_type stream_{};
44: id_type id_ = new_id_();
46: PETSC_NODISCARD static id_type new_id_() noexcept;
48: // CRTP implementations
49: PETSC_NODISCARD const stream_type &get_stream_() const noexcept;
50: PETSC_NODISCARD id_type get_id_() const noexcept;
51: PetscErrorCode record_event_(event_type &) const noexcept;
52: PetscErrorCode wait_for_(event_type &) const noexcept;
53: };
55: template <DeviceType T>
56: inline PetscErrorCode CUPMStream<T>::destroy() noexcept
57: {
58: PetscFunctionBegin;
59: if (stream_) {
60: PetscCallCUPM(cupmStreamDestroy(stream_));
61: stream_ = cupmStream_t{};
62: id_ = 0;
63: }
64: PetscFunctionReturn(PETSC_SUCCESS);
65: }
67: template <DeviceType T>
68: inline PetscErrorCode CUPMStream<T>::create(flag_type flags) noexcept
69: {
70: PetscFunctionBegin;
71: if (stream_) {
72: if (PetscDefined(USE_DEBUG)) {
73: flag_type current_flags;
75: PetscCallCUPM(cupmStreamGetFlags(stream_, ¤t_flags));
76: PetscCheck(flags == current_flags, PETSC_COMM_SELF, PETSC_ERR_GPU, "Current flags %u != requested flags %u for stream %d", current_flags, flags, id_);
77: }
78: PetscFunctionReturn(PETSC_SUCCESS);
79: }
80: PetscCallCUPM(cupmStreamCreateWithFlags(&stream_, flags));
81: id_ = new_id_();
82: PetscFunctionReturn(PETSC_SUCCESS);
83: }
85: template <DeviceType T>
86: inline PetscErrorCode CUPMStream<T>::change_type(PetscStreamType newtype) noexcept
87: {
88: PetscFunctionBegin;
89: if (newtype == PETSC_STREAM_GLOBAL_BLOCKING) {
90: PetscCall(destroy());
91: } else {
92: const flag_type preferred = newtype == PETSC_STREAM_DEFAULT_BLOCKING ? cupmStreamDefault : cupmStreamNonBlocking;
94: if (stream_) {
95: flag_type flag;
97: PetscCallCUPM(cupmStreamGetFlags(stream_, &flag));
98: if (flag == preferred) PetscFunctionReturn(PETSC_SUCCESS);
99: PetscCall(destroy());
100: }
101: PetscCall(create(preferred));
102: }
103: PetscFunctionReturn(PETSC_SUCCESS);
104: }
106: template <DeviceType T>
107: inline typename CUPMStream<T>::id_type CUPMStream<T>::new_id_() noexcept
108: {
109: static id_type id = 0;
110: return id++;
111: }
113: // CRTP implementations
114: template <DeviceType T>
115: inline const typename CUPMStream<T>::stream_type &CUPMStream<T>::get_stream_() const noexcept
116: {
117: return stream_;
118: }
120: template <DeviceType T>
121: inline typename CUPMStream<T>::id_type CUPMStream<T>::get_id_() const noexcept
122: {
123: return id_;
124: }
126: template <DeviceType T>
127: inline PetscErrorCode CUPMStream<T>::record_event_(event_type &event) const noexcept
128: {
129: PetscFunctionBegin;
130: PetscCall(event.record(stream_));
131: PetscFunctionReturn(PETSC_SUCCESS);
132: }
134: template <DeviceType T>
135: inline PetscErrorCode CUPMStream<T>::wait_for_(event_type &event) const noexcept
136: {
137: PetscFunctionBegin;
138: PetscCallCUPM(cupmStreamWaitEvent(stream_, event.get(), 0));
139: PetscFunctionReturn(PETSC_SUCCESS);
140: }
142: } // namespace cupm
144: } // namespace device
146: } // namespace Petsc