Neko 1.99.2
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
gs_device_shmem.F90
Go to the documentation of this file.
1! Copyright (c) 2020-2024, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
35 use num_types, only : rp, c_rp
36 use gs_comm, only : gs_comm_t
37 use gs_ops
38 use stack, only : stack_i4_t
39 use htable, only : htable_i4_t
40 use device
41 use comm, only : pe_size, pe_rank, neko_comm
42 use mpi_f08, only : mpi_allreduce, mpi_integer, &
43 mpi_max, mpi_sendrecv, mpi_status_ignore
44 use utils, only : neko_error
45 use, intrinsic :: iso_c_binding, only : c_sizeof, c_int32_t, &
46 c_ptr, c_null_ptr, c_size_t, c_associated
47 implicit none
48 private
49
51 type, private :: gs_device_shmem_buf_t
52 integer, allocatable :: ndofs(:)
53 integer, allocatable :: offset(:)
54 integer, allocatable :: remote_offset(:)
55 integer :: total
56 type(c_ptr) :: buf_d = c_null_ptr
57 type(c_ptr) :: dof_d = c_null_ptr
58 contains
59 procedure, pass(this) :: init => gs_device_shmem_buf_init
60 procedure, pass(this) :: free => gs_device_shmem_buf_free
62
65 type, public, extends(gs_comm_t) :: gs_device_shmem_t
66 type(gs_device_shmem_buf_t) :: send_buf
67 type(gs_device_shmem_buf_t) :: recv_buf
68 type(c_ptr), allocatable :: stream(:)
69 type(c_ptr), allocatable :: event(:)
70 integer :: nvshmem_counter = 1
71 type(c_ptr), allocatable :: notifydone(:)
72 type(c_ptr), allocatable :: notifyready(:)
73 contains
74 procedure, pass(this) :: init => gs_device_shmem_init
75 procedure, pass(this) :: free => gs_device_shmem_free
76 procedure, pass(this) :: nbsend => gs_device_shmem_nbsend
77 procedure, pass(this) :: nbrecv => gs_device_shmem_nbrecv
78 procedure, pass(this) :: nbwait => gs_device_shmem_nbwait
79 end type gs_device_shmem_t
80
81
82#if defined (HAVE_CUDA) && defined(HAVE_NVSHMEM)
83
84 interface
85 subroutine cudamalloc_nvshmem(ptr, size) &
86 bind(c, name='cudamalloc_nvshmem')
87 use, intrinsic :: iso_c_binding
88 implicit none
89 type(c_ptr) :: ptr
90 integer(c_size_t), value :: size
91 end subroutine cudamalloc_nvshmem
92 end interface
93
94 interface
95 subroutine cudafree_nvshmem(ptr) &
96 bind(c, name='cudafree_nvshmem')
97 use, intrinsic :: iso_c_binding
98 implicit none
99 type(c_ptr) :: ptr
100 end subroutine cudafree_nvshmem
101 end interface
102
103 interface
104 subroutine cuda_gs_pack_and_push(u_d, buf_d, dof_d, offset, n, stream, &
105 srank, rbuf_d, roffset, remote_offset, &
106 rrank, nvshmem_counter, notifyDone, &
107 notifyReady, iter) &
108 bind(c, name='cuda_gs_pack_and_push')
109 use, intrinsic :: iso_c_binding
110 implicit none
111 integer(c_int), value :: n, offset, srank, roffset, rrank, iter
112 integer(c_int), value :: nvshmem_counter
113 type(c_ptr), value :: u_d, buf_d, dof_d, stream, rbuf_d, notifydone, notifyready
114 integer(c_int),dimension(*) :: remote_offset
115 end subroutine cuda_gs_pack_and_push
116 end interface
117
118 interface
119 subroutine cuda_gs_pack_and_push_wait(stream, nvshmem_counter, notifyDone) &
120 bind(c, name='cuda_gs_pack_and_push_wait')
121 use, intrinsic :: iso_c_binding
122 implicit none
123 integer(c_int), value :: nvshmem_counter
124 type(c_ptr), value :: stream, notifydone
125 end subroutine cuda_gs_pack_and_push_wait
126 end interface
127
128 interface
129 subroutine cuda_gs_unpack(u_d, op, buf_d, dof_d, offset, n, stream) &
130 bind(c, name='cuda_gs_unpack')
131 use, intrinsic :: iso_c_binding
132 implicit none
133 integer(c_int), value :: op, offset, n
134 type(c_ptr), value :: u_d, buf_d, dof_d, stream
135 end subroutine cuda_gs_unpack
136 end interface
137#endif
138
139contains
140
141 subroutine gs_device_shmem_buf_init(this, pe_order, dof_stack, mark_dupes)
142 class(gs_device_shmem_buf_t), intent(inout) :: this
143 integer, allocatable, intent(inout) :: pe_order(:)
144 type(stack_i4_t), allocatable, intent(inout) :: dof_stack(:)
145 logical, intent(in) :: mark_dupes
146 integer, allocatable :: dofs(:)
147 integer :: i, j, total, max_total
148 integer(c_size_t) :: sz
149 type(htable_i4_t) :: doftable
150 integer :: dupe, marked, k
151 real(c_rp) :: rp_dummy
152 integer(c_int32_t) :: i4_dummy
153
154 allocate(this%ndofs(size(pe_order)))
155 allocate(this%offset(size(pe_order)))
156 allocate(this%remote_offset(size(pe_order)))
157
158 do i = 1, size(pe_order)
159 this%remote_offset(i)=-1
160 end do
161
162 total = 0
163 do i = 1, size(pe_order)
164 this%ndofs(i) = dof_stack(pe_order(i))%size()
165 this%offset(i) = total
166 total = total + this%ndofs(i)
167 end do
168
169 call mpi_allreduce(total, max_total, 1, mpi_integer, mpi_max, neko_comm)
170
171 this%total = total
172
173 sz = c_sizeof(rp_dummy) * max_total
174#ifdef HAVE_NVSHMEM
175 call cudamalloc_nvshmem(this%buf_d, sz)
176#endif
177
178 sz = c_sizeof(i4_dummy) * total
179 call device_alloc(this%dof_d, sz)
180
181 if (mark_dupes) call doftable%init(2*total)
182 allocate(dofs(total))
183
184 ! Copy from dof_stack into dofs, optionally marking duplicates with doftable
185 marked = 0
186 do i = 1, size(pe_order)
187 ! %array() breaks on cray
188 select type (arr => dof_stack(pe_order(i))%data)
189 type is (integer)
190 do j = 1, this%ndofs(i)
191 k = this%offset(i) + j
192 if (mark_dupes) then
193 if (doftable%get(arr(j), dupe) .eq. 0) then
194 if (dofs(dupe) .gt. 0) then
195 dofs(dupe) = -dofs(dupe)
196 marked = marked + 1
197 end if
198 dofs(k) = -arr(j)
199 marked = marked + 1
200 else
201 call doftable%set(arr(j), k)
202 dofs(k) = arr(j)
203 end if
204 else
205 dofs(k) = arr(j)
206 end if
207 end do
208 end select
209 end do
210
211 call device_memcpy(dofs, this%dof_d, total, host_to_device, sync=.true.)
212
213 deallocate(dofs)
214 call doftable%free()
215
216 end subroutine gs_device_shmem_buf_init
217
219 class(gs_device_shmem_buf_t), intent(inout) :: this
220
221
222 if (allocated(this%ndofs)) deallocate(this%ndofs)
223 if (allocated(this%offset)) deallocate(this%offset)
224
225#ifdef HAVE_NVSHMEM
226 if (c_associated(this%buf_d)) call cudafree_nvshmem(this%buf_d)
227#endif
228 if (c_associated(this%dof_d)) call device_free(this%dof_d)
229
230 end subroutine gs_device_shmem_buf_free
231
233 subroutine gs_device_shmem_init(this, send_pe, recv_pe)
234 class(gs_device_shmem_t), intent(inout) :: this
235 type(stack_i4_t), intent(inout) :: send_pe
236 type(stack_i4_t), intent(inout) :: recv_pe
237 integer :: i
238
239 call this%init_order(send_pe, recv_pe)
240
241 call this%send_buf%init(this%send_pe, this%send_dof, .false.)
242 call this%recv_buf%init(this%recv_pe, this%recv_dof, .true.)
243
244#if defined(HAVE_HIP) || defined(HAVE_CUDA)
245 ! Create a set of non-blocking streams
246 allocate(this%stream(size(this%recv_pe)))
247 do i = 1, size(this%recv_pe)
249 end do
250
251 allocate(this%event(size(this%recv_pe)))
252 do i = 1, size(this%recv_pe)
253 call device_event_create(this%event(i), 2)
254 end do
255
256#ifdef HAVE_NVSHMEM
257 allocate(this%notifyDone(size(this%recv_pe)))
258 allocate(this%notifyReady(size(this%recv_pe)))
259 do i = 1, size(this%recv_pe)
260 call cudamalloc_nvshmem(this%notifyDone(i), 8_8)
261 call cudamalloc_nvshmem(this%notifyReady(i), 8_8)
262 end do
263#endif
264#endif
265
266 end subroutine gs_device_shmem_init
267
269 subroutine gs_device_shmem_free(this)
270 class(gs_device_shmem_t), intent(inout) :: this
271 integer :: i
272
273 call this%send_buf%free()
274 call this%recv_buf%free()
275
276 call this%free_order()
277 call this%free_dofs()
278
279#if defined(HAVE_HIP) || defined(HAVE_CUDA)
280 if (allocated(this%stream)) then
281 do i = 1, size(this%stream)
282 call device_stream_destroy(this%stream(i))
283 end do
284 deallocate(this%stream)
285 end if
286#endif
287
288 end subroutine gs_device_shmem_free
289
291 subroutine gs_device_shmem_nbsend(this, u, n, deps, strm)
292 class(gs_device_shmem_t), intent(inout) :: this
293 integer, intent(in) :: n
294 real(kind=rp), dimension(n), intent(inout) :: u
295 type(c_ptr), intent(inout) :: deps
296 type(c_ptr), intent(inout) :: strm
297 integer :: i
298 type(c_ptr) :: u_d
299
300 u_d = device_get_ptr(u)
301
302 do i = 1, size(this%send_pe)
303 call device_stream_wait_event(this%stream(i), deps, 0)
304 ! Not clear why this sync is required, but there seems to be a race condition
305 ! without it for certain run configs
306 call device_sync(this%stream(i))
307 end do
308
309 ! We do the rest in the "wait" routine below
310
311 end subroutine gs_device_shmem_nbsend
312
314 subroutine gs_device_shmem_nbrecv(this)
315 class(gs_device_shmem_t), intent(inout) :: this
316 integer :: i
317
318 ! We do everything in the "wait" routine below
319
320 end subroutine gs_device_shmem_nbrecv
321
323 subroutine gs_device_shmem_nbwait(this, u, n, op, strm)
324 class(gs_device_shmem_t), intent(inout) :: this
325 integer, intent(in) :: n
326 real(kind=rp), dimension(n), intent(inout) :: u
327 type(c_ptr), intent(inout) :: strm
328 integer :: op, done_req, i
329 type(c_ptr) :: u_d
330
331 u_d = device_get_ptr(u)
332#ifdef HAVE_NVSHMEM
333 do i = 1, size(this%send_pe)
334 if (this%recv_buf%remote_offset(i) .eq. -1) then
335 call mpi_sendrecv(this%recv_buf%offset(i), 1, mpi_integer, &
336 this%recv_pe(i), 0, &
337 this%recv_buf%remote_offset(i), 1, mpi_integer, &
338 this%send_pe(i), 0, neko_comm, mpi_status_ignore)
339 end if
340
341 call cuda_gs_pack_and_push(u_d, &
342 this%send_buf%buf_d, &
343 this%send_buf%dof_d, &
344 this%send_buf%offset(i), &
345 this%send_buf%ndofs(i), &
346 this%stream(i), &
347 this%send_pe(i), &
348 this%recv_buf%buf_d, &
349 this%recv_buf%offset(i), &
350 this%recv_buf%remote_offset, &
351 this%recv_pe(i), &
352 this%nvshmem_counter, &
353 this%notifyDone(i), &
354 this%notifyReady(i), &
355 i)
356 this%nvshmem_counter = this%nvshmem_counter + 1
357 end do
358
359 do i = 1, size(this%send_pe)
360 call cuda_gs_pack_and_push_wait(this%stream(i), &
361 this%nvshmem_counter - size(this%send_pe) + i - 1, &
362 this%notifyDone(i))
363 end do
364
365 do done_req = 1, size(this%recv_pe)
366 call cuda_gs_unpack(u_d, op, &
367 this%recv_buf%buf_d, &
368 this%recv_buf%dof_d, &
369 this%recv_buf%offset(done_req), &
370 this%recv_buf%ndofs(done_req), &
371 this%stream(done_req))
372 call device_event_record(this%event(done_req), this%stream(done_req))
373 end do
374
375 ! Sync non-blocking streams
376 do done_req = 1, size(this%recv_pe)
377 call device_stream_wait_event(strm, &
378 this%event(done_req), 0)
379 end do
380#endif
381 end subroutine gs_device_shmem_nbwait
382
383end module gs_device_shmem
void cuda_gs_unpack(real *u_d, int op, real *buf_d, int *dof_d, int offset, int n, cudaStream_t stream)
Definition gs.cu:132
Return the device pointer for an associated Fortran array.
Definition device.F90:101
Copy data between host and device (or device and device)
Definition device.F90:71
Synchronize a device or stream.
Definition device.F90:107
Definition comm.F90:1
integer, public pe_size
MPI size of communicator.
Definition comm.F90:59
integer, public pe_rank
MPI rank.
Definition comm.F90:56
type(mpi_comm), public neko_comm
MPI communicator.
Definition comm.F90:43
Device abstraction, common interface for various accelerators.
Definition device.F90:34
subroutine, public device_event_record(event, stream)
Record a device event.
Definition device.F90:1295
integer, parameter, public host_to_device
Definition device.F90:47
subroutine, public device_free(x_d)
Deallocate memory on the device.
Definition device.F90:219
subroutine, public device_alloc(x_d, s)
Allocate memory on the device.
Definition device.F90:192
subroutine, public device_stream_create_with_priority(stream, flags, prio)
Create a device stream/command queue with priority.
Definition device.F90:1173
subroutine, public device_stream_wait_event(stream, event, flags)
Synchronize a device stream with an event.
Definition device.F90:1208
subroutine, public device_event_create(event, flags)
Create a device event queue.
Definition device.F90:1249
integer, public strm_high_prio
High priority stream setting.
Definition device.F90:65
subroutine, public device_stream_destroy(stream)
Destroy a device stream/command queue.
Definition device.F90:1190
Defines a gather-scatter communication method.
Definition gs_comm.f90:34
Defines GPU aware MPI gather-scatter communication.
subroutine gs_device_shmem_nbsend(this, u, n, deps, strm)
Post non-blocking send operations.
subroutine gs_device_shmem_nbrecv(this)
Post non-blocking receive operations.
subroutine gs_device_shmem_nbwait(this, u, n, op, strm)
Wait for non-blocking operations.
subroutine gs_device_shmem_buf_init(this, pe_order, dof_stack, mark_dupes)
subroutine gs_device_shmem_free(this)
Deallocate MPI based communication method.
subroutine gs_device_shmem_buf_free(this)
subroutine gs_device_shmem_init(this, send_pe, recv_pe)
Initialise MPI based communication method.
Defines Gather-scatter operations.
Definition gs_ops.f90:34
Implements a hash table ADT.
Definition htable.f90:36
integer, parameter, public c_rp
Definition num_types.f90:13
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Implements a dynamic stack ADT.
Definition stack.f90:35
Utilities.
Definition utils.f90:35
Gather-scatter communication method.
Definition gs_comm.f90:46
Buffers for non-blocking communication and packing/unpacking.
Gather-scatter communication using device SHMEM. The arrays are indexed per PE like send_pe and @ rec...
Integer based hash table.
Definition htable.f90:82
Integer based stack.
Definition stack.f90:63