Loading [MathJax]/extensions/tex2jax.js
Neko 0.9.99
A portable framework for high-order spectral element flow simulations
All Classes Namespaces Files Functions Variables Typedefs Enumerator Macros Pages
fusedcg_cpld_device.F90
Go to the documentation of this file.
1! Copyright (c) 2021-2024, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
36 use precon, only : pc_t
37 use ax_product, only : ax_t
38 use num_types, only: rp, c_rp
39 use field, only : field_t
40 use coefs, only : coef_t
41 use gather_scatter, only : gs_t, gs_op_add
42 use bc_list, only : bc_list_t
43 use math, only : glsc3, rzero, copy, abscmp
45 use device
46 use utils, only : neko_error
47 use comm, only : neko_comm, mpi_in_place, mpi_allreduce, &
49 use, intrinsic :: iso_c_binding, only : c_ptr, c_null_ptr, &
50 c_associated, c_size_t, c_sizeof, c_int, c_loc
51 implicit none
52 private
53
54 integer, parameter :: device_fusedcg_cpld_p_space = 10
55
57 type, public, extends(ksp_t) :: fusedcg_cpld_device_t
58 real(kind=rp), allocatable :: w1(:)
59 real(kind=rp), allocatable :: w2(:)
60 real(kind=rp), allocatable :: w3(:)
61 real(kind=rp), allocatable :: r1(:)
62 real(kind=rp), allocatable :: r2(:)
63 real(kind=rp), allocatable :: r3(:)
64 real(kind=rp), allocatable :: z1(:)
65 real(kind=rp), allocatable :: z2(:)
66 real(kind=rp), allocatable :: z3(:)
67 real(kind=rp), allocatable :: tmp(:)
68 real(kind=rp), allocatable :: p1(:,:)
69 real(kind=rp), allocatable :: p2(:,:)
70 real(kind=rp), allocatable :: p3(:,:)
71 real(kind=rp), allocatable :: alpha(:)
72 type(c_ptr) :: w1_d = c_null_ptr
73 type(c_ptr) :: w2_d = c_null_ptr
74 type(c_ptr) :: w3_d = c_null_ptr
75 type(c_ptr) :: r1_d = c_null_ptr
76 type(c_ptr) :: r2_d = c_null_ptr
77 type(c_ptr) :: r3_d = c_null_ptr
78 type(c_ptr) :: z1_d = c_null_ptr
79 type(c_ptr) :: z2_d = c_null_ptr
80 type(c_ptr) :: z3_d = c_null_ptr
81 type(c_ptr) :: alpha_d = c_null_ptr
82 type(c_ptr) :: p1_d_d = c_null_ptr
83 type(c_ptr) :: p2_d_d = c_null_ptr
84 type(c_ptr) :: p3_d_d = c_null_ptr
85 type(c_ptr) :: tmp_d = c_null_ptr
86 type(c_ptr), allocatable :: p1_d(:)
87 type(c_ptr), allocatable :: p2_d(:)
88 type(c_ptr), allocatable :: p3_d(:)
89 type(c_ptr) :: gs_event1 = c_null_ptr
90 type(c_ptr) :: gs_event2 = c_null_ptr
91 type(c_ptr) :: gs_event3 = c_null_ptr
92 contains
93 procedure, pass(this) :: init => fusedcg_cpld_device_init
94 procedure, pass(this) :: free => fusedcg_cpld_device_free
95 procedure, pass(this) :: solve => fusedcg_cpld_device_solve
96 procedure, pass(this) :: solve_coupled => fusedcg_cpld_device_solve_coupled
98
99#ifdef HAVE_CUDA
100 interface
101 subroutine cuda_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
102 b1_d, b2_d, b3_d, tmp_d, n) bind(c, name='cuda_fusedcg_cpld_part1')
103 use, intrinsic :: iso_c_binding
104 import c_rp
105 implicit none
106 type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d
107 integer(c_int) :: n
108 end subroutine cuda_fusedcg_cpld_part1
109 end interface
110
111 interface
112 subroutine cuda_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
113 po1_d, po2_d, po3_d, beta, n) bind(c, name='cuda_fusedcg_cpld_update_p')
114 use, intrinsic :: iso_c_binding
115 import c_rp
116 implicit none
117 type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
118 type(c_ptr), value :: po1_d, po2_d, po3_d
119 real(c_rp) :: beta
120 integer(c_int) :: n
121 end subroutine cuda_fusedcg_cpld_update_p
122 end interface
123
124 interface
125 subroutine cuda_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, &
126 alpha, p_cur, n) bind(c, name='cuda_fusedcg_cpld_update_x')
127 use, intrinsic :: iso_c_binding
128 implicit none
129 type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
130 integer(c_int) :: p_cur, n
131 end subroutine cuda_fusedcg_cpld_update_x
132 end interface
133
134 interface
135 real(c_rp) function cuda_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
136 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) &
137 bind(c, name='cuda_fusedcg_cpld_part2')
138 use, intrinsic :: iso_c_binding
139 import c_rp
140 implicit none
141 type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
142 type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
143 real(c_rp) :: alpha
144 integer(c_int) :: n, p_cur
145 end function cuda_fusedcg_cpld_part2
146 end interface
147#elif HAVE_HIP
148 interface
149 subroutine hip_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
150 b1_d, b2_d, b3_d, tmp_d, n) bind(c, name='hip_fusedcg_cpld_part1')
151 use, intrinsic :: iso_c_binding
152 import c_rp
153 implicit none
154 type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d
155 integer(c_int) :: n
156 end subroutine hip_fusedcg_cpld_part1
157 end interface
158
159 interface
160 subroutine hip_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
161 po1_d, po2_d, po3_d, beta, n) bind(c, name='hip_fusedcg_cpld_update_p')
162 use, intrinsic :: iso_c_binding
163 import c_rp
164 implicit none
165 type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
166 type(c_ptr), value :: po1_d, po2_d, po3_d
167 real(c_rp) :: beta
168 integer(c_int) :: n
169 end subroutine hip_fusedcg_cpld_update_p
170 end interface
171
172 interface
173 subroutine hip_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, &
174 alpha, p_cur, n) bind(c, name='hip_fusedcg_cpld_update_x')
175 use, intrinsic :: iso_c_binding
176 implicit none
177 type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
178 integer(c_int) :: p_cur, n
179 end subroutine hip_fusedcg_cpld_update_x
180 end interface
181
182 interface
183 real(c_rp) function hip_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
184 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) &
185 bind(c, name='hip_fusedcg_cpld_part2')
186 use, intrinsic :: iso_c_binding
187 import c_rp
188 implicit none
189 type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
190 type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
191 real(c_rp) :: alpha
192 integer(c_int) :: n, p_cur
193 end function hip_fusedcg_cpld_part2
194 end interface
195#endif
196
197contains
198
199 subroutine device_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
200 b1_d, b2_d, b3_d, tmp_d, n)
201 type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d
202 type(c_ptr), value :: tmp_d
203 integer(c_int) :: n
204#ifdef HAVE_HIP
205 call hip_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
206#elif HAVE_CUDA
207 call cuda_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
208#else
209 call neko_error('No device backend configured')
210#endif
211 end subroutine device_fusedcg_cpld_part1
212
213 subroutine device_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
214 po1_d, po2_d, po3_d, beta, n)
215 type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
216 type(c_ptr), value :: po1_d, po2_d, po3_d
217 real(c_rp) :: beta
218 integer(c_int) :: n
219#ifdef HAVE_HIP
220 call hip_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
221 po1_d, po2_d, po3_d, beta, n)
222#elif HAVE_CUDA
223 call cuda_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
224 po1_d, po2_d, po3_d, beta, n)
225#else
226 call neko_error('No device backend configured')
227#endif
228 end subroutine device_fusedcg_cpld_update_p
229
230 subroutine device_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
231 p1_d, p2_d, p3_d, alpha, p_cur, n)
232 type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
233 integer(c_int) :: p_cur, n
234#ifdef HAVE_HIP
235 call hip_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
236 p1_d, p2_d, p3_d, alpha, p_cur, n)
237#elif HAVE_CUDA
238 call cuda_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
239 p1_d, p2_d, p3_d, alpha, p_cur, n)
240#else
241 call neko_error('No device backend configured')
242#endif
243 end subroutine device_fusedcg_cpld_update_x
244
245 function device_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
246 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) result(res)
247 type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
248 type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
249 real(c_rp) :: alpha
250 integer :: n, p_cur
251 real(kind=rp) :: res
252 integer :: ierr
253#ifdef HAVE_HIP
254 res = hip_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
255 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
256#elif HAVE_CUDA
257 res = cuda_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
258 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
259#else
260 call neko_error('No device backend configured')
261#endif
262
263#ifndef HAVE_DEVICE_MPI
264 if (pe_size .gt. 1) then
265 call mpi_allreduce(mpi_in_place, res, 1, &
266 mpi_real_precision, mpi_sum, neko_comm, ierr)
267 end if
268#endif
269
270 end function device_fusedcg_cpld_part2
271
273 subroutine fusedcg_cpld_device_init(this, n, max_iter, M, &
274 rel_tol, abs_tol, monitor)
275 class(fusedcg_cpld_device_t), target, intent(inout) :: this
276 class(pc_t), optional, intent(in), target :: M
277 integer, intent(in) :: n
278 integer, intent(in) :: max_iter
279 real(kind=rp), optional, intent(in) :: rel_tol
280 real(kind=rp), optional, intent(in) :: abs_tol
281 logical, optional, intent(in) :: monitor
282 type(c_ptr) :: ptr
283 integer(c_size_t) :: p_size
284 integer :: i
285
286 call this%free()
287
288 allocate(this%w1(n))
289 allocate(this%w2(n))
290 allocate(this%w3(n))
291 allocate(this%r1(n))
292 allocate(this%r2(n))
293 allocate(this%r3(n))
294 allocate(this%z1(n))
295 allocate(this%z2(n))
296 allocate(this%z3(n))
297 allocate(this%tmp(n))
298 allocate(this%p1(n, device_fusedcg_cpld_p_space))
299 allocate(this%p2(n, device_fusedcg_cpld_p_space))
300 allocate(this%p3(n, device_fusedcg_cpld_p_space))
301 allocate(this%p1_d(device_fusedcg_cpld_p_space))
302 allocate(this%p2_d(device_fusedcg_cpld_p_space))
303 allocate(this%p3_d(device_fusedcg_cpld_p_space))
304 allocate(this%alpha(device_fusedcg_cpld_p_space))
305
306 if (present(m)) then
307 this%M => m
308 end if
309
310 call device_map(this%w1, this%w1_d, n)
311 call device_map(this%w2, this%w2_d, n)
312 call device_map(this%w3, this%w3_d, n)
313 call device_map(this%r1, this%r1_d, n)
314 call device_map(this%r2, this%r2_d, n)
315 call device_map(this%r3, this%r3_d, n)
316 call device_map(this%z1, this%z1_d, n)
317 call device_map(this%z2, this%z2_d, n)
318 call device_map(this%z3, this%z3_d, n)
319 call device_map(this%tmp, this%tmp_d, n)
320 call device_map(this%alpha, this%alpha_d, device_fusedcg_cpld_p_space)
322 this%p1_d(i) = c_null_ptr
323 call device_map(this%p1(:,i), this%p1_d(i), n)
324
325 this%p2_d(i) = c_null_ptr
326 call device_map(this%p2(:,i), this%p2_d(i), n)
327
328 this%p3_d(i) = c_null_ptr
329 call device_map(this%p3(:,i), this%p3_d(i), n)
330 end do
331
332 p_size = c_sizeof(c_null_ptr) * (device_fusedcg_cpld_p_space)
333 call device_alloc(this%p1_d_d, p_size)
334 call device_alloc(this%p2_d_d, p_size)
335 call device_alloc(this%p3_d_d, p_size)
336 ptr = c_loc(this%p1_d)
337 call device_memcpy(ptr, this%p1_d_d, p_size, &
338 host_to_device, sync=.false.)
339 ptr = c_loc(this%p2_d)
340 call device_memcpy(ptr, this%p2_d_d, p_size, &
341 host_to_device, sync=.false.)
342 ptr = c_loc(this%p3_d)
343 call device_memcpy(ptr, this%p3_d_d, p_size, &
344 host_to_device, sync=.false.)
345 if (present(rel_tol) .and. present(abs_tol) .and. present(monitor)) then
346 call this%ksp_init(max_iter, rel_tol, abs_tol, monitor = monitor)
347 else if (present(rel_tol) .and. present(abs_tol)) then
348 call this%ksp_init(max_iter, rel_tol, abs_tol)
349 else if (present(monitor) .and. present(abs_tol)) then
350 call this%ksp_init(max_iter, abs_tol = abs_tol, monitor = monitor)
351 else if (present(rel_tol) .and. present(monitor)) then
352 call this%ksp_init(max_iter, rel_tol, monitor = monitor)
353 else if (present(rel_tol)) then
354 call this%ksp_init(max_iter, rel_tol = rel_tol)
355 else if (present(abs_tol)) then
356 call this%ksp_init(max_iter, abs_tol = abs_tol)
357 else if (present(monitor)) then
358 call this%ksp_init(max_iter, monitor = monitor)
359 else
360 call this%ksp_init(max_iter)
361 end if
362
363 call device_event_create(this%gs_event1, 2)
364 call device_event_create(this%gs_event2, 2)
365 call device_event_create(this%gs_event3, 2)
366
367 end subroutine fusedcg_cpld_device_init
368
371 class(fusedcg_cpld_device_t), intent(inout) :: this
372 integer :: i
373
374 call this%ksp_free()
375
376 if (allocated(this%w1)) then
377 deallocate(this%w1)
378 end if
379
380 if (allocated(this%w2)) then
381 deallocate(this%w2)
382 end if
383
384 if (allocated(this%w3)) then
385 deallocate(this%w3)
386 end if
387
388 if (allocated(this%r1)) then
389 deallocate(this%r1)
390 end if
391
392 if (allocated(this%r2)) then
393 deallocate(this%r2)
394 end if
395
396 if (allocated(this%r3)) then
397 deallocate(this%r3)
398 end if
399
400 if (allocated(this%z1)) then
401 deallocate(this%z1)
402 end if
403
404 if (allocated(this%z2)) then
405 deallocate(this%z2)
406 end if
407
408 if (allocated(this%z3)) then
409 deallocate(this%z3)
410 end if
411
412 if (allocated(this%tmp)) then
413 deallocate(this%tmp)
414 end if
415
416 if (allocated(this%alpha)) then
417 deallocate(this%alpha)
418 end if
419
420 if (allocated(this%p1)) then
421 deallocate(this%p1)
422 end if
423
424 if (allocated(this%p2)) then
425 deallocate(this%p2)
426 end if
427
428 if (allocated(this%p3)) then
429 deallocate(this%p3)
430 end if
431
432 if (c_associated(this%w1_d)) then
433 call device_free(this%w1_d)
434 end if
435
436 if (c_associated(this%w2_d)) then
437 call device_free(this%w2_d)
438 end if
439
440 if (c_associated(this%w3_d)) then
441 call device_free(this%w3_d)
442 end if
443
444 if (c_associated(this%r1_d)) then
445 call device_free(this%r1_d)
446 end if
447
448 if (c_associated(this%r2_d)) then
449 call device_free(this%r2_d)
450 end if
451
452 if (c_associated(this%r3_d)) then
453 call device_free(this%r3_d)
454 end if
455
456 if (c_associated(this%z1_d)) then
457 call device_free(this%z1_d)
458 end if
459
460 if (c_associated(this%z2_d)) then
461 call device_free(this%z2_d)
462 end if
463
464 if (c_associated(this%z3_d)) then
465 call device_free(this%z3_d)
466 end if
467
468 if (c_associated(this%alpha_d)) then
469 call device_free(this%alpha_d)
470 end if
471
472 if (c_associated(this%tmp_d)) then
473 call device_free(this%tmp_d)
474 end if
475
476 if (allocated(this%p1_d)) then
478 if (c_associated(this%p1_d(i))) then
479 call device_free(this%p1_d(i))
480 end if
481 end do
482 end if
483
484 if (allocated(this%p2_d)) then
486 if (c_associated(this%p2_d(i))) then
487 call device_free(this%p2_d(i))
488 end if
489 end do
490 end if
491
492 if (allocated(this%p3_d)) then
494 if (c_associated(this%p3_d(i))) then
495 call device_free(this%p3_d(i))
496 end if
497 end do
498 end if
499
500 nullify(this%M)
501
502 if (c_associated(this%gs_event1)) then
503 call device_event_destroy(this%gs_event1)
504 end if
505
506 if (c_associated(this%gs_event2)) then
507 call device_event_destroy(this%gs_event2)
508 end if
509
510 if (c_associated(this%gs_event3)) then
511 call device_event_destroy(this%gs_event3)
512 end if
513
514 end subroutine fusedcg_cpld_device_free
515
517 function fusedcg_cpld_device_solve_coupled(this, Ax, x, y, z, fx, fy, fz, &
518 n, coef, blstx, blsty, blstz, gs_h, niter) result(ksp_results)
519 class(fusedcg_cpld_device_t), intent(inout) :: this
520 class(ax_t), intent(in) :: ax
521 type(field_t), intent(inout) :: x
522 type(field_t), intent(inout) :: y
523 type(field_t), intent(inout) :: z
524 integer, intent(in) :: n
525 real(kind=rp), dimension(n), intent(in) :: fx
526 real(kind=rp), dimension(n), intent(in) :: fy
527 real(kind=rp), dimension(n), intent(in) :: fz
528 type(coef_t), intent(inout) :: coef
529 type(bc_list_t), intent(inout) :: blstx
530 type(bc_list_t), intent(inout) :: blsty
531 type(bc_list_t), intent(inout) :: blstz
532 type(gs_t), intent(inout) :: gs_h
533 type(ksp_monitor_t), dimension(3) :: ksp_results
534 integer, optional, intent(in) :: niter
535 integer :: iter, max_iter, ierr, i, p_cur, p_prev
536 real(kind=rp) :: rnorm, rtr, norm_fac, rtz1, rtz2
537 real(kind=rp) :: pap, beta
538 type(c_ptr) :: fx_d
539 type(c_ptr) :: fy_d
540 type(c_ptr) :: fz_d
541
542 fx_d = device_get_ptr(fx)
543 fy_d = device_get_ptr(fy)
544 fz_d = device_get_ptr(fz)
545
546 if (present(niter)) then
547 max_iter = niter
548 else
549 max_iter = ksp_max_iter
550 end if
551 norm_fac = 1.0_rp / sqrt(coef%volume)
552
553 associate(w1 => this%w1, w2 => this%w2, w3 => this%w3, r1 => this%r1, &
554 r2 => this%r2, r3 => this%r3, p1 => this%p1, p2 => this%p2, &
555 p3 => this%p3, z1 => this%z1, z2 => this%z2, z3 => this%z3, &
556 tmp_d => this%tmp_d, alpha => this%alpha, alpha_d => this%alpha_d, &
557 w1_d => this%w1_d, w2_d => this%w2_d, w3_d => this%w3_d, &
558 r1_d => this%r1_d, r2_d => this%r2_d, r3_d => this%r3_d, &
559 z1_d => this%z1_d, z2_d => this%z2_d, z3_d => this%z3_d, &
560 p1_d => this%p1_d, p2_d => this%p2_d, p3_d => this%p3_d, &
561 p1_d_d => this%p1_d_d, p2_d_d => this%p2_d_d, p3_d_d => this%p3_d_d)
562
563 rtz1 = 1.0_rp
565 p_cur = 1
566
567
568 call device_rzero(x%x_d, n)
569 call device_rzero(y%x_d, n)
570 call device_rzero(z%x_d, n)
571 call device_rzero(p1_d(1), n)
572 call device_rzero(p2_d(1), n)
573 call device_rzero(p3_d(1), n)
574 call device_copy(r1_d, fx_d, n)
575 call device_copy(r2_d, fy_d, n)
576 call device_copy(r3_d, fz_d, n)
577
578 call device_fusedcg_cpld_part1(r1_d, r2_d, r3_d, r1_d, &
579 r2_d, r3_d, tmp_d, n)
580
581 rtr = device_glsc3(tmp_d, coef%mult_d, coef%binv_d, n)
582
583 rnorm = sqrt(rtr)*norm_fac
584 ksp_results%res_start = rnorm
585 ksp_results%res_final = rnorm
586 ksp_results(1)%iter = 0
587 ksp_results(2:3)%iter = -1
588 if(abscmp(rnorm, 0.0_rp)) return
589 call this%monitor_start('fcpldCG')
590 do iter = 1, max_iter
591 call this%M%solve(z1, r1, n)
592 call this%M%solve(z2, r2, n)
593 call this%M%solve(z3, r3, n)
594 rtz2 = rtz1
595 call device_fusedcg_cpld_part1(z1_d, z2_d, z3_d, &
596 r1_d, r2_d, r3_d, tmp_d, n)
597 rtz1 = device_glsc2(tmp_d, coef%mult_d, n)
598
599 beta = rtz1 / rtz2
600 if (iter .eq. 1) beta = 0.0_rp
601
602 call device_fusedcg_cpld_update_p(p1_d(p_cur), p2_d(p_cur), p3_d(p_cur), &
603 z1_d, z2_d, z3_d, p1_d(p_prev), p2_d(p_prev), p3_d(p_prev), beta, n)
604
605 call ax%compute_vector(w1, w2, w3, &
606 p1(1, p_cur), p2(1, p_cur), p3(1, p_cur), coef, x%msh, x%Xh)
607 call gs_h%op(w1, n, gs_op_add, this%gs_event1)
608 call gs_h%op(w2, n, gs_op_add, this%gs_event2)
609 call gs_h%op(w3, n, gs_op_add, this%gs_event3)
610 call device_event_sync(this%gs_event1)
611 call device_event_sync(this%gs_event2)
612 call device_event_sync(this%gs_event3)
613 call blstx%apply(w1, n)
614 call blsty%apply(w2, n)
615 call blstz%apply(w3, n)
616
617 call device_fusedcg_cpld_part1(w1_d, w2_d, w3_d, p1_d(p_cur), &
618 p2_d(p_cur), p3_d(p_cur), tmp_d, n)
619
620 pap = device_glsc2(tmp_d, coef%mult_d, n)
621
622 alpha(p_cur) = rtz1 / pap
623 rtr = device_fusedcg_cpld_part2(r1_d, r2_d, r3_d, coef%mult_d, &
624 w1_d, w2_d, w3_d, alpha_d, alpha(p_cur), p_cur, n)
625 rnorm = sqrt(rtr)*norm_fac
626 call this%monitor_iter(iter, rnorm)
627 if ((p_cur .eq. device_fusedcg_cpld_p_space) .or. &
628 (rnorm .lt. this%abs_tol) .or. iter .eq. max_iter) then
629 call device_fusedcg_cpld_update_x(x%x_d, y%x_d, z%x_d, &
630 p1_d_d, p2_d_d, p3_d_d, alpha_d, p_cur, n)
631 p_prev = p_cur
632 p_cur = 1
633 if (rnorm .lt. this%abs_tol) exit
634 else
635 p_prev = p_cur
636 p_cur = p_cur + 1
637 end if
638 end do
639 call this%monitor_stop()
640 ksp_results%res_final = rnorm
641 ksp_results%iter = iter
642 ksp_results%converged = this%is_converged(iter, rnorm)
643
644 end associate
645
647
649 function fusedcg_cpld_device_solve(this, Ax, x, f, n, coef, blst, &
650 gs_h, niter) result(ksp_results)
651 class(fusedcg_cpld_device_t), intent(inout) :: this
652 class(ax_t), intent(in) :: ax
653 type(field_t), intent(inout) :: x
654 integer, intent(in) :: n
655 real(kind=rp), dimension(n), intent(in) :: f
656 type(coef_t), intent(inout) :: coef
657 type(bc_list_t), intent(inout) :: blst
658 type(gs_t), intent(inout) :: gs_h
659 type(ksp_monitor_t) :: ksp_results
660 integer, optional, intent(in) :: niter
661
662 ! Throw and error
663 call neko_error('The cpldcg solver is only defined for coupled solves')
664
665 ksp_results%res_final = 0.0
666 ksp_results%iter = 0
667 ksp_results%converged = .false.
668
669 end function fusedcg_cpld_device_solve
670
671end module fusedcg_cpld_device
void hip_fusedcg_cpld_update_x(void *x1, void *x2, void *x3, void *p1, void *p2, void *p3, void *alpha, int *p_cur, int *n)
void hip_fusedcg_cpld_update_p(void *p1, void *p2, void *p3, void *z1, void *z2, void *z3, void *po1, void *po2, void *po3, real *beta, int *n)
real hip_fusedcg_cpld_part2(void *a1, void *a2, void *a3, void *b, void *c1, void *c2, void *c3, void *alpha_d, real *alpha, int *p_cur, int *n)
void hip_fusedcg_cpld_part1(void *a1, void *a2, void *a3, void *b1, void *b2, void *b3, void *tmp, int *n)
Return the device pointer for an associated Fortran array.
Definition device.F90:95
Map a Fortran array to a device (allocate and associate)
Definition device.F90:71
Copy data between host and device (or device and device)
Definition device.F90:65
Defines a Matrix-vector product.
Definition ax.f90:34
Defines a list of bc_t.
Definition bc_list.f90:34
Coefficients.
Definition coef.f90:34
Definition comm.F90:1
type(mpi_comm) neko_comm
MPI communicator.
Definition comm.F90:38
type(mpi_datatype) mpi_real_precision
MPI type for working precision of REAL types.
Definition comm.F90:45
integer pe_size
MPI size of communicator.
Definition comm.F90:53
subroutine, public device_rzero(a_d, n)
Zero a real vector.
real(kind=rp) function, public device_glsc2(a_d, b_d, n)
Weighted inner product .
real(kind=rp) function, public device_glsc3(a_d, b_d, c_d, n)
Weighted inner product .
subroutine, public device_copy(a_d, b_d, n)
Copy a vector .
Device abstraction, common interface for various accelerators.
Definition device.F90:34
subroutine, public device_event_sync(event)
Synchronize an event.
Definition device.F90:1244
integer, parameter, public host_to_device
Definition device.F90:46
subroutine, public device_free(x_d)
Deallocate memory on the device.
Definition device.F90:200
subroutine, public device_event_destroy(event)
Destroy a device event.
Definition device.F90:1209
subroutine, public device_alloc(x_d, s)
Allocate memory on the device.
Definition device.F90:179
subroutine, public device_event_create(event, flags)
Create a device event queue.
Definition device.F90:1179
Defines a field.
Definition field.f90:34
Defines a fused Conjugate Gradient method for accelerators.
subroutine device_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha, p_cur, n)
type(ksp_monitor_t) function fusedcg_cpld_device_solve(this, ax, x, f, n, coef, blst, gs_h, niter)
Pipelined PCG solve.
subroutine fusedcg_cpld_device_free(this)
Deallocate a pipelined PCG solver.
subroutine device_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, po1_d, po2_d, po3_d, beta, n)
subroutine fusedcg_cpld_device_init(this, n, max_iter, m, rel_tol, abs_tol, monitor)
Initialise a fused PCG solver.
real(kind=rp) function device_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
type(ksp_monitor_t) function, dimension(3) fusedcg_cpld_device_solve_coupled(this, ax, x, y, z, fx, fy, fz, n, coef, blstx, blsty, blstz, gs_h, niter)
Pipelined PCG solve coupled solve.
integer, parameter device_fusedcg_cpld_p_space
subroutine device_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
Gather-scatter.
Implements the base abstract type for Krylov solvers plus helper types.
Definition krylov.f90:34
integer, parameter, public ksp_max_iter
Maximum number of iters.
Definition krylov.f90:51
Definition math.f90:60
real(kind=rp) function, public glsc3(a, b, c, n)
Weighted inner product .
Definition math.f90:894
subroutine, public copy(a, b, n)
Copy a vector .
Definition math.f90:238
subroutine, public rzero(a, n)
Zero a real vector.
Definition math.f90:194
integer, parameter, public c_rp
Definition num_types.f90:13
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Krylov preconditioner.
Definition precon.f90:34
Utilities.
Definition utils.f90:35
Base type for a matrix-vector product providing .
Definition ax.f90:43
A list of allocatable `bc_t`. Follows the standard interface of lists.
Definition bc_list.f90:47
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition coef.f90:55
Fused preconditioned conjugate gradient method.
Type for storing initial and final residuals in a Krylov solver.
Definition krylov.f90:56
Base abstract type for a canonical Krylov method, solving .
Definition krylov.f90:68
Defines a canonical Krylov preconditioner.
Definition precon.f90:40