Neko 1.99.1
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
gs_device_mpi.F90
Go to the documentation of this file.
1! Copyright (c) 2020-2025, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
35 use num_types, only : rp, c_rp
36 use gs_comm, only : gs_comm_t
37 use stack, only : stack_i4_t
38 use comm, only : pe_size, pe_rank
39 use htable, only : htable_i4_t
45 use utils, only : neko_error
46 use, intrinsic :: iso_c_binding, only : c_sizeof, c_int32_t, &
47 c_ptr, c_null_ptr, c_size_t, c_associated
48 implicit none
49 private
50
52 type, private :: gs_device_mpi_buf_t
53 integer, allocatable :: ndofs(:)
54 integer, allocatable :: offset(:)
55 integer :: total
56 type(c_ptr) :: reqs = c_null_ptr
57 type(c_ptr) :: buf_d = c_null_ptr
58 type(c_ptr) :: dof_d = c_null_ptr
59 contains
60 procedure, pass(this) :: init => gs_device_mpi_buf_init
61 procedure, pass(this) :: free => gs_device_mpi_buf_free
62 end type gs_device_mpi_buf_t
63
66 type, public, extends(gs_comm_t) :: gs_device_mpi_t
67 type(gs_device_mpi_buf_t) :: send_buf
68 type(gs_device_mpi_buf_t) :: recv_buf
69 type(c_ptr), allocatable :: stream(:)
70 type(c_ptr), allocatable :: event(:)
71 integer :: nb_strtgy
72 type(c_ptr) :: send_event = c_null_ptr
73 contains
74 procedure, pass(this) :: init => gs_device_mpi_init
75 procedure, pass(this) :: free => gs_device_mpi_free
76 procedure, pass(this) :: nbsend => gs_device_mpi_nbsend
77 procedure, pass(this) :: nbrecv => gs_device_mpi_nbrecv
78 procedure, pass(this) :: nbwait => gs_device_mpi_nbwait
79 end type gs_device_mpi_t
80
81#ifdef HAVE_HIP
82 interface
83 subroutine hip_gs_pack(u_d, buf_d, dof_d, offset, n, stream) &
84 bind(c, name='hip_gs_pack')
85 use, intrinsic :: iso_c_binding
86 implicit none
87 integer(c_int), value :: n, offset
88 type(c_ptr), value :: u_d, buf_d, dof_d, stream
89 end subroutine hip_gs_pack
90 end interface
91
92 interface
93 subroutine hip_gs_unpack(u_d, op, buf_d, dof_d, offset, n, stream) &
94 bind(c, name='hip_gs_unpack')
95 use, intrinsic :: iso_c_binding
96 implicit none
97 integer(c_int), value :: op, offset, n
98 type(c_ptr), value :: u_d, buf_d, dof_d, stream
99 end subroutine hip_gs_unpack
100 end interface
101#elif HAVE_CUDA
102 interface
103 subroutine cuda_gs_pack(u_d, buf_d, dof_d, offset, n, stream) &
104 bind(c, name='cuda_gs_pack')
105 use, intrinsic :: iso_c_binding
106 implicit none
107 integer(c_int), value :: n, offset
108 type(c_ptr), value :: u_d, buf_d, dof_d, stream
109 end subroutine cuda_gs_pack
110 end interface
111
112 interface
113 subroutine cuda_gs_unpack(u_d, op, buf_d, dof_d, offset, n, stream) &
114 bind(c, name='cuda_gs_unpack')
115 use, intrinsic :: iso_c_binding
116 implicit none
117 integer(c_int), value :: op, offset, n
118 type(c_ptr), value :: u_d, buf_d, dof_d, stream
119 end subroutine cuda_gs_unpack
120 end interface
121#endif
122
123 interface
124 subroutine device_mpi_init_reqs(n, reqs) &
125 bind(c, name='device_mpi_init_reqs')
126 use, intrinsic :: iso_c_binding
127 implicit none
128 integer(c_int), value :: n
129 type(c_ptr) :: reqs
130 end subroutine device_mpi_init_reqs
131 end interface
132
133 interface
134 subroutine device_mpi_free_reqs(reqs) &
135 bind(c, name='device_mpi_free_reqs')
136 use, intrinsic :: iso_c_binding
137 implicit none
138 type(c_ptr) :: reqs
139 end subroutine device_mpi_free_reqs
140 end interface
141
142 interface
143 subroutine device_mpi_isend(buf_d, offset, nbytes, rank, reqs, i) &
144 bind(c, name='device_mpi_isend')
145 use, intrinsic :: iso_c_binding
146 implicit none
147 integer(c_int), value :: offset, nbytes, rank, i
148 type(c_ptr), value :: buf_d, reqs
149 end subroutine device_mpi_isend
150 end interface
151
152 interface
153 subroutine device_mpi_irecv(buf_d, offset, nbytes, rank, reqs, i) &
154 bind(c, name='device_mpi_irecv')
155 use, intrinsic :: iso_c_binding
156 implicit none
157 integer(c_int), value :: offset, nbytes, rank, i
158 type(c_ptr), value :: buf_d, reqs
159 end subroutine device_mpi_irecv
160 end interface
161
162 interface
163 integer(c_int) function device_mpi_test(reqs, i) &
164 bind(c, name='device_mpi_test')
165 use, intrinsic :: iso_c_binding
166 implicit none
167 integer(c_int), value :: i
168 type(c_ptr), value :: reqs
169 end function device_mpi_test
170 end interface
171
172 interface
173 subroutine device_mpi_waitall(n, reqs) &
174 bind(c, name='device_mpi_waitall')
175 use, intrinsic :: iso_c_binding
176 implicit none
177 integer(c_int), value :: n
178 type(c_ptr), value :: reqs
179 end subroutine device_mpi_waitall
180 end interface
181
182 interface
183 integer(c_int) function device_mpi_waitany(n, reqs, i) &
184 bind(c, name='device_mpi_waitany')
185 use, intrinsic :: iso_c_binding
186 implicit none
187 integer(c_int), value :: n
188 integer(c_int) :: i
189 type(c_ptr), value :: reqs
190 end function device_mpi_waitany
191 end interface
192
193contains
194
195 subroutine gs_device_mpi_buf_init(this, pe_order, dof_stack, mark_dupes)
196 class(gs_device_mpi_buf_t), intent(inout) :: this
197 integer, allocatable, intent(inout) :: pe_order(:)
198 type(stack_i4_t), allocatable, intent(inout) :: dof_stack(:)
199 logical, intent(in) :: mark_dupes
200 integer, allocatable :: dofs(:)
201 integer :: i, j, total
202 integer(c_size_t) :: sz
203 type(htable_i4_t) :: doftable
204 integer :: dupe, marked, k
205 real(c_rp) :: rp_dummy
206 integer(c_int32_t) :: i4_dummy
207
208 call device_mpi_init_reqs(size(pe_order), this%reqs)
209
210 allocate(this%ndofs(size(pe_order)))
211 allocate(this%offset(size(pe_order)))
212
213 total = 0
214 do i = 1, size(pe_order)
215 this%ndofs(i) = dof_stack(pe_order(i))%size()
216 this%offset(i) = total
217 total = total + this%ndofs(i)
218 end do
219
220 this%total = total
221
222 sz = c_sizeof(rp_dummy) * total
223 call device_alloc(this%buf_d, sz)
224 call device_memset(this%buf_d, 0, sz, sync=.true.)
225
226 sz = c_sizeof(i4_dummy) * total
227 call device_alloc(this%dof_d, sz)
228
229 if (mark_dupes) call doftable%init(2*total)
230 allocate(dofs(total))
231
232 ! Copy from dof_stack into dofs, optionally marking duplicates with doftable
233 marked = 0
234 do i = 1, size(pe_order)
235 ! %array() breaks on cray
236 select type (arr => dof_stack(pe_order(i))%data)
237 type is (integer)
238 do j = 1, this%ndofs(i)
239 k = this%offset(i) + j
240 if (mark_dupes) then
241 if (doftable%get(arr(j), dupe) .eq. 0) then
242 if (dofs(dupe) .gt. 0) then
243 dofs(dupe) = -dofs(dupe)
244 marked = marked + 1
245 end if
246 dofs(k) = -arr(j)
247 marked = marked + 1
248 else
249 call doftable%set(arr(j), k)
250 dofs(k) = arr(j)
251 end if
252 else
253 dofs(k) = arr(j)
254 end if
255 end do
256 end select
257 end do
258 call device_memcpy(dofs, this%dof_d, total, host_to_device, sync=.true.)
259 ! Syncing here prevents the memory in dofs to accidently be corrupted
260 ! while this memcpy is happening.
261 ! This might be happening in many other places as well. Karp 4/6-25
262
263 deallocate(dofs)
264 call doftable%free()
265
266 end subroutine gs_device_mpi_buf_init
267
268 subroutine gs_device_mpi_buf_free(this)
269 class(gs_device_mpi_buf_t), intent(inout) :: this
270
271 if (c_associated(this%reqs)) call device_mpi_free_reqs(this%reqs)
272
273 if (allocated(this%ndofs)) deallocate(this%ndofs)
274 if (allocated(this%offset)) deallocate(this%offset)
275
276 if (c_associated(this%buf_d)) call device_free(this%buf_d)
277 if (c_associated(this%dof_d)) call device_free(this%dof_d)
278 end subroutine gs_device_mpi_buf_free
279
281 subroutine gs_device_mpi_init(this, send_pe, recv_pe)
282 class(gs_device_mpi_t), intent(inout) :: this
283 type(stack_i4_t), intent(inout) :: send_pe
284 type(stack_i4_t), intent(inout) :: recv_pe
285 integer :: i
286
287 call this%init_order(send_pe, recv_pe)
288
289 call this%send_buf%init(this%send_pe, this%send_dof, .false.)
290 call this%recv_buf%init(this%recv_pe, this%recv_dof, .true.)
291
292#if defined(HAVE_HIP) || defined(HAVE_CUDA)
293 ! Create a set of non-blocking streams
294 allocate(this%stream(size(this%recv_pe)))
295 do i = 1, size(this%recv_pe)
297 end do
298
299 allocate(this%event(size(this%recv_pe)))
300 do i = 1, size(this%recv_pe)
301 call device_event_create(this%event(i), 2)
302 end do
303#endif
304
305
306 this%nb_strtgy = 0
307
308 end subroutine gs_device_mpi_init
309
311 subroutine gs_device_mpi_free(this)
312 class(gs_device_mpi_t), intent(inout) :: this
313 integer :: i
314
315 call this%send_buf%free()
316 call this%recv_buf%free()
317
318 call this%free_order()
319 call this%free_dofs()
320
321#if defined(HAVE_HIP) || defined(HAVE_CUDA)
322 if (allocated(this%stream)) then
323 do i = 1, size(this%stream)
324 call device_stream_destroy(this%stream(i))
325 end do
326 deallocate(this%stream)
327 end if
328#endif
329
330 end subroutine gs_device_mpi_free
331
333 subroutine gs_device_mpi_nbsend(this, u, n, deps, strm)
334 class(gs_device_mpi_t), intent(inout) :: this
335 integer, intent(in) :: n
336 real(kind=rp), dimension(n), intent(inout) :: u
337 type(c_ptr), intent(inout) :: deps
338 type(c_ptr), intent(inout) :: strm
339 integer :: i
340 type(c_ptr) :: u_d
341
342 u_d = device_get_ptr(u)
343
344 if (iand(this%nb_strtgy, 1) .eq. 0) then
345
346#ifdef HAVE_HIP
347 call hip_gs_pack(u_d, &
348 this%send_buf%buf_d, &
349 this%send_buf%dof_d, &
350 0, this%send_buf%total, &
351 strm)
352#elif HAVE_CUDA
353 call cuda_gs_pack(u_d, &
354 this%send_buf%buf_d, &
355 this%send_buf%dof_d, &
356 0, this%send_buf%total, &
357 strm)
358#else
359 call neko_error('gs_device_mpi: no backend')
360#endif
361
362 call device_sync(strm)
363
364 do i = 1, size(this%send_pe)
365 call device_mpi_isend(this%send_buf%buf_d, &
366 rp*this%send_buf%offset(i), &
367 rp*this%send_buf%ndofs(i), this%send_pe(i), &
368 this%send_buf%reqs, i)
369 end do
370
371 else
372
373 do i = 1, size(this%send_pe)
374 call device_stream_wait_event(this%stream(i), deps, 0)
375#ifdef HAVE_HIP
376 call hip_gs_pack(u_d, &
377 this%send_buf%buf_d, &
378 this%send_buf%dof_d, &
379 this%send_buf%offset(i), &
380 this%send_buf%ndofs(i), &
381 this%stream(i))
382#elif HAVE_CUDA
383 call cuda_gs_pack(u_d, &
384 this%send_buf%buf_d, &
385 this%send_buf%dof_d, &
386 this%send_buf%offset(i), &
387 this%send_buf%ndofs(i), &
388 this%stream(i))
389#else
390 call neko_error('gs_device_mpi: no backend')
391#endif
392 end do
393
394 ! Consider adding a poll loop here once we have device_query in place
395 do i = 1, size(this%send_pe)
396 call device_sync(this%stream(i))
397 call device_mpi_isend(this%send_buf%buf_d, &
398 rp*this%send_buf%offset(i), &
399 rp*this%send_buf%ndofs(i), this%send_pe(i), &
400 this%send_buf%reqs, i)
401 end do
402 end if
403
404 end subroutine gs_device_mpi_nbsend
405
407 subroutine gs_device_mpi_nbrecv(this)
408 class(gs_device_mpi_t), intent(inout) :: this
409 integer :: i
410
411 do i = 1, size(this%recv_pe)
412 call device_mpi_irecv(this%recv_buf%buf_d, rp*this%recv_buf%offset(i), &
413 rp*this%recv_buf%ndofs(i), this%recv_pe(i), &
414 this%recv_buf%reqs, i)
415 end do
416
417 end subroutine gs_device_mpi_nbrecv
418
420 subroutine gs_device_mpi_nbwait(this, u, n, op, strm)
421 class(gs_device_mpi_t), intent(inout) :: this
422 integer, intent(in) :: n
423 real(kind=rp), dimension(n), intent(inout) :: u
424 type(c_ptr), intent(inout) :: strm
425 integer :: op, done_req, i
426 type(c_ptr) :: u_d
427
428 u_d = device_get_ptr(u)
429
430 if (iand(this%nb_strtgy, 2) .eq. 0) then
431 call device_mpi_waitall(size(this%recv_pe), this%recv_buf%reqs)
432
433#ifdef HAVE_HIP
434 call hip_gs_unpack(u_d, op, &
435 this%recv_buf%buf_d, &
436 this%recv_buf%dof_d, &
437 0, this%recv_buf%total, &
438 strm)
439#elif HAVE_CUDA
440 call cuda_gs_unpack(u_d, op, &
441 this%recv_buf%buf_d, &
442 this%recv_buf%dof_d, &
443 0, this%recv_buf%total, &
444 strm)
445#else
446 call neko_error('gs_device_mpi: no backend')
447#endif
448
449 call device_mpi_waitall(size(this%send_pe), this%send_buf%reqs)
450
451 ! Syncing here seems to prevent some race condition
452 call device_sync(strm)
453
454 else
455
456 do while(device_mpi_waitany(size(this%recv_pe), &
457 this%recv_buf%reqs, done_req) .ne. 0)
458
459#ifdef HAVE_HIP
460 call hip_gs_unpack(u_d, op, &
461 this%recv_buf%buf_d, &
462 this%recv_buf%dof_d, &
463 this%recv_buf%offset(done_req), &
464 this%recv_buf%ndofs(done_req), &
465 this%stream(done_req))
466#elif HAVE_CUDA
467 call cuda_gs_unpack(u_d, op, &
468 this%recv_buf%buf_d, &
469 this%recv_buf%dof_d, &
470 this%recv_buf%offset(done_req), &
471 this%recv_buf%ndofs(done_req), &
472 this%stream(done_req))
473#else
474 call neko_error('gs_device_mpi: no backend')
475#endif
476 call device_event_record(this%event(done_req), this%stream(done_req))
477 end do
478
479 call device_mpi_waitall(size(this%send_pe), this%send_buf%reqs)
480
481 ! Sync non-blocking streams
482 do done_req = 1, size(this%recv_pe)
483 call device_stream_wait_event(strm, &
484 this%event(done_req), 0)
485 end do
486
487 end if
488
489 end subroutine gs_device_mpi_nbwait
490
491end module gs_device_mpi
void cuda_gs_unpack(real *u_d, int op, real *buf_d, int *dof_d, int offset, int n, cudaStream_t stream)
Definition gs.cu:132
void cuda_gs_pack(void *u_d, void *buf_d, void *dof_d, int offset, int n, cudaStream_t stream)
Definition gs.cu:116
Return the device pointer for an associated Fortran array.
Definition device.F90:96
Copy data between host and device (or device and device)
Definition device.F90:66
Synchronize a device or stream.
Definition device.F90:102
Definition comm.F90:1
integer, public pe_size
MPI size of communicator.
Definition comm.F90:58
integer, public pe_rank
MPI rank.
Definition comm.F90:55
Device abstraction, common interface for various accelerators.
Definition device.F90:34
subroutine, public device_event_record(event, stream)
Record a device event.
Definition device.F90:1290
integer, parameter, public host_to_device
Definition device.F90:47
subroutine, public device_free(x_d)
Deallocate memory on the device.
Definition device.F90:214
subroutine, public device_event_destroy(event)
Destroy a device event.
Definition device.F90:1274
subroutine, public device_alloc(x_d, s)
Allocate memory on the device.
Definition device.F90:187
subroutine, public device_stream_create_with_priority(stream, flags, prio)
Create a device stream/command queue with priority.
Definition device.F90:1168
subroutine, public device_stream_wait_event(stream, event, flags)
Synchronize a device stream with an event.
Definition device.F90:1203
subroutine, public device_event_create(event, flags)
Create a device event queue.
Definition device.F90:1244
integer, public strm_high_prio
High priority stream setting.
Definition device.F90:60
subroutine, public device_memset(x_d, v, s, sync, strm)
Set memory on the device to a value.
Definition device.F90:233
subroutine, public device_stream_destroy(stream)
Destroy a device stream/command queue.
Definition device.F90:1185
Defines a gather-scatter communication method.
Definition gs_comm.f90:34
Defines GPU aware MPI gather-scatter communication.
subroutine gs_device_mpi_buf_init(this, pe_order, dof_stack, mark_dupes)
subroutine gs_device_mpi_nbrecv(this)
Post non-blocking receive operations.
subroutine gs_device_mpi_nbwait(this, u, n, op, strm)
Wait for non-blocking operations.
subroutine gs_device_mpi_free(this)
Deallocate MPI based communication method.
subroutine gs_device_mpi_nbsend(this, u, n, deps, strm)
Post non-blocking send operations.
subroutine gs_device_mpi_buf_free(this)
subroutine gs_device_mpi_init(this, send_pe, recv_pe)
Initialise MPI based communication method.
Implements a hash table ADT.
Definition htable.f90:36
integer, parameter, public c_rp
Definition num_types.f90:13
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Implements a dynamic stack ADT.
Definition stack.f90:35
Utilities.
Definition utils.f90:35
Gather-scatter communication method.
Definition gs_comm.f90:46
Buffers for non-blocking communication and packing/unpacking.
Gather-scatter communication using device MPI. The arrays are indexed per PE like send_pe and @ recv_...
Integer based hash table.
Definition htable.f90:82
Integer based stack.
Definition stack.f90:63