HIP: Heterogenous-computing Interface for Portability
Loading...
Searching...
No Matches
amd_hip_cooperative_groups.h
1/*
2Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved.
3
4Permission is hereby granted, free of charge, to any person obtaining a copy
5of this software and associated documentation files (the "Software"), to deal
6in the Software without restriction, including without limitation the rights
7to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8copies of the Software, and to permit persons to whom the Software is
9furnished to do so, subject to the following conditions:
10
11The above copyright notice and this permission notice shall be included in
12all copies or substantial portions of the Software.
13
14THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20THE SOFTWARE.
21*/
22
32#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
33#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
34
35#if __cplusplus
36#if !defined(__HIPCC_RTC__)
38#endif
39
40#define __hip_abort() \
41 { abort(); }
42#if defined(NDEBUG)
43#define __hip_assert(COND)
44#else
45#define __hip_assert(COND) \
46 { \
47 if (!COND) { \
48 __hip_abort(); \
49 } \
50 }
51#endif
52
53namespace cooperative_groups {
54
63class thread_group {
64 protected:
65 uint32_t _type; // thread_group type
66 uint32_t _size; // total number of threads in the tread_group
67 uint64_t _mask; // Lanemask for coalesced and tiled partitioned group types,
68 // LSB represents lane 0, and MSB represents lane 63
69
70 // Construct a thread group, and set thread group type and other essential
71 // thread group properties. This generic thread group is directly constructed
72 // only when the group is supposed to contain only the calling the thread
73 // (throurh the API - `this_thread()`), and in all other cases, this thread
74 // group object is a sub-object of some other derived thread group object
75 __CG_QUALIFIER__ thread_group(internal::group_type type, uint32_t size = static_cast<uint64_t>(0),
76 uint64_t mask = static_cast<uint64_t>(0)) {
77 _type = type;
78 _size = size;
79 _mask = mask;
80 }
81
82 struct _tiled_info {
83 bool is_tiled;
84 unsigned int size;
85 unsigned int meta_group_rank;
86 unsigned int meta_group_size;
87 };
88
89 struct _coalesced_info {
90 lane_mask member_mask;
91 unsigned int size;
92 struct _tiled_info tiled_info;
93 } coalesced_info;
94
95 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent,
96 unsigned int tile_size);
97 friend class thread_block;
98
99 public:
100 // Total number of threads in the thread group, and this serves the purpose
101 // for all derived cooperative group types since their `size` is directly
102 // saved during the construction
103 __CG_QUALIFIER__ uint32_t size() const { return _size; }
104 __CG_QUALIFIER__ unsigned int cg_type() const { return _type; }
105 // Rank of the calling thread within [0, size())
106 __CG_QUALIFIER__ uint32_t thread_rank() const;
107 // Is this cooperative group type valid?
108 __CG_QUALIFIER__ bool is_valid() const;
109 // synchronize the threads in the thread group
110 __CG_QUALIFIER__ void sync() const;
111};
135class multi_grid_group : public thread_group {
136 // Only these friend functions are allowed to construct an object of this class
137 // and access its resources
138 friend __CG_QUALIFIER__ multi_grid_group this_multi_grid();
139
140 protected:
141 // Construct mutli-grid thread group (through the API this_multi_grid())
142 explicit __CG_QUALIFIER__ multi_grid_group(uint32_t size)
143 : thread_group(internal::cg_multi_grid, size) {}
144
145 public:
146 // Number of invocations participating in this multi-grid group. In other
147 // words, the number of GPUs
148 __CG_QUALIFIER__ uint32_t num_grids() { return internal::multi_grid::num_grids(); }
149 // Rank of this invocation. In other words, an ID number within the range
150 // [0, num_grids()) of the GPU, this kernel is running on
151 __CG_QUALIFIER__ uint32_t grid_rank() { return internal::multi_grid::grid_rank(); }
152 __CG_QUALIFIER__ uint32_t thread_rank() const { return internal::multi_grid::thread_rank(); }
153 __CG_QUALIFIER__ bool is_valid() const { return internal::multi_grid::is_valid(); }
154 __CG_QUALIFIER__ void sync() const { internal::multi_grid::sync(); }
155};
156
166__CG_QUALIFIER__ multi_grid_group this_multi_grid() {
167 return multi_grid_group(internal::multi_grid::size());
168}
169
178class grid_group : public thread_group {
179 // Only these friend functions are allowed to construct an object of this class
180 // and access its resources
181 friend __CG_QUALIFIER__ grid_group this_grid();
182
183 protected:
184 // Construct grid thread group (through the API this_grid())
185 explicit __CG_QUALIFIER__ grid_group(uint32_t size) : thread_group(internal::cg_grid, size) {}
186
187 public:
188 __CG_QUALIFIER__ uint32_t thread_rank() const { return internal::grid::thread_rank(); }
189 __CG_QUALIFIER__ bool is_valid() const { return internal::grid::is_valid(); }
190 __CG_QUALIFIER__ void sync() const { internal::grid::sync(); }
191};
192
202__CG_QUALIFIER__ grid_group this_grid() { return grid_group(internal::grid::size()); }
203
213class thread_block : public thread_group {
214 // Only these friend functions are allowed to construct an object of thi
215 // class and access its resources
216 friend __CG_QUALIFIER__ thread_block this_thread_block();
217 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent,
218 unsigned int tile_size);
219 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_block& parent,
220 unsigned int tile_size);
221 protected:
222 // Construct a workgroup thread group (through the API this_thread_block())
223 explicit __CG_QUALIFIER__ thread_block(uint32_t size)
224 : thread_group(internal::cg_workgroup, size) {}
225
226 __CG_QUALIFIER__ thread_group new_tiled_group(unsigned int tile_size) const {
227 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
228 // Invalid tile size, assert
229 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
230 __hip_assert(false && "invalid tile size")
231 }
232
233 thread_group tiledGroup = thread_group(internal::cg_tiled_group, tile_size);
234 tiledGroup.coalesced_info.tiled_info.size = tile_size;
235 tiledGroup.coalesced_info.tiled_info.is_tiled = true;
236 tiledGroup.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
237 tiledGroup.coalesced_info.tiled_info.meta_group_size = (size() + tile_size - 1) / tile_size;
238 return tiledGroup;
239 }
240
241 public:
242 // 3-dimensional block index within the grid
243 __CG_STATIC_QUALIFIER__ dim3 group_index() { return internal::workgroup::group_index(); }
244 // 3-dimensional thread index within the block
245 __CG_STATIC_QUALIFIER__ dim3 thread_index() { return internal::workgroup::thread_index(); }
246 __CG_STATIC_QUALIFIER__ uint32_t thread_rank() { return internal::workgroup::thread_rank(); }
247 __CG_STATIC_QUALIFIER__ uint32_t size() { return internal::workgroup::size(); }
248 __CG_STATIC_QUALIFIER__ bool is_valid() { return internal::workgroup::is_valid(); }
249 __CG_STATIC_QUALIFIER__ void sync() { internal::workgroup::sync(); }
250 __CG_QUALIFIER__ dim3 group_dim() { return internal::workgroup::block_dim(); }
251};
252
262__CG_QUALIFIER__ thread_block this_thread_block() {
263 return thread_block(internal::workgroup::size());
264}
265
274class tiled_group : public thread_group {
275 private:
276 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent,
277 unsigned int tile_size);
278 friend __CG_QUALIFIER__ tiled_group tiled_partition(const tiled_group& parent,
279 unsigned int tile_size);
280
281 __CG_QUALIFIER__ tiled_group new_tiled_group(unsigned int tile_size) const {
282 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
283
284 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
285 __hip_assert(false && "invalid tile size")
286 }
287
288 if (size() <= tile_size) {
289 return *this;
290 }
291
292 tiled_group tiledGroup = tiled_group(tile_size);
293 tiledGroup.coalesced_info.tiled_info.is_tiled = true;
294 return tiledGroup;
295 }
296
297 protected:
298 explicit __CG_QUALIFIER__ tiled_group(unsigned int tileSize)
299 : thread_group(internal::cg_tiled_group, tileSize) {
300 coalesced_info.tiled_info.size = tileSize;
301 coalesced_info.tiled_info.is_tiled = true;
302 }
303
304 public:
305 __CG_QUALIFIER__ unsigned int size() const { return (coalesced_info.tiled_info.size); }
306
307 __CG_QUALIFIER__ unsigned int thread_rank() const {
308 return (internal::workgroup::thread_rank() & (coalesced_info.tiled_info.size - 1));
309 }
310
311 __CG_QUALIFIER__ void sync() const {
312 internal::tiled_group::sync();
313 }
314};
315
323class coalesced_group : public thread_group {
324 private:
325 friend __CG_QUALIFIER__ coalesced_group coalesced_threads();
326 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent, unsigned int tile_size);
327 friend __CG_QUALIFIER__ coalesced_group tiled_partition(const coalesced_group& parent, unsigned int tile_size);
328
329 __CG_QUALIFIER__ coalesced_group new_tiled_group(unsigned int tile_size) const {
330 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
331
332 if (!tile_size || (tile_size > size()) || !pow2) {
333 return coalesced_group(0);
334 }
335
336 // If a tiled group is passed to be partitioned further into a coalesced_group.
337 // prepare a mask for further partitioning it so that it stays coalesced.
338 if (coalesced_info.tiled_info.is_tiled) {
339 unsigned int base_offset = (thread_rank() & (~(tile_size - 1)));
340 unsigned int masklength = min(static_cast<unsigned int>(size()) - base_offset, tile_size);
341 lane_mask member_mask = static_cast<lane_mask>(-1) >> (__AMDGCN_WAVEFRONT_SIZE - masklength);
342
343 member_mask <<= (__lane_id() & ~(tile_size - 1));
344 coalesced_group coalesced_tile = coalesced_group(member_mask);
345 coalesced_tile.coalesced_info.tiled_info.is_tiled = true;
346 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
347 coalesced_tile.coalesced_info.tiled_info.meta_group_size = size() / tile_size;
348 return coalesced_tile;
349 }
350 // Here the parent coalesced_group is not partitioned.
351 else {
352 lane_mask member_mask = 0;
353 unsigned int tile_rank = 0;
354 int lanes_to_skip = ((thread_rank()) / tile_size) * tile_size;
355
356 for (unsigned int i = 0; i < __AMDGCN_WAVEFRONT_SIZE; i++) {
357 lane_mask active = coalesced_info.member_mask & (1 << i);
358 // Make sure the lane is active
359 if (active) {
360 if (lanes_to_skip <= 0 && tile_rank < tile_size) {
361 // Prepare a member_mask that is appropriate for a tile
362 member_mask |= active;
363 tile_rank++;
364 }
365 lanes_to_skip--;
366 }
367 }
368 coalesced_group coalesced_tile = coalesced_group(member_mask);
369 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
370 coalesced_tile.coalesced_info.tiled_info.meta_group_size =
371 (size() + tile_size - 1) / tile_size;
372 return coalesced_tile;
373 }
374 return coalesced_group(0);
375 }
376
377 protected:
378 // Constructor
379 explicit __CG_QUALIFIER__ coalesced_group(lane_mask member_mask)
380 : thread_group(internal::cg_coalesced_group) {
381 coalesced_info.member_mask = member_mask; // Which threads are active
382 coalesced_info.size = __popcll(coalesced_info.member_mask); // How many threads are active
383 coalesced_info.tiled_info.is_tiled = false; // Not a partitioned group
384 coalesced_info.tiled_info.meta_group_rank = 0;
385 coalesced_info.tiled_info.meta_group_size = 1;
386 }
387
388 public:
389 __CG_QUALIFIER__ unsigned int size() const {
390 return coalesced_info.size;
391 }
392
393 __CG_QUALIFIER__ unsigned int thread_rank() const {
394 return internal::coalesced_group::masked_bit_count(coalesced_info.member_mask);
395 }
396
397 __CG_QUALIFIER__ void sync() const {
398 internal::coalesced_group::sync();
399 }
400
401 __CG_QUALIFIER__ unsigned int meta_group_rank() const {
402 return coalesced_info.tiled_info.meta_group_rank;
403 }
404
405 __CG_QUALIFIER__ unsigned int meta_group_size() const {
406 return coalesced_info.tiled_info.meta_group_size;
407 }
408
409 template <class T>
410 __CG_QUALIFIER__ T shfl(T var, int srcRank) const {
411 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
412
413 srcRank = srcRank % static_cast<int>(size());
414
415 int lane = (size() == __AMDGCN_WAVEFRONT_SIZE) ? srcRank
416 : (__AMDGCN_WAVEFRONT_SIZE == 64) ? __fns64(coalesced_info.member_mask, 0, (srcRank + 1))
417 : __fns32(coalesced_info.member_mask, 0, (srcRank + 1));
418
419 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
420 }
421
422 template <class T>
423 __CG_QUALIFIER__ T shfl_down(T var, unsigned int lane_delta) const {
424 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
425
426 // Note: The cuda implementation appears to use the remainder of lane_delta
427 // and WARP_SIZE as the shift value rather than lane_delta itself.
428 // This is not described in the documentation and is not done here.
429
430 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
431 return __shfl_down(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
432 }
433
434 int lane;
435 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
436 lane = __fns64(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
437 }
438 else {
439 lane = __fns32(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
440 }
441
442 if (lane == -1) {
443 lane = __lane_id();
444 }
445
446 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
447 }
448
449 template <class T>
450 __CG_QUALIFIER__ T shfl_up(T var, unsigned int lane_delta) const {
451 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
452
453 // Note: The cuda implementation appears to use the remainder of lane_delta
454 // and WARP_SIZE as the shift value rather than lane_delta itself.
455 // This is not described in the documentation and is not done here.
456
457 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
458 return __shfl_up(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
459 }
460
461 int lane;
462 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
463 lane = __fns64(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
464 }
465 else if (__AMDGCN_WAVEFRONT_SIZE == 32) {
466 lane = __fns32(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
467 }
468
469 if (lane == -1) {
470 lane = __lane_id();
471 }
472
473 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
474 }
475};
476
484__CG_QUALIFIER__ coalesced_group coalesced_threads() {
485 return cooperative_groups::coalesced_group(__builtin_amdgcn_read_exec());
486}
487
493__CG_QUALIFIER__ uint32_t thread_group::thread_rank() const {
494 switch (this->_type) {
495 case internal::cg_multi_grid: {
496 return (static_cast<const multi_grid_group*>(this)->thread_rank());
497 }
498 case internal::cg_grid: {
499 return (static_cast<const grid_group*>(this)->thread_rank());
500 }
501 case internal::cg_workgroup: {
502 return (static_cast<const thread_block*>(this)->thread_rank());
503 }
504 case internal::cg_tiled_group: {
505 return (static_cast<const tiled_group*>(this)->thread_rank());
506 }
507 case internal::cg_coalesced_group: {
508 return (static_cast<const coalesced_group*>(this)->thread_rank());
509 }
510 default: {
511 __hip_assert(false && "invalid cooperative group type")
512 return -1;
513 }
514 }
515}
521__CG_QUALIFIER__ bool thread_group::is_valid() const {
522 switch (this->_type) {
523 case internal::cg_multi_grid: {
524 return (static_cast<const multi_grid_group*>(this)->is_valid());
525 }
526 case internal::cg_grid: {
527 return (static_cast<const grid_group*>(this)->is_valid());
528 }
529 case internal::cg_workgroup: {
530 return (static_cast<const thread_block*>(this)->is_valid());
531 }
532 case internal::cg_tiled_group: {
533 return (static_cast<const tiled_group*>(this)->is_valid());
534 }
535 case internal::cg_coalesced_group: {
536 return (static_cast<const coalesced_group*>(this)->is_valid());
537 }
538 default: {
539 __hip_assert(false && "invalid cooperative group type")
540 return false;
541 }
542 }
543}
549__CG_QUALIFIER__ void thread_group::sync() const {
550 switch (this->_type) {
551 case internal::cg_multi_grid: {
552 static_cast<const multi_grid_group*>(this)->sync();
553 break;
554 }
555 case internal::cg_grid: {
556 static_cast<const grid_group*>(this)->sync();
557 break;
558 }
559 case internal::cg_workgroup: {
560 static_cast<const thread_block*>(this)->sync();
561 break;
562 }
563 case internal::cg_tiled_group: {
564 static_cast<const tiled_group*>(this)->sync();
565 break;
566 }
567 case internal::cg_coalesced_group: {
568 static_cast<const coalesced_group*>(this)->sync();
569 break;
570 }
571 default: {
572 __hip_assert(false && "invalid cooperative group type")
573 }
574 }
575}
576
583template <class CGTy> __CG_QUALIFIER__ uint32_t group_size(CGTy const& g) { return g.size(); }
590template <class CGTy> __CG_QUALIFIER__ uint32_t thread_rank(CGTy const& g) {
591 return g.thread_rank();
592}
599template <class CGTy> __CG_QUALIFIER__ bool is_valid(CGTy const& g) { return g.is_valid(); }
606template <class CGTy> __CG_QUALIFIER__ void sync(CGTy const& g) { g.sync(); }
612template <unsigned int tileSize> class tile_base {
613 protected:
614 _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize;
615
616 public:
617 // Rank of the thread within this tile
618 _CG_STATIC_CONST_DECL_ unsigned int thread_rank() {
619 return (internal::workgroup::thread_rank() & (numThreads - 1));
620 }
621
622 // Number of threads within this tile
623 __CG_STATIC_QUALIFIER__ unsigned int size() { return numThreads; }
624};
630template <unsigned int size> class thread_block_tile_base : public tile_base<size> {
631 static_assert(is_valid_tile_size<size>::value,
632 "Tile size is either not a power of 2 or greater than the wavefront size");
633 using tile_base<size>::numThreads;
634
635 public:
636 __CG_STATIC_QUALIFIER__ void sync() {
637 internal::tiled_group::sync();
638 }
639
640 template <class T> __CG_QUALIFIER__ T shfl(T var, int srcRank) const {
641 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
642 return (__shfl(var, srcRank, numThreads));
643 }
644
645 template <class T> __CG_QUALIFIER__ T shfl_down(T var, unsigned int lane_delta) const {
646 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
647 return (__shfl_down(var, lane_delta, numThreads));
648 }
649
650 template <class T> __CG_QUALIFIER__ T shfl_up(T var, unsigned int lane_delta) const {
651 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
652 return (__shfl_up(var, lane_delta, numThreads));
653 }
654
655 template <class T> __CG_QUALIFIER__ T shfl_xor(T var, unsigned int laneMask) const {
656 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
657 return (__shfl_xor(var, laneMask, numThreads));
658 }
659};
662template <unsigned int tileSize, typename ParentCGTy>
663class parent_group_info {
664public:
665 // Returns the linear rank of the group within the set of tiles partitioned
666 // from a parent group (bounded by meta_group_size)
667 __CG_STATIC_QUALIFIER__ unsigned int meta_group_rank() {
668 return ParentCGTy::thread_rank() / tileSize;
669 }
670
671 // Returns the number of groups created when the parent group was partitioned.
672 __CG_STATIC_QUALIFIER__ unsigned int meta_group_size() {
673 return (ParentCGTy::size() + tileSize - 1) / tileSize;
674 }
675};
676
683template <unsigned int tileSize, class ParentCGTy>
684class thread_block_tile_type : public thread_block_tile_base<tileSize>,
685 public tiled_group,
686 public parent_group_info<tileSize, ParentCGTy> {
687 _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize;
688 protected:
689 __CG_QUALIFIER__ thread_block_tile_type() : tiled_group(numThreads) {
690 coalesced_info.tiled_info.size = numThreads;
691 coalesced_info.tiled_info.is_tiled = true;
692 }
693};
694
695// Partial template specialization
696template <unsigned int tileSize>
697class thread_block_tile_type<tileSize, void> : public thread_block_tile_base<tileSize>,
698 public tiled_group
699 {
700 _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize;
701
702 typedef thread_block_tile_base<numThreads> tbtBase;
703
704 protected:
705
706 __CG_QUALIFIER__ thread_block_tile_type(unsigned int meta_group_rank, unsigned int meta_group_size)
707 : tiled_group(numThreads) {
708 coalesced_info.tiled_info.size = numThreads;
709 coalesced_info.tiled_info.is_tiled = true;
710 coalesced_info.tiled_info.meta_group_rank = meta_group_rank;
711 coalesced_info.tiled_info.meta_group_size = meta_group_size;
712 }
713
714 public:
715 using tbtBase::size;
716 using tbtBase::sync;
717 using tbtBase::thread_rank;
718
719 __CG_QUALIFIER__ unsigned int meta_group_rank() const {
720 return coalesced_info.tiled_info.meta_group_rank;
721 }
722
723 __CG_QUALIFIER__ unsigned int meta_group_size() const {
724 return coalesced_info.tiled_info.meta_group_size;
725 }
726// end of operative group
730};
731
732
739__CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent, unsigned int tile_size) {
740 if (parent.cg_type() == internal::cg_tiled_group) {
741 const tiled_group* cg = static_cast<const tiled_group*>(&parent);
742 return cg->new_tiled_group(tile_size);
743 }
744 else if(parent.cg_type() == internal::cg_coalesced_group) {
745 const coalesced_group* cg = static_cast<const coalesced_group*>(&parent);
746 return cg->new_tiled_group(tile_size);
747 }
748 else {
749 const thread_block* tb = static_cast<const thread_block*>(&parent);
750 return tb->new_tiled_group(tile_size);
751 }
752}
753
754// Thread block type overload
755__CG_QUALIFIER__ thread_group tiled_partition(const thread_block& parent, unsigned int tile_size) {
756 return (parent.new_tiled_group(tile_size));
757}
758
759__CG_QUALIFIER__ tiled_group tiled_partition(const tiled_group& parent, unsigned int tile_size) {
760 return (parent.new_tiled_group(tile_size));
761}
762
763// If a coalesced group is passed to be partitioned, it should remain coalesced
764__CG_QUALIFIER__ coalesced_group tiled_partition(const coalesced_group& parent, unsigned int tile_size) {
765 return (parent.new_tiled_group(tile_size));
766}
767
768template <unsigned int size, class ParentCGTy> class thread_block_tile;
769
770namespace impl {
771template <unsigned int size, class ParentCGTy> class thread_block_tile_internal;
772
773template <unsigned int size, class ParentCGTy>
774class thread_block_tile_internal : public thread_block_tile_type<size, ParentCGTy> {
775 protected:
776 template <unsigned int tbtSize, class tbtParentT>
777 __CG_QUALIFIER__ thread_block_tile_internal(
778 const thread_block_tile_internal<tbtSize, tbtParentT>& g)
779 : thread_block_tile_type<size, ParentCGTy>(g.meta_group_rank(), g.meta_group_size()) {}
780
781 __CG_QUALIFIER__ thread_block_tile_internal(const thread_block& g)
782 : thread_block_tile_type<size, ParentCGTy>() {}
783};
784} // namespace impl
785
786template <unsigned int size, class ParentCGTy>
787class thread_block_tile : public impl::thread_block_tile_internal<size, ParentCGTy> {
788 protected:
789 __CG_QUALIFIER__ thread_block_tile(const ParentCGTy& g)
790 : impl::thread_block_tile_internal<size, ParentCGTy>(g) {}
791
792 public:
793 __CG_QUALIFIER__ operator thread_block_tile<size, void>() const {
794 return thread_block_tile<size, void>(*this);
795 }
796};
797
798
799template <unsigned int size>
800class thread_block_tile<size, void> : public impl::thread_block_tile_internal<size, void> {
801 template <unsigned int, class ParentCGTy> friend class thread_block_tile;
802
803 protected:
804 public:
805 template <class ParentCGTy>
806 __CG_QUALIFIER__ thread_block_tile(const thread_block_tile<size, ParentCGTy>& g)
807 : impl::thread_block_tile_internal<size, void>(g) {}
808};
809
810template <unsigned int size, class ParentCGTy = void> class thread_block_tile;
811
812namespace impl {
813template <unsigned int size, class ParentCGTy> struct tiled_partition_internal;
814
815template <unsigned int size>
816struct tiled_partition_internal<size, thread_block> : public thread_block_tile<size, thread_block> {
817 __CG_QUALIFIER__ tiled_partition_internal(const thread_block& g)
818 : thread_block_tile<size, thread_block>(g) {}
819};
820
821} // namespace impl
822
828template <unsigned int size, class ParentCGTy>
829__CG_QUALIFIER__ thread_block_tile<size, ParentCGTy> tiled_partition(const ParentCGTy& g) {
830 static_assert(is_valid_tile_size<size>::value,
831 "Tiled partition with size > wavefront size. Currently not supported ");
832 return impl::tiled_partition_internal<size, ParentCGTy>(g);
833}
834} // namespace cooperative_groups
835
836#endif // __cplusplus
837#endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
Device side implementation of cooperative group feature.