Actual source code: ex23.c

  1: static const char help[] = "Test PetscSF with integers and MPIU_2INT \n\n";

  3: #include <petscvec.h>
  4: #include <petscsf.h>
  5: #include <petscdevice.h>

  7: int main(int argc, char *argv[])
  8: {
  9:   PetscInt           n, n2, N = 12;
 10:   PetscInt          *indices;
 11:   IS                 ix, iy;
 12:   VecScatter         vscat;
 13:   Vec                x, y;
 14:   PetscInt           rstart, rend;
 15:   PetscInt          *xh, *yh, *xd, *yd;
 16:   PetscDeviceContext dctx;

 18:   PetscFunctionBeginUser;
 19:   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
 20:   PetscCall(VecCreateFromOptions(PETSC_COMM_WORLD, NULL, 1, PETSC_DECIDE, N, &x));
 21:   PetscCall(VecDuplicate(x, &y));
 22:   PetscCall(VecGetLocalSize(x, &n));

 24:   PetscCall(VecGetOwnershipRange(x, &rstart, &rend));
 25:   PetscCall(ISCreateStride(PETSC_COMM_WORLD, n, rstart, 1, &ix));
 26:   PetscCall(PetscMalloc1(n, &indices));
 27:   for (int i = rstart; i < rend; i++) indices[i - rstart] = i / 2;
 28:   PetscCall(ISCreateGeneral(PETSC_COMM_WORLD, n, indices, PETSC_OWN_POINTER, &iy));
 29:   // connect y[0] to x[0..1], y[1] to x[2..3], etc
 30:   PetscCall(VecScatterCreate(y, iy, x, ix, &vscat)); // y has roots, x has leaves

 32:   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));

 34:   // double the allocation since we will use MPIU_2INT later
 35:   n2 = 2 * n;
 36:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_HOST, n2, &xh));
 37:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_HOST, n2, &yh));
 38:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_DEVICE, n2, &xd));
 39:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_DEVICE, n2, &yd));

 41:   for (PetscInt i = 0; i < n; i++) {
 42:     xh[i] = xh[i + n] = i + rstart;
 43:     yh[i] = yh[i + n] = i + rstart;
 44:   }
 45:   PetscCall(PetscDeviceMemcpy(dctx, xd, xh, sizeof(PetscInt) * n2));
 46:   PetscCall(PetscDeviceMemcpy(dctx, yd, yh, sizeof(PetscInt) * n2));

 48:   PetscCall(PetscSFReduceWithMemTypeBegin(vscat, MPIU_INT, PETSC_MEMTYPE_DEVICE, xd, PETSC_MEMTYPE_DEVICE, yd, MPI_SUM));
 49:   PetscCall(PetscSFReduceEnd(vscat, MPIU_INT, xd, yd, MPI_SUM));
 50:   PetscCall(PetscDeviceMemcpy(dctx, yh, yd, sizeof(PetscInt) * n));
 51:   PetscCall(PetscDeviceContextSynchronize(dctx)); // finish the async memcpy
 52:   PetscCall(PetscIntView(n, yh, PETSC_VIEWER_STDOUT_WORLD));

 54:   PetscCall(PetscSFBcastWithMemTypeBegin(vscat, MPIU_2INT, PETSC_MEMTYPE_DEVICE, yd, PETSC_MEMTYPE_DEVICE, xd, MPI_MINLOC));
 55:   PetscCall(PetscSFBcastEnd(vscat, MPIU_2INT, yd, xd, MPI_MINLOC));
 56:   PetscCall(PetscDeviceMemcpy(dctx, xh, xd, sizeof(PetscInt) * n2));
 57:   PetscCall(PetscDeviceContextSynchronize(dctx)); // finish the async memcpy
 58:   PetscCall(PetscIntView(n2, xh, PETSC_VIEWER_STDOUT_WORLD));

 60:   PetscCall(PetscDeviceFree(dctx, xh));
 61:   PetscCall(PetscDeviceFree(dctx, yh));
 62:   PetscCall(PetscDeviceFree(dctx, xd));
 63:   PetscCall(PetscDeviceFree(dctx, yd));
 64:   PetscCall(ISDestroy(&ix));
 65:   PetscCall(ISDestroy(&iy));
 66:   PetscCall(VecDestroy(&x));
 67:   PetscCall(VecDestroy(&y));
 68:   PetscCall(VecScatterDestroy(&vscat));
 69:   PetscCall(PetscFinalize());
 70: }

 72: /*TEST
 73:   testset:
 74:     output_file: output/ex23.out
 75:     nsize: 3

 77:     test:
 78:       suffix: 1
 79:       requires: cuda

 81:     test:
 82:       suffix: 2
 83:       requires: hip

 85:     test:
 86:       suffix: 3
 87:       requires: sycl

 89: TEST*/