31#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_HELPER_H
32#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_HELPER_H
35#if !defined(__HIPCC_RTC__)
36#include <hip/amd_detail/amd_hip_runtime.h>
37#include <hip/amd_detail/amd_device_functions.h>
39#if !defined(__align__)
40#define __align__(x) __attribute__((aligned(x)))
43#if !defined(__CG_QUALIFIER__)
44#define __CG_QUALIFIER__ __device__ __forceinline__
47#if !defined(__CG_STATIC_QUALIFIER__)
48#define __CG_STATIC_QUALIFIER__ __device__ static __forceinline__
51#if !defined(_CG_STATIC_CONST_DECL_)
52#define _CG_STATIC_CONST_DECL_ static constexpr
55#if __AMDGCN_WAVEFRONT_SIZE == 32
56using lane_mask =
unsigned int;
58using lane_mask =
unsigned long long int;
61namespace cooperative_groups {
64template <
unsigned int size>
65using is_power_of_2 = std::integral_constant<bool, (size & (size - 1)) == 0>;
67template <
unsigned int size>
68using is_valid_wavefront = std::integral_constant<bool, (size <= __AMDGCN_WAVEFRONT_SIZE)>;
70template <
unsigned int size>
71using is_valid_tile_size =
72 std::integral_constant<bool, is_power_of_2<size>::value && is_valid_wavefront<size>::value>;
76 std::integral_constant<bool, std::is_integral<T>::value || std::is_floating_point<T>::value>;
111namespace multi_grid {
113__CG_STATIC_QUALIFIER__ uint32_t num_grids() {
114 return static_cast<uint32_t
>(__ockl_multi_grid_num_grids()); }
116__CG_STATIC_QUALIFIER__ uint32_t grid_rank() {
117 return static_cast<uint32_t
>(__ockl_multi_grid_grid_rank()); }
119__CG_STATIC_QUALIFIER__ uint32_t size() {
return static_cast<uint32_t
>(__ockl_multi_grid_size()); }
121__CG_STATIC_QUALIFIER__ uint32_t thread_rank() {
122 return static_cast<uint32_t
>(__ockl_multi_grid_thread_rank()); }
124__CG_STATIC_QUALIFIER__
bool is_valid() {
return static_cast<bool>(__ockl_multi_grid_is_valid()); }
126__CG_STATIC_QUALIFIER__
void sync() { __ockl_multi_grid_sync(); }
136__CG_STATIC_QUALIFIER__ uint32_t size() {
137 return static_cast<uint32_t
>((blockDim.z * gridDim.z) * (blockDim.y * gridDim.y) *
138 (blockDim.x * gridDim.x));
141__CG_STATIC_QUALIFIER__ uint32_t thread_rank() {
143 uint32_t blkIdx =
static_cast<uint32_t
>((blockIdx.z * gridDim.y * gridDim.x) +
144 (blockIdx.y * gridDim.x) + (blockIdx.x));
148 uint32_t num_threads_till_current_workgroup =
149 static_cast<uint32_t
>(blkIdx * (blockDim.x * blockDim.y * blockDim.z));
152 uint32_t local_thread_rank =
static_cast<uint32_t
>((threadIdx.z * blockDim.y * blockDim.x) +
153 (threadIdx.y * blockDim.x) + (threadIdx.x));
155 return (num_threads_till_current_workgroup + local_thread_rank);
158__CG_STATIC_QUALIFIER__
bool is_valid() {
return static_cast<bool>(__ockl_grid_is_valid()); }
160__CG_STATIC_QUALIFIER__
void sync() { __ockl_grid_sync(); }
171__CG_STATIC_QUALIFIER__ dim3 group_index() {
172 return (dim3(
static_cast<uint32_t
>(blockIdx.x),
static_cast<uint32_t
>(blockIdx.y),
173 static_cast<uint32_t
>(blockIdx.z)));
176__CG_STATIC_QUALIFIER__ dim3 thread_index() {
177 return (dim3(
static_cast<uint32_t
>(threadIdx.x),
static_cast<uint32_t
>(threadIdx.y),
178 static_cast<uint32_t
>(threadIdx.z)));
181__CG_STATIC_QUALIFIER__ uint32_t size() {
182 return (
static_cast<uint32_t
>(blockDim.x * blockDim.y * blockDim.z));
185__CG_STATIC_QUALIFIER__ uint32_t thread_rank() {
186 return (
static_cast<uint32_t
>((threadIdx.z * blockDim.y * blockDim.x) +
187 (threadIdx.y * blockDim.x) + (threadIdx.x)));
190__CG_STATIC_QUALIFIER__
bool is_valid() {
194__CG_STATIC_QUALIFIER__
void sync() { __syncthreads(); }
196__CG_STATIC_QUALIFIER__ dim3 block_dim() {
197 return (dim3(
static_cast<uint32_t
>(blockDim.x),
static_cast<uint32_t
>(blockDim.y),
198 static_cast<uint32_t
>(blockDim.z)));
203namespace tiled_group {
206__CG_STATIC_QUALIFIER__
void sync() { __builtin_amdgcn_fence(__ATOMIC_ACQ_REL,
"agent"); }
210namespace coalesced_group {
213__CG_STATIC_QUALIFIER__
void sync() { __builtin_amdgcn_fence(__ATOMIC_ACQ_REL,
"agent"); }
219__CG_STATIC_QUALIFIER__
unsigned int masked_bit_count(lane_mask x,
unsigned int add = 0) {
220 unsigned int counter=0;
221 #if __AMDGCN_WAVEFRONT_SIZE == 32
222 counter = __builtin_amdgcn_mbcnt_lo(x, add);
224 counter = __builtin_amdgcn_mbcnt_lo(
static_cast<lane_mask
>(x), add);
225 counter = __builtin_amdgcn_mbcnt_hi(
static_cast<lane_mask
>(x >> 32), counter);