Actual source code: curand2.cu

  1: #include <petsc/private/randomimpl.h>
  2: #include <thrust/transform.h>
  3: #include <thrust/device_ptr.h>
  4: #include <thrust/iterator/counting_iterator.h>

  6: #if defined(PETSC_USE_COMPLEX)
  7: struct complexscalelw
  8:   #if PETSC_PKG_CUDA_VERSION_LT(12, 8, 0)
  9:   :
 10:   public thrust::unary_function<thrust::tuple<PetscReal, size_t>, PetscReal>
 11:   #endif
 12: {
 13:   PetscReal rl, rw;
 14:   PetscReal il, iw;

 16:   complexscalelw(PetscScalar low, PetscScalar width)
 17:   {
 18:     rl = PetscRealPart(low);
 19:     il = PetscImaginaryPart(low);
 20:     rw = PetscRealPart(width);
 21:     iw = PetscImaginaryPart(width);
 22:   }

 24:   __host__ __device__ PetscReal operator()(thrust::tuple<PetscReal, size_t> x) { return thrust::get<1>(x) % 2 ? thrust::get<0>(x) * iw + il : thrust::get<0>(x) * rw + rl; }
 25: };
 26: #endif

 28: struct realscalelw
 29: #if PETSC_PKG_CUDA_VERSION_LT(12, 8, 0) // To suppress the warning "thrust::THRUST_200700_860_NS::unary_function is deprecated"
 30:   :
 31:   public thrust::unary_function<PetscReal, PetscReal>
 32: #endif
 33: {
 34:   PetscReal l, w;

 36:   realscalelw(PetscReal low, PetscReal width) : l(low), w(width) { }

 38:   __host__ __device__ PetscReal operator()(PetscReal x) { return x * w + l; }
 39: };

 41: PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom r, size_t n, PetscReal *val, PetscBool isneg)
 42: {
 43:   PetscFunctionBegin;
 44:   if (!r->iset) PetscFunctionReturn(PETSC_SUCCESS);
 45:   if (isneg) { /* complex case, need to scale differently */
 46: #if defined(PETSC_USE_COMPLEX)
 47:     thrust::device_ptr<PetscReal> pval  = thrust::device_pointer_cast(val);
 48:     auto                          zibit = thrust::make_zip_iterator(thrust::make_tuple(pval, thrust::counting_iterator<size_t>(0)));
 49:     thrust::transform(zibit, zibit + n, pval, complexscalelw(r->low, r->width));
 50: #else
 51:     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Negative array size %" PetscInt_FMT, (PetscInt)n);
 52: #endif
 53:   } else {
 54:     PetscReal                     rl   = PetscRealPart(r->low);
 55:     PetscReal                     rw   = PetscRealPart(r->width);
 56:     thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val);
 57:     thrust::transform(pval, pval + n, pval, realscalelw(rl, rw));
 58:   }
 59:   PetscFunctionReturn(PETSC_SUCCESS);
 60: }