Neko 1.99.1
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
fusedcg_cpld_device.F90
Go to the documentation of this file.
1! Copyright (c) 2021-2024, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
36 use precon, only : pc_t
37 use ax_product, only : ax_t
38 use num_types, only: rp, c_rp
39 use field, only : field_t
40 use coefs, only : coef_t
41 use gather_scatter, only : gs_t, gs_op_add
42 use bc_list, only : bc_list_t
43 use math, only : glsc3, rzero, copy, abscmp
45 use device
46 use utils, only : neko_error
48 use mpi_f08, only : mpi_in_place, mpi_allreduce, &
49 mpi_sum
50 use, intrinsic :: iso_c_binding, only : c_ptr, c_null_ptr, &
51 c_associated, c_size_t, c_sizeof, c_int, c_loc
52 implicit none
53 private
54
55 integer, parameter :: device_fusedcg_cpld_p_space = 10
56
58 type, public, extends(ksp_t) :: fusedcg_cpld_device_t
59 real(kind=rp), allocatable :: w1(:)
60 real(kind=rp), allocatable :: w2(:)
61 real(kind=rp), allocatable :: w3(:)
62 real(kind=rp), allocatable :: r1(:)
63 real(kind=rp), allocatable :: r2(:)
64 real(kind=rp), allocatable :: r3(:)
65 real(kind=rp), allocatable :: z1(:)
66 real(kind=rp), allocatable :: z2(:)
67 real(kind=rp), allocatable :: z3(:)
68 real(kind=rp), allocatable :: tmp(:)
69 real(kind=rp), allocatable :: p1(:,:)
70 real(kind=rp), allocatable :: p2(:,:)
71 real(kind=rp), allocatable :: p3(:,:)
72 real(kind=rp), allocatable :: alpha(:)
73 type(c_ptr) :: w1_d = c_null_ptr
74 type(c_ptr) :: w2_d = c_null_ptr
75 type(c_ptr) :: w3_d = c_null_ptr
76 type(c_ptr) :: r1_d = c_null_ptr
77 type(c_ptr) :: r2_d = c_null_ptr
78 type(c_ptr) :: r3_d = c_null_ptr
79 type(c_ptr) :: z1_d = c_null_ptr
80 type(c_ptr) :: z2_d = c_null_ptr
81 type(c_ptr) :: z3_d = c_null_ptr
82 type(c_ptr) :: alpha_d = c_null_ptr
83 type(c_ptr) :: p1_d_d = c_null_ptr
84 type(c_ptr) :: p2_d_d = c_null_ptr
85 type(c_ptr) :: p3_d_d = c_null_ptr
86 type(c_ptr) :: tmp_d = c_null_ptr
87 type(c_ptr), allocatable :: p1_d(:)
88 type(c_ptr), allocatable :: p2_d(:)
89 type(c_ptr), allocatable :: p3_d(:)
90 type(c_ptr) :: gs_event1 = c_null_ptr
91 type(c_ptr) :: gs_event2 = c_null_ptr
92 type(c_ptr) :: gs_event3 = c_null_ptr
93 contains
94 procedure, pass(this) :: init => fusedcg_cpld_device_init
95 procedure, pass(this) :: free => fusedcg_cpld_device_free
96 procedure, pass(this) :: solve => fusedcg_cpld_device_solve
97 procedure, pass(this) :: solve_coupled => fusedcg_cpld_device_solve_coupled
99
100#ifdef HAVE_CUDA
101 interface
102 subroutine cuda_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
103 b1_d, b2_d, b3_d, tmp_d, n) bind(c, name='cuda_fusedcg_cpld_part1')
104 use, intrinsic :: iso_c_binding
105 import c_rp
106 implicit none
107 type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d
108 integer(c_int) :: n
109 end subroutine cuda_fusedcg_cpld_part1
110 end interface
111
112 interface
113 subroutine cuda_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
114 po1_d, po2_d, po3_d, beta, n) bind(c, name='cuda_fusedcg_cpld_update_p')
115 use, intrinsic :: iso_c_binding
116 import c_rp
117 implicit none
118 type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
119 type(c_ptr), value :: po1_d, po2_d, po3_d
120 real(c_rp) :: beta
121 integer(c_int) :: n
122 end subroutine cuda_fusedcg_cpld_update_p
123 end interface
124
125 interface
126 subroutine cuda_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, &
127 alpha, p_cur, n) bind(c, name='cuda_fusedcg_cpld_update_x')
128 use, intrinsic :: iso_c_binding
129 implicit none
130 type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
131 integer(c_int) :: p_cur, n
132 end subroutine cuda_fusedcg_cpld_update_x
133 end interface
134
135 interface
136 real(c_rp) function cuda_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
137 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) &
138 bind(c, name='cuda_fusedcg_cpld_part2')
139 use, intrinsic :: iso_c_binding
140 import c_rp
141 implicit none
142 type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
143 type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
144 real(c_rp) :: alpha
145 integer(c_int) :: n, p_cur
146 end function cuda_fusedcg_cpld_part2
147 end interface
148#elif HAVE_HIP
149 interface
150 subroutine hip_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
151 b1_d, b2_d, b3_d, tmp_d, n) bind(c, name='hip_fusedcg_cpld_part1')
152 use, intrinsic :: iso_c_binding
153 import c_rp
154 implicit none
155 type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d
156 integer(c_int) :: n
157 end subroutine hip_fusedcg_cpld_part1
158 end interface
159
160 interface
161 subroutine hip_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
162 po1_d, po2_d, po3_d, beta, n) bind(c, name='hip_fusedcg_cpld_update_p')
163 use, intrinsic :: iso_c_binding
164 import c_rp
165 implicit none
166 type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
167 type(c_ptr), value :: po1_d, po2_d, po3_d
168 real(c_rp) :: beta
169 integer(c_int) :: n
170 end subroutine hip_fusedcg_cpld_update_p
171 end interface
172
173 interface
174 subroutine hip_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, &
175 alpha, p_cur, n) bind(c, name='hip_fusedcg_cpld_update_x')
176 use, intrinsic :: iso_c_binding
177 implicit none
178 type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
179 integer(c_int) :: p_cur, n
180 end subroutine hip_fusedcg_cpld_update_x
181 end interface
182
183 interface
184 real(c_rp) function hip_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
185 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) &
186 bind(c, name='hip_fusedcg_cpld_part2')
187 use, intrinsic :: iso_c_binding
188 import c_rp
189 implicit none
190 type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
191 type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
192 real(c_rp) :: alpha
193 integer(c_int) :: n, p_cur
194 end function hip_fusedcg_cpld_part2
195 end interface
196#endif
197
198contains
199
200 subroutine device_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
201 b1_d, b2_d, b3_d, tmp_d, n)
202 type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d
203 type(c_ptr), value :: tmp_d
204 integer(c_int) :: n
205#ifdef HAVE_HIP
206 call hip_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
207#elif HAVE_CUDA
208 call cuda_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
209#else
210 call neko_error('No device backend configured')
211#endif
212 end subroutine device_fusedcg_cpld_part1
213
214 subroutine device_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
215 po1_d, po2_d, po3_d, beta, n)
216 type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
217 type(c_ptr), value :: po1_d, po2_d, po3_d
218 real(c_rp) :: beta
219 integer(c_int) :: n
220#ifdef HAVE_HIP
221 call hip_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
222 po1_d, po2_d, po3_d, beta, n)
223#elif HAVE_CUDA
224 call cuda_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
225 po1_d, po2_d, po3_d, beta, n)
226#else
227 call neko_error('No device backend configured')
228#endif
229 end subroutine device_fusedcg_cpld_update_p
230
231 subroutine device_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
232 p1_d, p2_d, p3_d, alpha, p_cur, n)
233 type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
234 integer(c_int) :: p_cur, n
235#ifdef HAVE_HIP
236 call hip_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
237 p1_d, p2_d, p3_d, alpha, p_cur, n)
238#elif HAVE_CUDA
239 call cuda_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
240 p1_d, p2_d, p3_d, alpha, p_cur, n)
241#else
242 call neko_error('No device backend configured')
243#endif
244 end subroutine device_fusedcg_cpld_update_x
245
246 function device_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
247 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) result(res)
248 type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
249 type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
250 real(c_rp) :: alpha
251 integer :: n, p_cur
252 real(kind=rp) :: res
253 integer :: ierr
254#ifdef HAVE_HIP
255 res = hip_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
256 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
257#elif HAVE_CUDA
258 res = cuda_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
259 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
260#else
261 call neko_error('No device backend configured')
262#endif
263
264#ifndef HAVE_DEVICE_MPI
265 if (pe_size .gt. 1) then
266 call mpi_allreduce(mpi_in_place, res, 1, &
267 mpi_real_precision, mpi_sum, neko_comm, ierr)
268 end if
269#endif
270
271 end function device_fusedcg_cpld_part2
272
274 subroutine fusedcg_cpld_device_init(this, n, max_iter, M, &
275 rel_tol, abs_tol, monitor)
276 class(fusedcg_cpld_device_t), target, intent(inout) :: this
277 class(pc_t), optional, intent(in), target :: M
278 integer, intent(in) :: n
279 integer, intent(in) :: max_iter
280 real(kind=rp), optional, intent(in) :: rel_tol
281 real(kind=rp), optional, intent(in) :: abs_tol
282 logical, optional, intent(in) :: monitor
283 type(c_ptr) :: ptr
284 integer(c_size_t) :: p_size
285 integer :: i
286
287 call this%free()
288
289 allocate(this%w1(n))
290 allocate(this%w2(n))
291 allocate(this%w3(n))
292 allocate(this%r1(n))
293 allocate(this%r2(n))
294 allocate(this%r3(n))
295 allocate(this%z1(n))
296 allocate(this%z2(n))
297 allocate(this%z3(n))
298 allocate(this%tmp(n))
299 allocate(this%p1(n, device_fusedcg_cpld_p_space))
300 allocate(this%p2(n, device_fusedcg_cpld_p_space))
301 allocate(this%p3(n, device_fusedcg_cpld_p_space))
302 allocate(this%p1_d(device_fusedcg_cpld_p_space))
303 allocate(this%p2_d(device_fusedcg_cpld_p_space))
304 allocate(this%p3_d(device_fusedcg_cpld_p_space))
305 allocate(this%alpha(device_fusedcg_cpld_p_space))
306
307 if (present(m)) then
308 this%M => m
309 end if
310
311 call device_map(this%w1, this%w1_d, n)
312 call device_map(this%w2, this%w2_d, n)
313 call device_map(this%w3, this%w3_d, n)
314 call device_map(this%r1, this%r1_d, n)
315 call device_map(this%r2, this%r2_d, n)
316 call device_map(this%r3, this%r3_d, n)
317 call device_map(this%z1, this%z1_d, n)
318 call device_map(this%z2, this%z2_d, n)
319 call device_map(this%z3, this%z3_d, n)
320 call device_map(this%tmp, this%tmp_d, n)
321 call device_map(this%alpha, this%alpha_d, device_fusedcg_cpld_p_space)
323 this%p1_d(i) = c_null_ptr
324 call device_map(this%p1(:,i), this%p1_d(i), n)
325
326 this%p2_d(i) = c_null_ptr
327 call device_map(this%p2(:,i), this%p2_d(i), n)
328
329 this%p3_d(i) = c_null_ptr
330 call device_map(this%p3(:,i), this%p3_d(i), n)
331 end do
332
333 p_size = c_sizeof(c_null_ptr) * (device_fusedcg_cpld_p_space)
334 call device_alloc(this%p1_d_d, p_size)
335 call device_alloc(this%p2_d_d, p_size)
336 call device_alloc(this%p3_d_d, p_size)
337 ptr = c_loc(this%p1_d)
338 call device_memcpy(ptr, this%p1_d_d, p_size, &
339 host_to_device, sync=.false.)
340 ptr = c_loc(this%p2_d)
341 call device_memcpy(ptr, this%p2_d_d, p_size, &
342 host_to_device, sync=.false.)
343 ptr = c_loc(this%p3_d)
344 call device_memcpy(ptr, this%p3_d_d, p_size, &
345 host_to_device, sync=.false.)
346 if (present(rel_tol) .and. present(abs_tol) .and. present(monitor)) then
347 call this%ksp_init(max_iter, rel_tol, abs_tol, monitor = monitor)
348 else if (present(rel_tol) .and. present(abs_tol)) then
349 call this%ksp_init(max_iter, rel_tol, abs_tol)
350 else if (present(monitor) .and. present(abs_tol)) then
351 call this%ksp_init(max_iter, abs_tol = abs_tol, monitor = monitor)
352 else if (present(rel_tol) .and. present(monitor)) then
353 call this%ksp_init(max_iter, rel_tol, monitor = monitor)
354 else if (present(rel_tol)) then
355 call this%ksp_init(max_iter, rel_tol = rel_tol)
356 else if (present(abs_tol)) then
357 call this%ksp_init(max_iter, abs_tol = abs_tol)
358 else if (present(monitor)) then
359 call this%ksp_init(max_iter, monitor = monitor)
360 else
361 call this%ksp_init(max_iter)
362 end if
363
364 call device_event_create(this%gs_event1, 2)
365 call device_event_create(this%gs_event2, 2)
366 call device_event_create(this%gs_event3, 2)
367
368 end subroutine fusedcg_cpld_device_init
369
372 class(fusedcg_cpld_device_t), intent(inout) :: this
373 integer :: i
374
375 call this%ksp_free()
376
377 if (allocated(this%w1)) then
378 deallocate(this%w1)
379 end if
380
381 if (allocated(this%w2)) then
382 deallocate(this%w2)
383 end if
384
385 if (allocated(this%w3)) then
386 deallocate(this%w3)
387 end if
388
389 if (allocated(this%r1)) then
390 deallocate(this%r1)
391 end if
392
393 if (allocated(this%r2)) then
394 deallocate(this%r2)
395 end if
396
397 if (allocated(this%r3)) then
398 deallocate(this%r3)
399 end if
400
401 if (allocated(this%z1)) then
402 deallocate(this%z1)
403 end if
404
405 if (allocated(this%z2)) then
406 deallocate(this%z2)
407 end if
408
409 if (allocated(this%z3)) then
410 deallocate(this%z3)
411 end if
412
413 if (allocated(this%tmp)) then
414 deallocate(this%tmp)
415 end if
416
417 if (allocated(this%alpha)) then
418 deallocate(this%alpha)
419 end if
420
421 if (allocated(this%p1)) then
422 deallocate(this%p1)
423 end if
424
425 if (allocated(this%p2)) then
426 deallocate(this%p2)
427 end if
428
429 if (allocated(this%p3)) then
430 deallocate(this%p3)
431 end if
432
433 if (c_associated(this%w1_d)) then
434 call device_free(this%w1_d)
435 end if
436
437 if (c_associated(this%w2_d)) then
438 call device_free(this%w2_d)
439 end if
440
441 if (c_associated(this%w3_d)) then
442 call device_free(this%w3_d)
443 end if
444
445 if (c_associated(this%r1_d)) then
446 call device_free(this%r1_d)
447 end if
448
449 if (c_associated(this%r2_d)) then
450 call device_free(this%r2_d)
451 end if
452
453 if (c_associated(this%r3_d)) then
454 call device_free(this%r3_d)
455 end if
456
457 if (c_associated(this%z1_d)) then
458 call device_free(this%z1_d)
459 end if
460
461 if (c_associated(this%z2_d)) then
462 call device_free(this%z2_d)
463 end if
464
465 if (c_associated(this%z3_d)) then
466 call device_free(this%z3_d)
467 end if
468
469 if (c_associated(this%alpha_d)) then
470 call device_free(this%alpha_d)
471 end if
472
473 if (c_associated(this%tmp_d)) then
474 call device_free(this%tmp_d)
475 end if
476
477 if (allocated(this%p1_d)) then
479 if (c_associated(this%p1_d(i))) then
480 call device_free(this%p1_d(i))
481 end if
482 end do
483 end if
484
485 if (allocated(this%p2_d)) then
487 if (c_associated(this%p2_d(i))) then
488 call device_free(this%p2_d(i))
489 end if
490 end do
491 end if
492
493 if (allocated(this%p3_d)) then
495 if (c_associated(this%p3_d(i))) then
496 call device_free(this%p3_d(i))
497 end if
498 end do
499 end if
500
501 nullify(this%M)
502
503 if (c_associated(this%gs_event1)) then
504 call device_event_destroy(this%gs_event1)
505 end if
506
507 if (c_associated(this%gs_event2)) then
508 call device_event_destroy(this%gs_event2)
509 end if
510
511 if (c_associated(this%gs_event3)) then
512 call device_event_destroy(this%gs_event3)
513 end if
514
515 end subroutine fusedcg_cpld_device_free
516
518 function fusedcg_cpld_device_solve_coupled(this, Ax, x, y, z, fx, fy, fz, &
519 n, coef, blstx, blsty, blstz, gs_h, niter) result(ksp_results)
520 class(fusedcg_cpld_device_t), intent(inout) :: this
521 class(ax_t), intent(in) :: ax
522 type(field_t), intent(inout) :: x
523 type(field_t), intent(inout) :: y
524 type(field_t), intent(inout) :: z
525 integer, intent(in) :: n
526 real(kind=rp), dimension(n), intent(in) :: fx
527 real(kind=rp), dimension(n), intent(in) :: fy
528 real(kind=rp), dimension(n), intent(in) :: fz
529 type(coef_t), intent(inout) :: coef
530 type(bc_list_t), intent(inout) :: blstx
531 type(bc_list_t), intent(inout) :: blsty
532 type(bc_list_t), intent(inout) :: blstz
533 type(gs_t), intent(inout) :: gs_h
534 type(ksp_monitor_t), dimension(3) :: ksp_results
535 integer, optional, intent(in) :: niter
536 integer :: iter, max_iter, ierr, i, p_cur, p_prev
537 real(kind=rp) :: rnorm, rtr, norm_fac, rtz1, rtz2
538 real(kind=rp) :: pap, beta
539 type(c_ptr) :: fx_d
540 type(c_ptr) :: fy_d
541 type(c_ptr) :: fz_d
542
543 fx_d = device_get_ptr(fx)
544 fy_d = device_get_ptr(fy)
545 fz_d = device_get_ptr(fz)
546
547 if (present(niter)) then
548 max_iter = niter
549 else
550 max_iter = ksp_max_iter
551 end if
552 norm_fac = 1.0_rp / sqrt(coef%volume)
553
554 associate(w1 => this%w1, w2 => this%w2, w3 => this%w3, r1 => this%r1, &
555 r2 => this%r2, r3 => this%r3, p1 => this%p1, p2 => this%p2, &
556 p3 => this%p3, z1 => this%z1, z2 => this%z2, z3 => this%z3, &
557 tmp_d => this%tmp_d, alpha => this%alpha, alpha_d => this%alpha_d, &
558 w1_d => this%w1_d, w2_d => this%w2_d, w3_d => this%w3_d, &
559 r1_d => this%r1_d, r2_d => this%r2_d, r3_d => this%r3_d, &
560 z1_d => this%z1_d, z2_d => this%z2_d, z3_d => this%z3_d, &
561 p1_d => this%p1_d, p2_d => this%p2_d, p3_d => this%p3_d, &
562 p1_d_d => this%p1_d_d, p2_d_d => this%p2_d_d, p3_d_d => this%p3_d_d)
563
564 rtz1 = 1.0_rp
566 p_cur = 1
567
568
569 call device_rzero(x%x_d, n)
570 call device_rzero(y%x_d, n)
571 call device_rzero(z%x_d, n)
572 call device_rzero(p1_d(1), n)
573 call device_rzero(p2_d(1), n)
574 call device_rzero(p3_d(1), n)
575 call device_copy(r1_d, fx_d, n)
576 call device_copy(r2_d, fy_d, n)
577 call device_copy(r3_d, fz_d, n)
578
579 call device_fusedcg_cpld_part1(r1_d, r2_d, r3_d, r1_d, &
580 r2_d, r3_d, tmp_d, n)
581
582 rtr = device_glsc3(tmp_d, coef%mult_d, coef%binv_d, n)
583
584 rnorm = sqrt(rtr)*norm_fac
585 ksp_results%res_start = rnorm
586 ksp_results%res_final = rnorm
587 ksp_results(1)%iter = 0
588 ksp_results(2:3)%iter = -1
589 if(abscmp(rnorm, 0.0_rp)) then
590 ksp_results%converged = .true.
591 return
592 end if
593 call this%monitor_start('fcpldCG')
594 do iter = 1, max_iter
595 call this%M%solve(z1, r1, n)
596 call this%M%solve(z2, r2, n)
597 call this%M%solve(z3, r3, n)
598 rtz2 = rtz1
599 call device_fusedcg_cpld_part1(z1_d, z2_d, z3_d, &
600 r1_d, r2_d, r3_d, tmp_d, n)
601 rtz1 = device_glsc2(tmp_d, coef%mult_d, n)
602
603 beta = rtz1 / rtz2
604 if (iter .eq. 1) beta = 0.0_rp
605
606 call device_fusedcg_cpld_update_p(p1_d(p_cur), p2_d(p_cur), p3_d(p_cur), &
607 z1_d, z2_d, z3_d, p1_d(p_prev), p2_d(p_prev), p3_d(p_prev), beta, n)
608
609 call ax%compute_vector(w1, w2, w3, &
610 p1(1, p_cur), p2(1, p_cur), p3(1, p_cur), coef, x%msh, x%Xh)
611 call gs_h%op(w1, n, gs_op_add, this%gs_event1)
612 call device_event_sync(this%gs_event1)
613 call blstx%apply(w1, n)
614 call gs_h%op(w2, n, gs_op_add, this%gs_event2)
615 call device_event_sync(this%gs_event2)
616 call blsty%apply(w2, n)
617 call gs_h%op(w3, n, gs_op_add, this%gs_event3)
618 call device_event_sync(this%gs_event3)
619 call blstz%apply(w3, n)
620
621 call device_fusedcg_cpld_part1(w1_d, w2_d, w3_d, p1_d(p_cur), &
622 p2_d(p_cur), p3_d(p_cur), tmp_d, n)
623
624 pap = device_glsc2(tmp_d, coef%mult_d, n)
625
626 alpha(p_cur) = rtz1 / pap
627 rtr = device_fusedcg_cpld_part2(r1_d, r2_d, r3_d, coef%mult_d, &
628 w1_d, w2_d, w3_d, alpha_d, alpha(p_cur), p_cur, n)
629 rnorm = sqrt(rtr)*norm_fac
630 call this%monitor_iter(iter, rnorm)
631 if ((p_cur .eq. device_fusedcg_cpld_p_space) .or. &
632 (rnorm .lt. this%abs_tol) .or. iter .eq. max_iter) then
633 call device_fusedcg_cpld_update_x(x%x_d, y%x_d, z%x_d, &
634 p1_d_d, p2_d_d, p3_d_d, alpha_d, p_cur, n)
635 p_prev = p_cur
636 p_cur = 1
637 if (rnorm .lt. this%abs_tol) exit
638 else
639 p_prev = p_cur
640 p_cur = p_cur + 1
641 end if
642 end do
643 call this%monitor_stop()
644 ksp_results%res_final = rnorm
645 ksp_results%iter = iter
646 ksp_results%converged = this%is_converged(iter, rnorm)
647
648 end associate
649
651
653 function fusedcg_cpld_device_solve(this, Ax, x, f, n, coef, blst, &
654 gs_h, niter) result(ksp_results)
655 class(fusedcg_cpld_device_t), intent(inout) :: this
656 class(ax_t), intent(in) :: ax
657 type(field_t), intent(inout) :: x
658 integer, intent(in) :: n
659 real(kind=rp), dimension(n), intent(in) :: f
660 type(coef_t), intent(inout) :: coef
661 type(bc_list_t), intent(inout) :: blst
662 type(gs_t), intent(inout) :: gs_h
663 type(ksp_monitor_t) :: ksp_results
664 integer, optional, intent(in) :: niter
665
666 ! Throw and error
667 call neko_error('The cpldcg solver is only defined for coupled solves')
668
669 ksp_results%res_final = 0.0
670 ksp_results%iter = 0
671 ksp_results%converged = .false.
672
673 end function fusedcg_cpld_device_solve
674
675end module fusedcg_cpld_device
__device__ T solve(const T u, const T y, const T guess, const T nu, const T kappa, const T B)
void hip_fusedcg_cpld_update_x(void *x1, void *x2, void *x3, void *p1, void *p2, void *p3, void *alpha, int *p_cur, int *n)
void hip_fusedcg_cpld_update_p(void *p1, void *p2, void *p3, void *z1, void *z2, void *z3, void *po1, void *po2, void *po3, real *beta, int *n)
real hip_fusedcg_cpld_part2(void *a1, void *a2, void *a3, void *b, void *c1, void *c2, void *c3, void *alpha_d, real *alpha, int *p_cur, int *n)
void hip_fusedcg_cpld_part1(void *a1, void *a2, void *a3, void *b1, void *b2, void *b3, void *tmp, int *n)
Return the device pointer for an associated Fortran array.
Definition device.F90:96
Map a Fortran array to a device (allocate and associate)
Definition device.F90:72
Copy data between host and device (or device and device)
Definition device.F90:66
Defines a Matrix-vector product.
Definition ax.f90:34
Defines a list of bc_t.
Definition bc_list.f90:34
Coefficients.
Definition coef.f90:34
Definition comm.F90:1
type(mpi_datatype), public mpi_real_precision
MPI type for working precision of REAL types.
Definition comm.F90:50
integer, public pe_size
MPI size of communicator.
Definition comm.F90:58
type(mpi_comm), public neko_comm
MPI communicator.
Definition comm.F90:42
subroutine, public device_rzero(a_d, n, strm)
Zero a real vector.
subroutine, public device_copy(a_d, b_d, n, strm)
Copy a vector .
real(kind=rp) function, public device_glsc3(a_d, b_d, c_d, n, strm)
Weighted inner product .
real(kind=rp) function, public device_glsc2(a_d, b_d, n, strm)
Weighted inner product .
Device abstraction, common interface for various accelerators.
Definition device.F90:34
subroutine, public device_event_sync(event)
Synchronize an event.
Definition device.F90:1309
integer, parameter, public host_to_device
Definition device.F90:47
subroutine, public device_free(x_d)
Deallocate memory on the device.
Definition device.F90:214
subroutine, public device_event_destroy(event)
Destroy a device event.
Definition device.F90:1274
subroutine, public device_alloc(x_d, s)
Allocate memory on the device.
Definition device.F90:187
subroutine, public device_event_create(event, flags)
Create a device event queue.
Definition device.F90:1244
Defines a field.
Definition field.f90:34
Defines a fused Conjugate Gradient method for accelerators.
subroutine device_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha, p_cur, n)
type(ksp_monitor_t) function fusedcg_cpld_device_solve(this, ax, x, f, n, coef, blst, gs_h, niter)
Pipelined PCG solve.
subroutine fusedcg_cpld_device_free(this)
Deallocate a pipelined PCG solver.
subroutine device_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, po1_d, po2_d, po3_d, beta, n)
subroutine fusedcg_cpld_device_init(this, n, max_iter, m, rel_tol, abs_tol, monitor)
Initialise a fused PCG solver.
real(kind=rp) function device_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
type(ksp_monitor_t) function, dimension(3) fusedcg_cpld_device_solve_coupled(this, ax, x, y, z, fx, fy, fz, n, coef, blstx, blsty, blstz, gs_h, niter)
Pipelined PCG solve coupled solve.
integer, parameter device_fusedcg_cpld_p_space
subroutine device_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
Gather-scatter.
Implements the base abstract type for Krylov solvers plus helper types.
Definition krylov.f90:34
integer, parameter, public ksp_max_iter
Maximum number of iters.
Definition krylov.f90:51
Definition math.f90:60
real(kind=rp) function, public glsc3(a, b, c, n)
Weighted inner product .
Definition math.f90:1026
subroutine, public copy(a, b, n)
Copy a vector .
Definition math.f90:255
subroutine, public rzero(a, n)
Zero a real vector.
Definition math.f90:211
integer, parameter, public c_rp
Definition num_types.f90:13
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Krylov preconditioner.
Definition precon.f90:34
Utilities.
Definition utils.f90:35
Base type for a matrix-vector product providing .
Definition ax.f90:43
A list of allocatable `bc_t`. Follows the standard interface of lists.
Definition bc_list.f90:48
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition coef.f90:55
Fused preconditioned conjugate gradient method.
Type for storing initial and final residuals in a Krylov solver.
Definition krylov.f90:56
Base abstract type for a canonical Krylov method, solving .
Definition krylov.f90:73
Defines a canonical Krylov preconditioner.
Definition precon.f90:40