Neko 1.99.1
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
gs_device_shmem.F90
Go to the documentation of this file.
1! Copyright (c) 2020-2024, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
35 use num_types, only : rp, c_rp
36 use gs_comm, only : gs_comm_t
37 use gs_ops
38 use stack, only : stack_i4_t
39 use htable, only : htable_i4_t
40 use device
41 use comm
42 use mpi_f08
43 use utils, only : neko_error
44 use, intrinsic :: iso_c_binding, only : c_sizeof, c_int32_t, &
45 c_ptr, c_null_ptr, c_size_t, c_associated
46 implicit none
47 private
48
50 type, private :: gs_device_shmem_buf_t
51 integer, allocatable :: ndofs(:)
52 integer, allocatable :: offset(:)
53 integer, allocatable :: remote_offset(:)
54 integer :: total
55 type(c_ptr) :: buf_d = c_null_ptr
56 type(c_ptr) :: dof_d = c_null_ptr
57 contains
58 procedure, pass(this) :: init => gs_device_shmem_buf_init
59 procedure, pass(this) :: free => gs_device_shmem_buf_free
61
64 type, public, extends(gs_comm_t) :: gs_device_shmem_t
65 type(gs_device_shmem_buf_t) :: send_buf
66 type(gs_device_shmem_buf_t) :: recv_buf
67 type(c_ptr), allocatable :: stream(:)
68 type(c_ptr), allocatable :: event(:)
69 integer :: nvshmem_counter = 1
70 type(c_ptr), allocatable :: notifydone(:)
71 type(c_ptr), allocatable :: notifyready(:)
72 contains
73 procedure, pass(this) :: init => gs_device_shmem_init
74 procedure, pass(this) :: free => gs_device_shmem_free
75 procedure, pass(this) :: nbsend => gs_device_shmem_nbsend
76 procedure, pass(this) :: nbrecv => gs_device_shmem_nbrecv
77 procedure, pass(this) :: nbwait => gs_device_shmem_nbwait
78 end type gs_device_shmem_t
79
80
81#if defined (HAVE_CUDA) && defined(HAVE_NVSHMEM)
82
83 interface
84 subroutine cudamalloc_nvshmem(ptr, size) &
85 bind(c, name='cudamalloc_nvshmem')
86 use, intrinsic :: iso_c_binding
87 implicit none
88 type(c_ptr) :: ptr
89 integer(c_size_t), value :: size
90 end subroutine cudamalloc_nvshmem
91 end interface
92
93 interface
94 subroutine cudafree_nvshmem(ptr) &
95 bind(c, name='cudafree_nvshmem')
96 use, intrinsic :: iso_c_binding
97 implicit none
98 type(c_ptr) :: ptr
99 end subroutine cudafree_nvshmem
100 end interface
101
102 interface
103 subroutine cuda_gs_pack_and_push(u_d, buf_d, dof_d, offset, n, stream, &
104 srank, rbuf_d, roffset, remote_offset, &
105 rrank, nvshmem_counter, notifyDone, &
106 notifyReady, iter) &
107 bind(c, name='cuda_gs_pack_and_push')
108 use, intrinsic :: iso_c_binding
109 implicit none
110 integer(c_int), value :: n, offset, srank, roffset, rrank, iter
111 integer(c_int), value :: nvshmem_counter
112 type(c_ptr), value :: u_d, buf_d, dof_d, stream, rbuf_d, notifydone, notifyready
113 integer(c_int),dimension(*) :: remote_offset
114 end subroutine cuda_gs_pack_and_push
115 end interface
116
117 interface
118 subroutine cuda_gs_pack_and_push_wait(stream, nvshmem_counter, notifyDone) &
119 bind(c, name='cuda_gs_pack_and_push_wait')
120 use, intrinsic :: iso_c_binding
121 implicit none
122 integer(c_int), value :: nvshmem_counter
123 type(c_ptr), value :: stream, notifydone
124 end subroutine cuda_gs_pack_and_push_wait
125 end interface
126
127 interface
128 subroutine cuda_gs_unpack(u_d, op, buf_d, dof_d, offset, n, stream) &
129 bind(c, name='cuda_gs_unpack')
130 use, intrinsic :: iso_c_binding
131 implicit none
132 integer(c_int), value :: op, offset, n
133 type(c_ptr), value :: u_d, buf_d, dof_d, stream
134 end subroutine cuda_gs_unpack
135 end interface
136#endif
137
138contains
139
140 subroutine gs_device_shmem_buf_init(this, pe_order, dof_stack, mark_dupes)
141 class(gs_device_shmem_buf_t), intent(inout) :: this
142 integer, allocatable, intent(inout) :: pe_order(:)
143 type(stack_i4_t), allocatable, intent(inout) :: dof_stack(:)
144 logical, intent(in) :: mark_dupes
145 integer, allocatable :: dofs(:)
146 integer :: i, j, total, max_total
147 integer(c_size_t) :: sz
148 type(htable_i4_t) :: doftable
149 integer :: dupe, marked, k
150 real(c_rp) :: rp_dummy
151 integer(c_int32_t) :: i4_dummy
152
153 allocate(this%ndofs(size(pe_order)))
154 allocate(this%offset(size(pe_order)))
155 allocate(this%remote_offset(size(pe_order)))
156
157 do i = 1, size(pe_order)
158 this%remote_offset(i)=-1
159 end do
160
161 total = 0
162 do i = 1, size(pe_order)
163 this%ndofs(i) = dof_stack(pe_order(i))%size()
164 this%offset(i) = total
165 total = total + this%ndofs(i)
166 end do
167
168 call mpi_allreduce(total, max_total, 1, mpi_integer, mpi_max, neko_comm)
169
170 this%total = total
171
172 sz = c_sizeof(rp_dummy) * max_total
173#ifdef HAVE_NVSHMEM
174 call cudamalloc_nvshmem(this%buf_d, sz)
175#endif
176
177 sz = c_sizeof(i4_dummy) * total
178 call device_alloc(this%dof_d, sz)
179
180 if (mark_dupes) call doftable%init(2*total)
181 allocate(dofs(total))
182
183 ! Copy from dof_stack into dofs, optionally marking duplicates with doftable
184 marked = 0
185 do i = 1, size(pe_order)
186 ! %array() breaks on cray
187 select type (arr => dof_stack(pe_order(i))%data)
188 type is (integer)
189 do j = 1, this%ndofs(i)
190 k = this%offset(i) + j
191 if (mark_dupes) then
192 if (doftable%get(arr(j), dupe) .eq. 0) then
193 if (dofs(dupe) .gt. 0) then
194 dofs(dupe) = -dofs(dupe)
195 marked = marked + 1
196 end if
197 dofs(k) = -arr(j)
198 marked = marked + 1
199 else
200 call doftable%set(arr(j), k)
201 dofs(k) = arr(j)
202 end if
203 else
204 dofs(k) = arr(j)
205 end if
206 end do
207 end select
208 end do
209
210 call device_memcpy(dofs, this%dof_d, total, host_to_device, sync=.true.)
211
212 deallocate(dofs)
213 call doftable%free()
214
215 end subroutine gs_device_shmem_buf_init
216
218 class(gs_device_shmem_buf_t), intent(inout) :: this
219
220
221 if (allocated(this%ndofs)) deallocate(this%ndofs)
222 if (allocated(this%offset)) deallocate(this%offset)
223
224#ifdef HAVE_NVSHMEM
225 if (c_associated(this%buf_d)) call cudafree_nvshmem(this%buf_d)
226#endif
227 if (c_associated(this%dof_d)) call device_free(this%dof_d)
228
229 end subroutine gs_device_shmem_buf_free
230
232 subroutine gs_device_shmem_init(this, send_pe, recv_pe)
233 class(gs_device_shmem_t), intent(inout) :: this
234 type(stack_i4_t), intent(inout) :: send_pe
235 type(stack_i4_t), intent(inout) :: recv_pe
236 integer :: i
237
238 call this%init_order(send_pe, recv_pe)
239
240 call this%send_buf%init(this%send_pe, this%send_dof, .false.)
241 call this%recv_buf%init(this%recv_pe, this%recv_dof, .true.)
242
243#if defined(HAVE_HIP) || defined(HAVE_CUDA)
244 ! Create a set of non-blocking streams
245 allocate(this%stream(size(this%recv_pe)))
246 do i = 1, size(this%recv_pe)
248 end do
249
250 allocate(this%event(size(this%recv_pe)))
251 do i = 1, size(this%recv_pe)
252 call device_event_create(this%event(i), 2)
253 end do
254
255#ifdef HAVE_NVSHMEM
256 allocate(this%notifyDone(size(this%recv_pe)))
257 allocate(this%notifyReady(size(this%recv_pe)))
258 do i = 1, size(this%recv_pe)
259 call cudamalloc_nvshmem(this%notifyDone(i), 8_8)
260 call cudamalloc_nvshmem(this%notifyReady(i), 8_8)
261 end do
262#endif
263#endif
264
265 end subroutine gs_device_shmem_init
266
268 subroutine gs_device_shmem_free(this)
269 class(gs_device_shmem_t), intent(inout) :: this
270 integer :: i
271
272 call this%send_buf%free()
273 call this%recv_buf%free()
274
275 call this%free_order()
276 call this%free_dofs()
277
278#if defined(HAVE_HIP) || defined(HAVE_CUDA)
279 if (allocated(this%stream)) then
280 do i = 1, size(this%stream)
281 call device_stream_destroy(this%stream(i))
282 end do
283 deallocate(this%stream)
284 end if
285#endif
286
287 end subroutine gs_device_shmem_free
288
290 subroutine gs_device_shmem_nbsend(this, u, n, deps, strm)
291 class(gs_device_shmem_t), intent(inout) :: this
292 integer, intent(in) :: n
293 real(kind=rp), dimension(n), intent(inout) :: u
294 type(c_ptr), intent(inout) :: deps
295 type(c_ptr), intent(inout) :: strm
296 integer :: i
297 type(c_ptr) :: u_d
298
299 u_d = device_get_ptr(u)
300
301 do i = 1, size(this%send_pe)
302 call device_stream_wait_event(this%stream(i), deps, 0)
303 ! Not clear why this sync is required, but there seems to be a race condition
304 ! without it for certain run configs
305 call device_sync(this%stream(i))
306 end do
307
308 ! We do the rest in the "wait" routine below
309
310 end subroutine gs_device_shmem_nbsend
311
313 subroutine gs_device_shmem_nbrecv(this)
314 class(gs_device_shmem_t), intent(inout) :: this
315 integer :: i
316
317 ! We do everything in the "wait" routine below
318
319 end subroutine gs_device_shmem_nbrecv
320
322 subroutine gs_device_shmem_nbwait(this, u, n, op, strm)
323 class(gs_device_shmem_t), intent(inout) :: this
324 integer, intent(in) :: n
325 real(kind=rp), dimension(n), intent(inout) :: u
326 type(c_ptr), intent(inout) :: strm
327 integer :: op, done_req, i
328 type(c_ptr) :: u_d
329
330 u_d = device_get_ptr(u)
331#ifdef HAVE_NVSHMEM
332 do i = 1, size(this%send_pe)
333 if (this%recv_buf%remote_offset(i) .eq. -1) then
334 call mpi_sendrecv(this%recv_buf%offset(i), 1, mpi_integer, &
335 this%recv_pe(i), 0, &
336 this%recv_buf%remote_offset(i), 1, mpi_integer, &
337 this%send_pe(i), 0, neko_comm, mpi_status_ignore)
338 end if
339
340 call cuda_gs_pack_and_push(u_d, &
341 this%send_buf%buf_d, &
342 this%send_buf%dof_d, &
343 this%send_buf%offset(i), &
344 this%send_buf%ndofs(i), &
345 this%stream(i), &
346 this%send_pe(i), &
347 this%recv_buf%buf_d, &
348 this%recv_buf%offset(i), &
349 this%recv_buf%remote_offset, &
350 this%recv_pe(i), &
351 this%nvshmem_counter, &
352 this%notifyDone(i), &
353 this%notifyReady(i), &
354 i)
355 this%nvshmem_counter = this%nvshmem_counter + 1
356 end do
357
358 do i = 1, size(this%send_pe)
359 call cuda_gs_pack_and_push_wait(this%stream(i), &
360 this%nvshmem_counter - size(this%send_pe) + i - 1, &
361 this%notifyDone(i))
362 end do
363
364 do done_req = 1, size(this%recv_pe)
365 call cuda_gs_unpack(u_d, op, &
366 this%recv_buf%buf_d, &
367 this%recv_buf%dof_d, &
368 this%recv_buf%offset(done_req), &
369 this%recv_buf%ndofs(done_req), &
370 this%stream(done_req))
371 call device_event_record(this%event(done_req), this%stream(done_req))
372 end do
373
374 ! Sync non-blocking streams
375 do done_req = 1, size(this%recv_pe)
376 call device_stream_wait_event(strm, &
377 this%event(done_req), 0)
378 end do
379#endif
380 end subroutine gs_device_shmem_nbwait
381
382end module gs_device_shmem
void cuda_gs_unpack(real *u_d, int op, real *buf_d, int *dof_d, int offset, int n, cudaStream_t stream)
Definition gs.cu:132
Return the device pointer for an associated Fortran array.
Definition device.F90:96
Copy data between host and device (or device and device)
Definition device.F90:66
Synchronize a device or stream.
Definition device.F90:102
Definition comm.F90:1
type(mpi_comm), public neko_comm
MPI communicator.
Definition comm.F90:42
Device abstraction, common interface for various accelerators.
Definition device.F90:34
subroutine, public device_event_record(event, stream)
Record a device event.
Definition device.F90:1290
integer, parameter, public host_to_device
Definition device.F90:47
subroutine, public device_free(x_d)
Deallocate memory on the device.
Definition device.F90:214
subroutine, public device_alloc(x_d, s)
Allocate memory on the device.
Definition device.F90:187
subroutine, public device_stream_create_with_priority(stream, flags, prio)
Create a device stream/command queue with priority.
Definition device.F90:1168
subroutine, public device_stream_wait_event(stream, event, flags)
Synchronize a device stream with an event.
Definition device.F90:1203
subroutine, public device_event_create(event, flags)
Create a device event queue.
Definition device.F90:1244
integer, public strm_high_prio
High priority stream setting.
Definition device.F90:60
subroutine, public device_stream_destroy(stream)
Destroy a device stream/command queue.
Definition device.F90:1185
Defines a gather-scatter communication method.
Definition gs_comm.f90:34
Defines GPU aware MPI gather-scatter communication.
subroutine gs_device_shmem_nbsend(this, u, n, deps, strm)
Post non-blocking send operations.
subroutine gs_device_shmem_nbrecv(this)
Post non-blocking receive operations.
subroutine gs_device_shmem_nbwait(this, u, n, op, strm)
Wait for non-blocking operations.
subroutine gs_device_shmem_buf_init(this, pe_order, dof_stack, mark_dupes)
subroutine gs_device_shmem_free(this)
Deallocate MPI based communication method.
subroutine gs_device_shmem_buf_free(this)
subroutine gs_device_shmem_init(this, send_pe, recv_pe)
Initialise MPI based communication method.
Defines Gather-scatter operations.
Definition gs_ops.f90:34
Implements a hash table ADT.
Definition htable.f90:36
integer, parameter, public c_rp
Definition num_types.f90:13
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Implements a dynamic stack ADT.
Definition stack.f90:35
Utilities.
Definition utils.f90:35
Gather-scatter communication method.
Definition gs_comm.f90:46
Buffers for non-blocking communication and packing/unpacking.
Gather-scatter communication using device SHMEM. The arrays are indexed per PE like send_pe and @ rec...
Integer based hash table.
Definition htable.f90:82
Integer based stack.
Definition stack.f90:63