Neko 1.99.1
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
fusedcg_cpld_device.F90
Go to the documentation of this file.
1! Copyright (c) 2021-2024, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
36 use precon, only : pc_t
37 use ax_product, only : ax_t
38 use num_types, only: rp, c_rp
39 use field, only : field_t
40 use coefs, only : coef_t
41 use gather_scatter, only : gs_t, gs_op_add
42 use bc_list, only : bc_list_t
43 use math, only : glsc3, rzero, copy, abscmp
45 use device
46 use utils, only : neko_error
48 use mpi_f08, only : mpi_in_place, mpi_allreduce, &
49 mpi_sum
50 use, intrinsic :: iso_c_binding, only : c_ptr, c_null_ptr, &
51 c_associated, c_size_t, c_sizeof, c_int, c_loc
52 implicit none
53 private
54
55 integer, parameter :: device_fusedcg_cpld_p_space = 10
56
58 type, public, extends(ksp_t) :: fusedcg_cpld_device_t
59 real(kind=rp), allocatable :: w1(:)
60 real(kind=rp), allocatable :: w2(:)
61 real(kind=rp), allocatable :: w3(:)
62 real(kind=rp), allocatable :: r1(:)
63 real(kind=rp), allocatable :: r2(:)
64 real(kind=rp), allocatable :: r3(:)
65 real(kind=rp), allocatable :: z1(:)
66 real(kind=rp), allocatable :: z2(:)
67 real(kind=rp), allocatable :: z3(:)
68 real(kind=rp), allocatable :: tmp(:)
69 real(kind=rp), allocatable :: p1(:,:)
70 real(kind=rp), allocatable :: p2(:,:)
71 real(kind=rp), allocatable :: p3(:,:)
72 real(kind=rp), allocatable :: alpha(:)
73 type(c_ptr) :: w1_d = c_null_ptr
74 type(c_ptr) :: w2_d = c_null_ptr
75 type(c_ptr) :: w3_d = c_null_ptr
76 type(c_ptr) :: r1_d = c_null_ptr
77 type(c_ptr) :: r2_d = c_null_ptr
78 type(c_ptr) :: r3_d = c_null_ptr
79 type(c_ptr) :: z1_d = c_null_ptr
80 type(c_ptr) :: z2_d = c_null_ptr
81 type(c_ptr) :: z3_d = c_null_ptr
82 type(c_ptr) :: alpha_d = c_null_ptr
83 type(c_ptr) :: p1_d_d = c_null_ptr
84 type(c_ptr) :: p2_d_d = c_null_ptr
85 type(c_ptr) :: p3_d_d = c_null_ptr
86 type(c_ptr) :: tmp_d = c_null_ptr
87 type(c_ptr), allocatable :: p1_d(:)
88 type(c_ptr), allocatable :: p2_d(:)
89 type(c_ptr), allocatable :: p3_d(:)
90 type(c_ptr) :: gs_event1 = c_null_ptr
91 type(c_ptr) :: gs_event2 = c_null_ptr
92 type(c_ptr) :: gs_event3 = c_null_ptr
93 contains
94 procedure, pass(this) :: init => fusedcg_cpld_device_init
95 procedure, pass(this) :: free => fusedcg_cpld_device_free
96 procedure, pass(this) :: solve => fusedcg_cpld_device_solve
97 procedure, pass(this) :: solve_coupled => fusedcg_cpld_device_solve_coupled
99
100#ifdef HAVE_CUDA
101 interface
102 subroutine cuda_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
103 b1_d, b2_d, b3_d, tmp_d, n) bind(c, name='cuda_fusedcg_cpld_part1')
104 use, intrinsic :: iso_c_binding
105 import c_rp
106 implicit none
107 type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d
108 integer(c_int) :: n
109 end subroutine cuda_fusedcg_cpld_part1
110 end interface
111
112 interface
113 subroutine cuda_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
114 po1_d, po2_d, po3_d, beta, n) bind(c, name='cuda_fusedcg_cpld_update_p')
115 use, intrinsic :: iso_c_binding
116 import c_rp
117 implicit none
118 type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
119 type(c_ptr), value :: po1_d, po2_d, po3_d
120 real(c_rp) :: beta
121 integer(c_int) :: n
122 end subroutine cuda_fusedcg_cpld_update_p
123 end interface
124
125 interface
126 subroutine cuda_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, &
127 alpha, p_cur, n) bind(c, name='cuda_fusedcg_cpld_update_x')
128 use, intrinsic :: iso_c_binding
129 implicit none
130 type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
131 integer(c_int) :: p_cur, n
132 end subroutine cuda_fusedcg_cpld_update_x
133 end interface
134
135 interface
136 real(c_rp) function cuda_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
137 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) &
138 bind(c, name='cuda_fusedcg_cpld_part2')
139 use, intrinsic :: iso_c_binding
140 import c_rp
141 implicit none
142 type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
143 type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
144 real(c_rp) :: alpha
145 integer(c_int) :: n, p_cur
146 end function cuda_fusedcg_cpld_part2
147 end interface
148#elif HAVE_HIP
149 interface
150 subroutine hip_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
151 b1_d, b2_d, b3_d, tmp_d, n) bind(c, name='hip_fusedcg_cpld_part1')
152 use, intrinsic :: iso_c_binding
153 import c_rp
154 implicit none
155 type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d
156 integer(c_int) :: n
157 end subroutine hip_fusedcg_cpld_part1
158 end interface
159
160 interface
161 subroutine hip_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
162 po1_d, po2_d, po3_d, beta, n) bind(c, name='hip_fusedcg_cpld_update_p')
163 use, intrinsic :: iso_c_binding
164 import c_rp
165 implicit none
166 type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
167 type(c_ptr), value :: po1_d, po2_d, po3_d
168 real(c_rp) :: beta
169 integer(c_int) :: n
170 end subroutine hip_fusedcg_cpld_update_p
171 end interface
172
173 interface
174 subroutine hip_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, &
175 alpha, p_cur, n) bind(c, name='hip_fusedcg_cpld_update_x')
176 use, intrinsic :: iso_c_binding
177 implicit none
178 type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
179 integer(c_int) :: p_cur, n
180 end subroutine hip_fusedcg_cpld_update_x
181 end interface
182
183 interface
184 real(c_rp) function hip_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
185 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) &
186 bind(c, name='hip_fusedcg_cpld_part2')
187 use, intrinsic :: iso_c_binding
188 import c_rp
189 implicit none
190 type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
191 type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
192 real(c_rp) :: alpha
193 integer(c_int) :: n, p_cur
194 end function hip_fusedcg_cpld_part2
195 end interface
196#endif
197
198contains
199
200 subroutine device_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
201 b1_d, b2_d, b3_d, tmp_d, n)
202 type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d
203 type(c_ptr), value :: tmp_d
204 integer(c_int) :: n
205#ifdef HAVE_HIP
206 call hip_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
207#elif HAVE_CUDA
208 call cuda_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
209#else
210 call neko_error('No device backend configured')
211#endif
212 end subroutine device_fusedcg_cpld_part1
213
214 subroutine device_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
215 po1_d, po2_d, po3_d, beta, n)
216 type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
217 type(c_ptr), value :: po1_d, po2_d, po3_d
218 real(c_rp) :: beta
219 integer(c_int) :: n
220#ifdef HAVE_HIP
221 call hip_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
222 po1_d, po2_d, po3_d, beta, n)
223#elif HAVE_CUDA
224 call cuda_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
225 po1_d, po2_d, po3_d, beta, n)
226#else
227 call neko_error('No device backend configured')
228#endif
229 end subroutine device_fusedcg_cpld_update_p
230
231 subroutine device_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
232 p1_d, p2_d, p3_d, alpha, p_cur, n)
233 type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
234 integer(c_int) :: p_cur, n
235#ifdef HAVE_HIP
236 call hip_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
237 p1_d, p2_d, p3_d, alpha, p_cur, n)
238#elif HAVE_CUDA
239 call cuda_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
240 p1_d, p2_d, p3_d, alpha, p_cur, n)
241#else
242 call neko_error('No device backend configured')
243#endif
244 end subroutine device_fusedcg_cpld_update_x
245
246 function device_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
247 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) result(res)
248 type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
249 type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
250 real(c_rp) :: alpha
251 integer :: n, p_cur
252 real(kind=rp) :: res
253 integer :: ierr
254#ifdef HAVE_HIP
255 res = hip_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
256 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
257#elif HAVE_CUDA
258 res = cuda_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
259 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
260#else
261 call neko_error('No device backend configured')
262#endif
263
264#ifndef HAVE_DEVICE_MPI
265 if (pe_size .gt. 1) then
266 call mpi_allreduce(mpi_in_place, res, 1, &
267 mpi_real_precision, mpi_sum, neko_comm, ierr)
268 end if
269#endif
270
271 end function device_fusedcg_cpld_part2
272
274 subroutine fusedcg_cpld_device_init(this, n, max_iter, M, &
275 rel_tol, abs_tol, monitor)
276 class(fusedcg_cpld_device_t), target, intent(inout) :: this
277 class(pc_t), optional, intent(in), target :: M
278 integer, intent(in) :: n
279 integer, intent(in) :: max_iter
280 real(kind=rp), optional, intent(in) :: rel_tol
281 real(kind=rp), optional, intent(in) :: abs_tol
282 logical, optional, intent(in) :: monitor
283 type(c_ptr) :: ptr
284 integer(c_size_t) :: p_size
285 integer :: i
286
287 call this%free()
288
289 allocate(this%w1(n))
290 allocate(this%w2(n))
291 allocate(this%w3(n))
292 allocate(this%r1(n))
293 allocate(this%r2(n))
294 allocate(this%r3(n))
295 allocate(this%z1(n))
296 allocate(this%z2(n))
297 allocate(this%z3(n))
298 allocate(this%tmp(n))
299 allocate(this%p1(n, device_fusedcg_cpld_p_space))
300 allocate(this%p2(n, device_fusedcg_cpld_p_space))
301 allocate(this%p3(n, device_fusedcg_cpld_p_space))
302 allocate(this%p1_d(device_fusedcg_cpld_p_space))
303 allocate(this%p2_d(device_fusedcg_cpld_p_space))
304 allocate(this%p3_d(device_fusedcg_cpld_p_space))
305 allocate(this%alpha(device_fusedcg_cpld_p_space))
306
307 if (present(m)) then
308 this%M => m
309 end if
310
311 call device_map(this%w1, this%w1_d, n)
312 call device_map(this%w2, this%w2_d, n)
313 call device_map(this%w3, this%w3_d, n)
314 call device_map(this%r1, this%r1_d, n)
315 call device_map(this%r2, this%r2_d, n)
316 call device_map(this%r3, this%r3_d, n)
317 call device_map(this%z1, this%z1_d, n)
318 call device_map(this%z2, this%z2_d, n)
319 call device_map(this%z3, this%z3_d, n)
320 call device_map(this%tmp, this%tmp_d, n)
321 call device_map(this%alpha, this%alpha_d, device_fusedcg_cpld_p_space)
323 this%p1_d(i) = c_null_ptr
324 call device_map(this%p1(:,i), this%p1_d(i), n)
325
326 this%p2_d(i) = c_null_ptr
327 call device_map(this%p2(:,i), this%p2_d(i), n)
328
329 this%p3_d(i) = c_null_ptr
330 call device_map(this%p3(:,i), this%p3_d(i), n)
331 end do
332
333 p_size = c_sizeof(c_null_ptr) * (device_fusedcg_cpld_p_space)
334 call device_alloc(this%p1_d_d, p_size)
335 call device_alloc(this%p2_d_d, p_size)
336 call device_alloc(this%p3_d_d, p_size)
337 ptr = c_loc(this%p1_d)
338 call device_memcpy(ptr, this%p1_d_d, p_size, &
339 host_to_device, sync=.false.)
340 ptr = c_loc(this%p2_d)
341 call device_memcpy(ptr, this%p2_d_d, p_size, &
342 host_to_device, sync=.false.)
343 ptr = c_loc(this%p3_d)
344 call device_memcpy(ptr, this%p3_d_d, p_size, &
345 host_to_device, sync=.false.)
346 if (present(rel_tol) .and. present(abs_tol) .and. present(monitor)) then
347 call this%ksp_init(max_iter, rel_tol, abs_tol, monitor = monitor)
348 else if (present(rel_tol) .and. present(abs_tol)) then
349 call this%ksp_init(max_iter, rel_tol, abs_tol)
350 else if (present(monitor) .and. present(abs_tol)) then
351 call this%ksp_init(max_iter, abs_tol = abs_tol, monitor = monitor)
352 else if (present(rel_tol) .and. present(monitor)) then
353 call this%ksp_init(max_iter, rel_tol, monitor = monitor)
354 else if (present(rel_tol)) then
355 call this%ksp_init(max_iter, rel_tol = rel_tol)
356 else if (present(abs_tol)) then
357 call this%ksp_init(max_iter, abs_tol = abs_tol)
358 else if (present(monitor)) then
359 call this%ksp_init(max_iter, monitor = monitor)
360 else
361 call this%ksp_init(max_iter)
362 end if
363
364 call device_event_create(this%gs_event1, 2)
365 call device_event_create(this%gs_event2, 2)
366 call device_event_create(this%gs_event3, 2)
367
368 end subroutine fusedcg_cpld_device_init
369
372 class(fusedcg_cpld_device_t), intent(inout) :: this
373 integer :: i
374
375 call this%ksp_free()
376
377 if (allocated(this%w1)) then
378 deallocate(this%w1)
379 end if
380
381 if (allocated(this%w2)) then
382 deallocate(this%w2)
383 end if
384
385 if (allocated(this%w3)) then
386 deallocate(this%w3)
387 end if
388
389 if (allocated(this%r1)) then
390 deallocate(this%r1)
391 end if
392
393 if (allocated(this%r2)) then
394 deallocate(this%r2)
395 end if
396
397 if (allocated(this%r3)) then
398 deallocate(this%r3)
399 end if
400
401 if (allocated(this%z1)) then
402 deallocate(this%z1)
403 end if
404
405 if (allocated(this%z2)) then
406 deallocate(this%z2)
407 end if
408
409 if (allocated(this%z3)) then
410 deallocate(this%z3)
411 end if
412
413 if (allocated(this%tmp)) then
414 deallocate(this%tmp)
415 end if
416
417 if (allocated(this%alpha)) then
418 deallocate(this%alpha)
419 end if
420
421 if (allocated(this%p1)) then
422 deallocate(this%p1)
423 end if
424
425 if (allocated(this%p2)) then
426 deallocate(this%p2)
427 end if
428
429 if (allocated(this%p3)) then
430 deallocate(this%p3)
431 end if
432
433 if (c_associated(this%w1_d)) then
434 call device_free(this%w1_d)
435 end if
436
437 if (c_associated(this%w2_d)) then
438 call device_free(this%w2_d)
439 end if
440
441 if (c_associated(this%w3_d)) then
442 call device_free(this%w3_d)
443 end if
444
445 if (c_associated(this%r1_d)) then
446 call device_free(this%r1_d)
447 end if
448
449 if (c_associated(this%r2_d)) then
450 call device_free(this%r2_d)
451 end if
452
453 if (c_associated(this%r3_d)) then
454 call device_free(this%r3_d)
455 end if
456
457 if (c_associated(this%z1_d)) then
458 call device_free(this%z1_d)
459 end if
460
461 if (c_associated(this%z2_d)) then
462 call device_free(this%z2_d)
463 end if
464
465 if (c_associated(this%z3_d)) then
466 call device_free(this%z3_d)
467 end if
468
469 if (c_associated(this%alpha_d)) then
470 call device_free(this%alpha_d)
471 end if
472
473 if (c_associated(this%tmp_d)) then
474 call device_free(this%tmp_d)
475 end if
476
477 if (allocated(this%p1_d)) then
479 if (c_associated(this%p1_d(i))) then
480 call device_free(this%p1_d(i))
481 end if
482 end do
483 end if
484
485 if (allocated(this%p2_d)) then
487 if (c_associated(this%p2_d(i))) then
488 call device_free(this%p2_d(i))
489 end if
490 end do
491 end if
492
493 if (allocated(this%p3_d)) then
495 if (c_associated(this%p3_d(i))) then
496 call device_free(this%p3_d(i))
497 end if
498 end do
499 end if
500
501 if (c_associated(this%p1_d_d)) then
502 call device_free(this%p1_d_d)
503 end if
504
505 if (c_associated(this%p2_d_d)) then
506 call device_free(this%p2_d_d)
507 end if
508
509 if (c_associated(this%p3_d_d)) then
510 call device_free(this%p3_d_d)
511 end if
512
513 nullify(this%M)
514
515 if (c_associated(this%gs_event1)) then
516 call device_event_destroy(this%gs_event1)
517 end if
518
519 if (c_associated(this%gs_event2)) then
520 call device_event_destroy(this%gs_event2)
521 end if
522
523 if (c_associated(this%gs_event3)) then
524 call device_event_destroy(this%gs_event3)
525 end if
526
527 end subroutine fusedcg_cpld_device_free
528
530 function fusedcg_cpld_device_solve_coupled(this, Ax, x, y, z, fx, fy, fz, &
531 n, coef, blstx, blsty, blstz, gs_h, niter) result(ksp_results)
532 class(fusedcg_cpld_device_t), intent(inout) :: this
533 class(ax_t), intent(in) :: ax
534 type(field_t), intent(inout) :: x
535 type(field_t), intent(inout) :: y
536 type(field_t), intent(inout) :: z
537 integer, intent(in) :: n
538 real(kind=rp), dimension(n), intent(in) :: fx
539 real(kind=rp), dimension(n), intent(in) :: fy
540 real(kind=rp), dimension(n), intent(in) :: fz
541 type(coef_t), intent(inout) :: coef
542 type(bc_list_t), intent(inout) :: blstx
543 type(bc_list_t), intent(inout) :: blsty
544 type(bc_list_t), intent(inout) :: blstz
545 type(gs_t), intent(inout) :: gs_h
546 type(ksp_monitor_t), dimension(3) :: ksp_results
547 integer, optional, intent(in) :: niter
548 integer :: iter, max_iter, ierr, i, p_cur, p_prev
549 real(kind=rp) :: rnorm, rtr, norm_fac, rtz1, rtz2
550 real(kind=rp) :: pap, beta
551 type(c_ptr) :: fx_d
552 type(c_ptr) :: fy_d
553 type(c_ptr) :: fz_d
554
555 fx_d = device_get_ptr(fx)
556 fy_d = device_get_ptr(fy)
557 fz_d = device_get_ptr(fz)
558
559 if (present(niter)) then
560 max_iter = niter
561 else
562 max_iter = ksp_max_iter
563 end if
564 norm_fac = 1.0_rp / sqrt(coef%volume)
565
566 associate(w1 => this%w1, w2 => this%w2, w3 => this%w3, r1 => this%r1, &
567 r2 => this%r2, r3 => this%r3, p1 => this%p1, p2 => this%p2, &
568 p3 => this%p3, z1 => this%z1, z2 => this%z2, z3 => this%z3, &
569 tmp_d => this%tmp_d, alpha => this%alpha, alpha_d => this%alpha_d, &
570 w1_d => this%w1_d, w2_d => this%w2_d, w3_d => this%w3_d, &
571 r1_d => this%r1_d, r2_d => this%r2_d, r3_d => this%r3_d, &
572 z1_d => this%z1_d, z2_d => this%z2_d, z3_d => this%z3_d, &
573 p1_d => this%p1_d, p2_d => this%p2_d, p3_d => this%p3_d, &
574 p1_d_d => this%p1_d_d, p2_d_d => this%p2_d_d, p3_d_d => this%p3_d_d)
575
576 rtz1 = 1.0_rp
578 p_cur = 1
579
580
581 call device_rzero(x%x_d, n)
582 call device_rzero(y%x_d, n)
583 call device_rzero(z%x_d, n)
584 call device_rzero(p1_d(1), n)
585 call device_rzero(p2_d(1), n)
586 call device_rzero(p3_d(1), n)
587 call device_copy(r1_d, fx_d, n)
588 call device_copy(r2_d, fy_d, n)
589 call device_copy(r3_d, fz_d, n)
590
591 call device_fusedcg_cpld_part1(r1_d, r2_d, r3_d, r1_d, &
592 r2_d, r3_d, tmp_d, n)
593
594 rtr = device_glsc3(tmp_d, coef%mult_d, coef%binv_d, n)
595
596 rnorm = sqrt(rtr)*norm_fac
597 ksp_results%res_start = rnorm
598 ksp_results%res_final = rnorm
599 ksp_results(1)%iter = 0
600 ksp_results(2:3)%iter = -1
601 if(abscmp(rnorm, 0.0_rp)) then
602 ksp_results%converged = .true.
603 return
604 end if
605 call this%monitor_start('fcpldCG')
606 do iter = 1, max_iter
607 call this%M%solve(z1, r1, n)
608 call this%M%solve(z2, r2, n)
609 call this%M%solve(z3, r3, n)
610 rtz2 = rtz1
611 call device_fusedcg_cpld_part1(z1_d, z2_d, z3_d, &
612 r1_d, r2_d, r3_d, tmp_d, n)
613 rtz1 = device_glsc2(tmp_d, coef%mult_d, n)
614
615 beta = rtz1 / rtz2
616 if (iter .eq. 1) beta = 0.0_rp
617
618 call device_fusedcg_cpld_update_p(p1_d(p_cur), p2_d(p_cur), p3_d(p_cur), &
619 z1_d, z2_d, z3_d, p1_d(p_prev), p2_d(p_prev), p3_d(p_prev), beta, n)
620
621 call ax%compute_vector(w1, w2, w3, &
622 p1(1, p_cur), p2(1, p_cur), p3(1, p_cur), coef, x%msh, x%Xh)
623 call gs_h%op(w1, n, gs_op_add, this%gs_event1)
624 call device_event_sync(this%gs_event1)
625 call blstx%apply(w1, n)
626 call gs_h%op(w2, n, gs_op_add, this%gs_event2)
627 call device_event_sync(this%gs_event2)
628 call blsty%apply(w2, n)
629 call gs_h%op(w3, n, gs_op_add, this%gs_event3)
630 call device_event_sync(this%gs_event3)
631 call blstz%apply(w3, n)
632
633 call device_fusedcg_cpld_part1(w1_d, w2_d, w3_d, p1_d(p_cur), &
634 p2_d(p_cur), p3_d(p_cur), tmp_d, n)
635
636 pap = device_glsc2(tmp_d, coef%mult_d, n)
637
638 alpha(p_cur) = rtz1 / pap
639 rtr = device_fusedcg_cpld_part2(r1_d, r2_d, r3_d, coef%mult_d, &
640 w1_d, w2_d, w3_d, alpha_d, alpha(p_cur), p_cur, n)
641 rnorm = sqrt(rtr)*norm_fac
642 call this%monitor_iter(iter, rnorm)
643 if ((p_cur .eq. device_fusedcg_cpld_p_space) .or. &
644 (rnorm .lt. this%abs_tol) .or. iter .eq. max_iter) then
645 call device_fusedcg_cpld_update_x(x%x_d, y%x_d, z%x_d, &
646 p1_d_d, p2_d_d, p3_d_d, alpha_d, p_cur, n)
647 p_prev = p_cur
648 p_cur = 1
649 if (rnorm .lt. this%abs_tol) exit
650 else
651 p_prev = p_cur
652 p_cur = p_cur + 1
653 end if
654 end do
655 call this%monitor_stop()
656 ksp_results%res_final = rnorm
657 ksp_results%iter = iter
658 ksp_results%converged = this%is_converged(iter, rnorm)
659
660 end associate
661
663
665 function fusedcg_cpld_device_solve(this, Ax, x, f, n, coef, blst, &
666 gs_h, niter) result(ksp_results)
667 class(fusedcg_cpld_device_t), intent(inout) :: this
668 class(ax_t), intent(in) :: ax
669 type(field_t), intent(inout) :: x
670 integer, intent(in) :: n
671 real(kind=rp), dimension(n), intent(in) :: f
672 type(coef_t), intent(inout) :: coef
673 type(bc_list_t), intent(inout) :: blst
674 type(gs_t), intent(inout) :: gs_h
675 type(ksp_monitor_t) :: ksp_results
676 integer, optional, intent(in) :: niter
677
678 ! Throw and error
679 call neko_error('The cpldcg solver is only defined for coupled solves')
680
681 ksp_results%res_final = 0.0
682 ksp_results%iter = 0
683 ksp_results%converged = .false.
684
685 end function fusedcg_cpld_device_solve
686
687end module fusedcg_cpld_device
__device__ T solve(const T u, const T y, const T guess, const T nu, const T kappa, const T B)
void hip_fusedcg_cpld_update_x(void *x1, void *x2, void *x3, void *p1, void *p2, void *p3, void *alpha, int *p_cur, int *n)
void hip_fusedcg_cpld_update_p(void *p1, void *p2, void *p3, void *z1, void *z2, void *z3, void *po1, void *po2, void *po3, real *beta, int *n)
real hip_fusedcg_cpld_part2(void *a1, void *a2, void *a3, void *b, void *c1, void *c2, void *c3, void *alpha_d, real *alpha, int *p_cur, int *n)
void hip_fusedcg_cpld_part1(void *a1, void *a2, void *a3, void *b1, void *b2, void *b3, void *tmp, int *n)
Return the device pointer for an associated Fortran array.
Definition device.F90:101
Map a Fortran array to a device (allocate and associate)
Definition device.F90:77
Copy data between host and device (or device and device)
Definition device.F90:71
Defines a Matrix-vector product.
Definition ax.f90:34
Defines a list of bc_t.
Definition bc_list.f90:34
Coefficients.
Definition coef.f90:34
Definition comm.F90:1
type(mpi_datatype), public mpi_real_precision
MPI type for working precision of REAL types.
Definition comm.F90:51
integer, public pe_size
MPI size of communicator.
Definition comm.F90:59
type(mpi_comm), public neko_comm
MPI communicator.
Definition comm.F90:43
subroutine, public device_rzero(a_d, n, strm)
Zero a real vector.
subroutine, public device_copy(a_d, b_d, n, strm)
Copy a vector .
real(kind=rp) function, public device_glsc3(a_d, b_d, c_d, n, strm)
Weighted inner product .
real(kind=rp) function, public device_glsc2(a_d, b_d, n, strm)
Weighted inner product .
Device abstraction, common interface for various accelerators.
Definition device.F90:34
subroutine, public device_event_sync(event)
Synchronize an event.
Definition device.F90:1314
integer, parameter, public host_to_device
Definition device.F90:47
subroutine, public device_free(x_d)
Deallocate memory on the device.
Definition device.F90:219
subroutine, public device_event_destroy(event)
Destroy a device event.
Definition device.F90:1279
subroutine, public device_alloc(x_d, s)
Allocate memory on the device.
Definition device.F90:192
subroutine, public device_event_create(event, flags)
Create a device event queue.
Definition device.F90:1249
Defines a field.
Definition field.f90:34
Defines a fused Conjugate Gradient method for accelerators.
subroutine device_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha, p_cur, n)
type(ksp_monitor_t) function fusedcg_cpld_device_solve(this, ax, x, f, n, coef, blst, gs_h, niter)
Pipelined PCG solve.
subroutine fusedcg_cpld_device_free(this)
Deallocate a pipelined PCG solver.
subroutine device_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, po1_d, po2_d, po3_d, beta, n)
subroutine fusedcg_cpld_device_init(this, n, max_iter, m, rel_tol, abs_tol, monitor)
Initialise a fused PCG solver.
real(kind=rp) function device_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
type(ksp_monitor_t) function, dimension(3) fusedcg_cpld_device_solve_coupled(this, ax, x, y, z, fx, fy, fz, n, coef, blstx, blsty, blstz, gs_h, niter)
Pipelined PCG solve coupled solve.
integer, parameter device_fusedcg_cpld_p_space
subroutine device_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
Gather-scatter.
Implements the base abstract type for Krylov solvers plus helper types.
Definition krylov.f90:34
integer, parameter, public ksp_max_iter
Maximum number of iters.
Definition krylov.f90:51
Definition math.f90:60
real(kind=rp) function, public glsc3(a, b, c, n)
Weighted inner product .
Definition math.f90:1067
subroutine, public copy(a, b, n)
Copy a vector .
Definition math.f90:249
subroutine, public rzero(a, n)
Zero a real vector.
Definition math.f90:205
integer, parameter, public c_rp
Definition num_types.f90:13
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Krylov preconditioner.
Definition precon.f90:34
Utilities.
Definition utils.f90:35
Base type for a matrix-vector product providing .
Definition ax.f90:43
A list of allocatable `bc_t`. Follows the standard interface of lists.
Definition bc_list.f90:48
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition coef.f90:55
Fused preconditioned conjugate gradient method.
Type for storing initial and final residuals in a Krylov solver.
Definition krylov.f90:56
Base abstract type for a canonical Krylov method, solving .
Definition krylov.f90:73
Defines a canonical Krylov preconditioner.
Definition precon.f90:40