Neko 1.99.2
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
fusedcg_cpld_device.F90
Go to the documentation of this file.
1! Copyright (c) 2021-2024, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
36 use precon, only : pc_t
37 use ax_product, only : ax_t
38 use num_types, only: rp, c_rp
39 use field, only : field_t
40 use coefs, only : coef_t
41 use gather_scatter, only : gs_t, gs_op_add
42 use bc_list, only : bc_list_t
43 use math, only : glsc3, rzero, copy, abscmp
45 use device
46 use utils, only : neko_error
48 use mpi_f08, only : mpi_in_place, mpi_allreduce, &
49 mpi_sum
50 use operators, only : rotate_cyc
51 use, intrinsic :: iso_c_binding, only : c_ptr, c_null_ptr, &
52 c_associated, c_size_t, c_sizeof, c_int, c_loc
53 implicit none
54 private
55
56 integer, parameter :: device_fusedcg_cpld_p_space = 10
57
59 type, public, extends(ksp_t) :: fusedcg_cpld_device_t
60 real(kind=rp), allocatable :: w1(:)
61 real(kind=rp), allocatable :: w2(:)
62 real(kind=rp), allocatable :: w3(:)
63 real(kind=rp), allocatable :: r1(:)
64 real(kind=rp), allocatable :: r2(:)
65 real(kind=rp), allocatable :: r3(:)
66 real(kind=rp), allocatable :: z1(:)
67 real(kind=rp), allocatable :: z2(:)
68 real(kind=rp), allocatable :: z3(:)
69 real(kind=rp), allocatable :: tmp(:)
70 real(kind=rp), allocatable :: p1(:,:)
71 real(kind=rp), allocatable :: p2(:,:)
72 real(kind=rp), allocatable :: p3(:,:)
73 real(kind=rp), allocatable :: alpha(:)
74 type(c_ptr) :: w1_d = c_null_ptr
75 type(c_ptr) :: w2_d = c_null_ptr
76 type(c_ptr) :: w3_d = c_null_ptr
77 type(c_ptr) :: r1_d = c_null_ptr
78 type(c_ptr) :: r2_d = c_null_ptr
79 type(c_ptr) :: r3_d = c_null_ptr
80 type(c_ptr) :: z1_d = c_null_ptr
81 type(c_ptr) :: z2_d = c_null_ptr
82 type(c_ptr) :: z3_d = c_null_ptr
83 type(c_ptr) :: alpha_d = c_null_ptr
84 type(c_ptr) :: p1_d_d = c_null_ptr
85 type(c_ptr) :: p2_d_d = c_null_ptr
86 type(c_ptr) :: p3_d_d = c_null_ptr
87 type(c_ptr) :: tmp_d = c_null_ptr
88 type(c_ptr), allocatable :: p1_d(:)
89 type(c_ptr), allocatable :: p2_d(:)
90 type(c_ptr), allocatable :: p3_d(:)
91 type(c_ptr) :: gs_event1 = c_null_ptr
92 type(c_ptr) :: gs_event2 = c_null_ptr
93 type(c_ptr) :: gs_event3 = c_null_ptr
94 contains
95 procedure, pass(this) :: init => fusedcg_cpld_device_init
96 procedure, pass(this) :: free => fusedcg_cpld_device_free
97 procedure, pass(this) :: solve => fusedcg_cpld_device_solve
98 procedure, pass(this) :: solve_coupled => fusedcg_cpld_device_solve_coupled
100
101#ifdef HAVE_CUDA
102 interface
103 subroutine cuda_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
104 b1_d, b2_d, b3_d, tmp_d, n) bind(c, name='cuda_fusedcg_cpld_part1')
105 use, intrinsic :: iso_c_binding
106 import c_rp
107 implicit none
108 type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d
109 integer(c_int) :: n
110 end subroutine cuda_fusedcg_cpld_part1
111 end interface
112
113 interface
114 subroutine cuda_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
115 po1_d, po2_d, po3_d, beta, n) bind(c, name='cuda_fusedcg_cpld_update_p')
116 use, intrinsic :: iso_c_binding
117 import c_rp
118 implicit none
119 type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
120 type(c_ptr), value :: po1_d, po2_d, po3_d
121 real(c_rp) :: beta
122 integer(c_int) :: n
123 end subroutine cuda_fusedcg_cpld_update_p
124 end interface
125
126 interface
127 subroutine cuda_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, &
128 alpha, p_cur, n) bind(c, name='cuda_fusedcg_cpld_update_x')
129 use, intrinsic :: iso_c_binding
130 implicit none
131 type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
132 integer(c_int) :: p_cur, n
133 end subroutine cuda_fusedcg_cpld_update_x
134 end interface
135
136 interface
137 real(c_rp) function cuda_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
138 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) &
139 bind(c, name='cuda_fusedcg_cpld_part2')
140 use, intrinsic :: iso_c_binding
141 import c_rp
142 implicit none
143 type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
144 type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
145 real(c_rp) :: alpha
146 integer(c_int) :: n, p_cur
147 end function cuda_fusedcg_cpld_part2
148 end interface
149#elif HAVE_HIP
150 interface
151 subroutine hip_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
152 b1_d, b2_d, b3_d, tmp_d, n) bind(c, name='hip_fusedcg_cpld_part1')
153 use, intrinsic :: iso_c_binding
154 import c_rp
155 implicit none
156 type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d
157 integer(c_int) :: n
158 end subroutine hip_fusedcg_cpld_part1
159 end interface
160
161 interface
162 subroutine hip_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
163 po1_d, po2_d, po3_d, beta, n) bind(c, name='hip_fusedcg_cpld_update_p')
164 use, intrinsic :: iso_c_binding
165 import c_rp
166 implicit none
167 type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
168 type(c_ptr), value :: po1_d, po2_d, po3_d
169 real(c_rp) :: beta
170 integer(c_int) :: n
171 end subroutine hip_fusedcg_cpld_update_p
172 end interface
173
174 interface
175 subroutine hip_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, &
176 alpha, p_cur, n) bind(c, name='hip_fusedcg_cpld_update_x')
177 use, intrinsic :: iso_c_binding
178 implicit none
179 type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
180 integer(c_int) :: p_cur, n
181 end subroutine hip_fusedcg_cpld_update_x
182 end interface
183
184 interface
185 real(c_rp) function hip_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
186 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) &
187 bind(c, name='hip_fusedcg_cpld_part2')
188 use, intrinsic :: iso_c_binding
189 import c_rp
190 implicit none
191 type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
192 type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
193 real(c_rp) :: alpha
194 integer(c_int) :: n, p_cur
195 end function hip_fusedcg_cpld_part2
196 end interface
197#endif
198
199contains
200
201 subroutine device_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
202 b1_d, b2_d, b3_d, tmp_d, n)
203 type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d
204 type(c_ptr), value :: tmp_d
205 integer(c_int) :: n
206#ifdef HAVE_HIP
207 call hip_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
208#elif HAVE_CUDA
209 call cuda_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
210#else
211 call neko_error('No device backend configured')
212#endif
213 end subroutine device_fusedcg_cpld_part1
214
215 subroutine device_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
216 po1_d, po2_d, po3_d, beta, n)
217 type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
218 type(c_ptr), value :: po1_d, po2_d, po3_d
219 real(c_rp) :: beta
220 integer(c_int) :: n
221#ifdef HAVE_HIP
222 call hip_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
223 po1_d, po2_d, po3_d, beta, n)
224#elif HAVE_CUDA
225 call cuda_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
226 po1_d, po2_d, po3_d, beta, n)
227#else
228 call neko_error('No device backend configured')
229#endif
230 end subroutine device_fusedcg_cpld_update_p
231
232 subroutine device_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
233 p1_d, p2_d, p3_d, alpha, p_cur, n)
234 type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
235 integer(c_int) :: p_cur, n
236#ifdef HAVE_HIP
237 call hip_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
238 p1_d, p2_d, p3_d, alpha, p_cur, n)
239#elif HAVE_CUDA
240 call cuda_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
241 p1_d, p2_d, p3_d, alpha, p_cur, n)
242#else
243 call neko_error('No device backend configured')
244#endif
245 end subroutine device_fusedcg_cpld_update_x
246
247 function device_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
248 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) result(res)
249 type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
250 type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
251 real(c_rp) :: alpha
252 integer :: n, p_cur
253 real(kind=rp) :: res
254 integer :: ierr
255#ifdef HAVE_HIP
256 res = hip_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
257 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
258#elif HAVE_CUDA
259 res = cuda_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
260 c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
261#else
262 call neko_error('No device backend configured')
263#endif
264
265#ifndef HAVE_DEVICE_MPI
266 if (pe_size .gt. 1) then
267 call mpi_allreduce(mpi_in_place, res, 1, &
268 mpi_real_precision, mpi_sum, neko_comm, ierr)
269 end if
270#endif
271
272 end function device_fusedcg_cpld_part2
273
275 subroutine fusedcg_cpld_device_init(this, n, max_iter, M, &
276 rel_tol, abs_tol, monitor)
277 class(fusedcg_cpld_device_t), target, intent(inout) :: this
278 class(pc_t), optional, intent(in), target :: M
279 integer, intent(in) :: n
280 integer, intent(in) :: max_iter
281 real(kind=rp), optional, intent(in) :: rel_tol
282 real(kind=rp), optional, intent(in) :: abs_tol
283 logical, optional, intent(in) :: monitor
284 type(c_ptr) :: ptr
285 integer(c_size_t) :: p_size
286 integer :: i
287
288 call this%free()
289
290 allocate(this%w1(n))
291 allocate(this%w2(n))
292 allocate(this%w3(n))
293 allocate(this%r1(n))
294 allocate(this%r2(n))
295 allocate(this%r3(n))
296 allocate(this%z1(n))
297 allocate(this%z2(n))
298 allocate(this%z3(n))
299 allocate(this%tmp(n))
300 allocate(this%p1(n, device_fusedcg_cpld_p_space))
301 allocate(this%p2(n, device_fusedcg_cpld_p_space))
302 allocate(this%p3(n, device_fusedcg_cpld_p_space))
303 allocate(this%p1_d(device_fusedcg_cpld_p_space))
304 allocate(this%p2_d(device_fusedcg_cpld_p_space))
305 allocate(this%p3_d(device_fusedcg_cpld_p_space))
306 allocate(this%alpha(device_fusedcg_cpld_p_space))
307
308 if (present(m)) then
309 this%M => m
310 end if
311
312 call device_map(this%w1, this%w1_d, n)
313 call device_map(this%w2, this%w2_d, n)
314 call device_map(this%w3, this%w3_d, n)
315 call device_map(this%r1, this%r1_d, n)
316 call device_map(this%r2, this%r2_d, n)
317 call device_map(this%r3, this%r3_d, n)
318 call device_map(this%z1, this%z1_d, n)
319 call device_map(this%z2, this%z2_d, n)
320 call device_map(this%z3, this%z3_d, n)
321 call device_map(this%tmp, this%tmp_d, n)
322 call device_map(this%alpha, this%alpha_d, device_fusedcg_cpld_p_space)
324 this%p1_d(i) = c_null_ptr
325 call device_map(this%p1(:,i), this%p1_d(i), n)
326
327 this%p2_d(i) = c_null_ptr
328 call device_map(this%p2(:,i), this%p2_d(i), n)
329
330 this%p3_d(i) = c_null_ptr
331 call device_map(this%p3(:,i), this%p3_d(i), n)
332 end do
333
334 p_size = c_sizeof(c_null_ptr) * (device_fusedcg_cpld_p_space)
335 call device_alloc(this%p1_d_d, p_size)
336 call device_alloc(this%p2_d_d, p_size)
337 call device_alloc(this%p3_d_d, p_size)
338 ptr = c_loc(this%p1_d)
339 call device_memcpy(ptr, this%p1_d_d, p_size, &
340 host_to_device, sync=.false.)
341 ptr = c_loc(this%p2_d)
342 call device_memcpy(ptr, this%p2_d_d, p_size, &
343 host_to_device, sync=.false.)
344 ptr = c_loc(this%p3_d)
345 call device_memcpy(ptr, this%p3_d_d, p_size, &
346 host_to_device, sync=.false.)
347 if (present(rel_tol) .and. present(abs_tol) .and. present(monitor)) then
348 call this%ksp_init(max_iter, rel_tol, abs_tol, monitor = monitor)
349 else if (present(rel_tol) .and. present(abs_tol)) then
350 call this%ksp_init(max_iter, rel_tol, abs_tol)
351 else if (present(monitor) .and. present(abs_tol)) then
352 call this%ksp_init(max_iter, abs_tol = abs_tol, monitor = monitor)
353 else if (present(rel_tol) .and. present(monitor)) then
354 call this%ksp_init(max_iter, rel_tol, monitor = monitor)
355 else if (present(rel_tol)) then
356 call this%ksp_init(max_iter, rel_tol = rel_tol)
357 else if (present(abs_tol)) then
358 call this%ksp_init(max_iter, abs_tol = abs_tol)
359 else if (present(monitor)) then
360 call this%ksp_init(max_iter, monitor = monitor)
361 else
362 call this%ksp_init(max_iter)
363 end if
364
365 call device_event_create(this%gs_event1, 2)
366 call device_event_create(this%gs_event2, 2)
367 call device_event_create(this%gs_event3, 2)
368
369 end subroutine fusedcg_cpld_device_init
370
373 class(fusedcg_cpld_device_t), intent(inout) :: this
374 integer :: i
375
376 call this%ksp_free()
377
378 if (allocated(this%w1)) then
379 deallocate(this%w1)
380 end if
381
382 if (allocated(this%w2)) then
383 deallocate(this%w2)
384 end if
385
386 if (allocated(this%w3)) then
387 deallocate(this%w3)
388 end if
389
390 if (allocated(this%r1)) then
391 deallocate(this%r1)
392 end if
393
394 if (allocated(this%r2)) then
395 deallocate(this%r2)
396 end if
397
398 if (allocated(this%r3)) then
399 deallocate(this%r3)
400 end if
401
402 if (allocated(this%z1)) then
403 deallocate(this%z1)
404 end if
405
406 if (allocated(this%z2)) then
407 deallocate(this%z2)
408 end if
409
410 if (allocated(this%z3)) then
411 deallocate(this%z3)
412 end if
413
414 if (allocated(this%tmp)) then
415 deallocate(this%tmp)
416 end if
417
418 if (allocated(this%alpha)) then
419 deallocate(this%alpha)
420 end if
421
422 if (allocated(this%p1)) then
423 deallocate(this%p1)
424 end if
425
426 if (allocated(this%p2)) then
427 deallocate(this%p2)
428 end if
429
430 if (allocated(this%p3)) then
431 deallocate(this%p3)
432 end if
433
434 if (c_associated(this%w1_d)) then
435 call device_free(this%w1_d)
436 end if
437
438 if (c_associated(this%w2_d)) then
439 call device_free(this%w2_d)
440 end if
441
442 if (c_associated(this%w3_d)) then
443 call device_free(this%w3_d)
444 end if
445
446 if (c_associated(this%r1_d)) then
447 call device_free(this%r1_d)
448 end if
449
450 if (c_associated(this%r2_d)) then
451 call device_free(this%r2_d)
452 end if
453
454 if (c_associated(this%r3_d)) then
455 call device_free(this%r3_d)
456 end if
457
458 if (c_associated(this%z1_d)) then
459 call device_free(this%z1_d)
460 end if
461
462 if (c_associated(this%z2_d)) then
463 call device_free(this%z2_d)
464 end if
465
466 if (c_associated(this%z3_d)) then
467 call device_free(this%z3_d)
468 end if
469
470 if (c_associated(this%alpha_d)) then
471 call device_free(this%alpha_d)
472 end if
473
474 if (c_associated(this%tmp_d)) then
475 call device_free(this%tmp_d)
476 end if
477
478 if (allocated(this%p1_d)) then
480 if (c_associated(this%p1_d(i))) then
481 call device_free(this%p1_d(i))
482 end if
483 end do
484 end if
485
486 if (allocated(this%p2_d)) then
488 if (c_associated(this%p2_d(i))) then
489 call device_free(this%p2_d(i))
490 end if
491 end do
492 end if
493
494 if (allocated(this%p3_d)) then
496 if (c_associated(this%p3_d(i))) then
497 call device_free(this%p3_d(i))
498 end if
499 end do
500 end if
501
502 if (c_associated(this%p1_d_d)) then
503 call device_free(this%p1_d_d)
504 end if
505
506 if (c_associated(this%p2_d_d)) then
507 call device_free(this%p2_d_d)
508 end if
509
510 if (c_associated(this%p3_d_d)) then
511 call device_free(this%p3_d_d)
512 end if
513
514 nullify(this%M)
515
516 if (c_associated(this%gs_event1)) then
517 call device_event_destroy(this%gs_event1)
518 end if
519
520 if (c_associated(this%gs_event2)) then
521 call device_event_destroy(this%gs_event2)
522 end if
523
524 if (c_associated(this%gs_event3)) then
525 call device_event_destroy(this%gs_event3)
526 end if
527
528 end subroutine fusedcg_cpld_device_free
529
531 function fusedcg_cpld_device_solve_coupled(this, Ax, x, y, z, fx, fy, fz, &
532 n, coef, blstx, blsty, blstz, gs_h, niter) result(ksp_results)
533 class(fusedcg_cpld_device_t), intent(inout) :: this
534 class(ax_t), intent(in) :: ax
535 type(field_t), intent(inout) :: x
536 type(field_t), intent(inout) :: y
537 type(field_t), intent(inout) :: z
538 integer, intent(in) :: n
539 real(kind=rp), dimension(n), intent(in) :: fx
540 real(kind=rp), dimension(n), intent(in) :: fy
541 real(kind=rp), dimension(n), intent(in) :: fz
542 type(coef_t), intent(inout) :: coef
543 type(bc_list_t), intent(inout) :: blstx
544 type(bc_list_t), intent(inout) :: blsty
545 type(bc_list_t), intent(inout) :: blstz
546 type(gs_t), intent(inout) :: gs_h
547 type(ksp_monitor_t), dimension(3) :: ksp_results
548 integer, optional, intent(in) :: niter
549 integer :: iter, max_iter, ierr, i, p_cur, p_prev
550 real(kind=rp) :: rnorm, rtr, norm_fac, rtz1, rtz2
551 real(kind=rp) :: pap, beta
552 type(c_ptr) :: fx_d
553 type(c_ptr) :: fy_d
554 type(c_ptr) :: fz_d
555
556 fx_d = device_get_ptr(fx)
557 fy_d = device_get_ptr(fy)
558 fz_d = device_get_ptr(fz)
559
560 if (present(niter)) then
561 max_iter = niter
562 else
563 max_iter = ksp_max_iter
564 end if
565 norm_fac = 1.0_rp / sqrt(coef%volume)
566
567 associate(w1 => this%w1, w2 => this%w2, w3 => this%w3, r1 => this%r1, &
568 r2 => this%r2, r3 => this%r3, p1 => this%p1, p2 => this%p2, &
569 p3 => this%p3, z1 => this%z1, z2 => this%z2, z3 => this%z3, &
570 tmp_d => this%tmp_d, alpha => this%alpha, alpha_d => this%alpha_d, &
571 w1_d => this%w1_d, w2_d => this%w2_d, w3_d => this%w3_d, &
572 r1_d => this%r1_d, r2_d => this%r2_d, r3_d => this%r3_d, &
573 z1_d => this%z1_d, z2_d => this%z2_d, z3_d => this%z3_d, &
574 p1_d => this%p1_d, p2_d => this%p2_d, p3_d => this%p3_d, &
575 p1_d_d => this%p1_d_d, p2_d_d => this%p2_d_d, p3_d_d => this%p3_d_d)
576
577 rtz1 = 1.0_rp
579 p_cur = 1
580
581
582 call device_rzero(x%x_d, n)
583 call device_rzero(y%x_d, n)
584 call device_rzero(z%x_d, n)
585 call device_rzero(p1_d(1), n)
586 call device_rzero(p2_d(1), n)
587 call device_rzero(p3_d(1), n)
588 call device_copy(r1_d, fx_d, n)
589 call device_copy(r2_d, fy_d, n)
590 call device_copy(r3_d, fz_d, n)
591
592 call device_fusedcg_cpld_part1(r1_d, r2_d, r3_d, r1_d, &
593 r2_d, r3_d, tmp_d, n)
594
595 rtr = device_glsc3(tmp_d, coef%mult_d, coef%binv_d, n)
596
597 rnorm = sqrt(rtr)*norm_fac
598 ksp_results%res_start = rnorm
599 ksp_results%res_final = rnorm
600 ksp_results(1)%iter = 0
601 ksp_results(2:3)%iter = -1
602 if(abscmp(rnorm, 0.0_rp)) then
603 ksp_results%converged = .true.
604 return
605 end if
606 call this%monitor_start('fcpldCG')
607 do iter = 1, max_iter
608 call this%M%solve(z1, r1, n)
609 call this%M%solve(z2, r2, n)
610 call this%M%solve(z3, r3, n)
611 rtz2 = rtz1
612 call device_fusedcg_cpld_part1(z1_d, z2_d, z3_d, &
613 r1_d, r2_d, r3_d, tmp_d, n)
614 rtz1 = device_glsc2(tmp_d, coef%mult_d, n)
615
616 beta = rtz1 / rtz2
617 if (iter .eq. 1) beta = 0.0_rp
618
619 call device_fusedcg_cpld_update_p(p1_d(p_cur), p2_d(p_cur), p3_d(p_cur), &
620 z1_d, z2_d, z3_d, p1_d(p_prev), p2_d(p_prev), p3_d(p_prev), beta, n)
621
622 call ax%compute_vector(w1, w2, w3, &
623 p1(1, p_cur), p2(1, p_cur), p3(1, p_cur), coef, x%msh, x%Xh)
624
625 call rotate_cyc(w1, w2, w3, 1, coef)
626 call gs_h%op(w1, n, gs_op_add, this%gs_event1)
627 call device_event_sync(this%gs_event1)
628 call blstx%apply(w1, n)
629 call gs_h%op(w2, n, gs_op_add, this%gs_event2)
630 call device_event_sync(this%gs_event2)
631 call blsty%apply(w2, n)
632 call gs_h%op(w3, n, gs_op_add, this%gs_event3)
633 call device_event_sync(this%gs_event3)
634 call blstz%apply(w3, n)
635 call rotate_cyc(w1, w2, w3, 0, coef)
636
637 call device_fusedcg_cpld_part1(w1_d, w2_d, w3_d, p1_d(p_cur), &
638 p2_d(p_cur), p3_d(p_cur), tmp_d, n)
639
640 pap = device_glsc2(tmp_d, coef%mult_d, n)
641
642 alpha(p_cur) = rtz1 / pap
643 rtr = device_fusedcg_cpld_part2(r1_d, r2_d, r3_d, coef%mult_d, &
644 w1_d, w2_d, w3_d, alpha_d, alpha(p_cur), p_cur, n)
645 rnorm = sqrt(rtr)*norm_fac
646 call this%monitor_iter(iter, rnorm)
647 if ((p_cur .eq. device_fusedcg_cpld_p_space) .or. &
648 (rnorm .lt. this%abs_tol) .or. iter .eq. max_iter) then
649 call device_fusedcg_cpld_update_x(x%x_d, y%x_d, z%x_d, &
650 p1_d_d, p2_d_d, p3_d_d, alpha_d, p_cur, n)
651 p_prev = p_cur
652 p_cur = 1
653 if (rnorm .lt. this%abs_tol) exit
654 else
655 p_prev = p_cur
656 p_cur = p_cur + 1
657 end if
658 end do
659 call this%monitor_stop()
660 ksp_results%res_final = rnorm
661 ksp_results%iter = iter
662 ksp_results%converged = this%is_converged(iter, rnorm)
663
664 end associate
665
667
669 function fusedcg_cpld_device_solve(this, Ax, x, f, n, coef, blst, &
670 gs_h, niter) result(ksp_results)
671 class(fusedcg_cpld_device_t), intent(inout) :: this
672 class(ax_t), intent(in) :: ax
673 type(field_t), intent(inout) :: x
674 integer, intent(in) :: n
675 real(kind=rp), dimension(n), intent(in) :: f
676 type(coef_t), intent(inout) :: coef
677 type(bc_list_t), intent(inout) :: blst
678 type(gs_t), intent(inout) :: gs_h
679 type(ksp_monitor_t) :: ksp_results
680 integer, optional, intent(in) :: niter
681
682 ! Throw and error
683 call neko_error('The cpldcg solver is only defined for coupled solves')
684
685 ksp_results%res_final = 0.0
686 ksp_results%iter = 0
687 ksp_results%converged = .false.
688
689 end function fusedcg_cpld_device_solve
690
691end module fusedcg_cpld_device
__device__ T solve(const T u, const T y, const T guess, const T nu, const T kappa, const T B)
void hip_fusedcg_cpld_update_x(void *x1, void *x2, void *x3, void *p1, void *p2, void *p3, void *alpha, int *p_cur, int *n)
void hip_fusedcg_cpld_update_p(void *p1, void *p2, void *p3, void *z1, void *z2, void *z3, void *po1, void *po2, void *po3, real *beta, int *n)
real hip_fusedcg_cpld_part2(void *a1, void *a2, void *a3, void *b, void *c1, void *c2, void *c3, void *alpha_d, real *alpha, int *p_cur, int *n)
void hip_fusedcg_cpld_part1(void *a1, void *a2, void *a3, void *b1, void *b2, void *b3, void *tmp, int *n)
Return the device pointer for an associated Fortran array.
Definition device.F90:101
Map a Fortran array to a device (allocate and associate)
Definition device.F90:77
Copy data between host and device (or device and device)
Definition device.F90:71
Defines a Matrix-vector product.
Definition ax.f90:34
Defines a list of bc_t.
Definition bc_list.f90:34
Coefficients.
Definition coef.f90:34
Definition comm.F90:1
type(mpi_datatype), public mpi_real_precision
MPI type for working precision of REAL types.
Definition comm.F90:51
integer, public pe_size
MPI size of communicator.
Definition comm.F90:59
type(mpi_comm), public neko_comm
MPI communicator.
Definition comm.F90:43
subroutine, public device_rzero(a_d, n, strm)
Zero a real vector.
subroutine, public device_copy(a_d, b_d, n, strm)
Copy a vector .
real(kind=rp) function, public device_glsc3(a_d, b_d, c_d, n, strm)
Weighted inner product .
real(kind=rp) function, public device_glsc2(a_d, b_d, n, strm)
Weighted inner product .
Device abstraction, common interface for various accelerators.
Definition device.F90:34
subroutine, public device_event_sync(event)
Synchronize an event.
Definition device.F90:1314
integer, parameter, public host_to_device
Definition device.F90:47
subroutine, public device_free(x_d)
Deallocate memory on the device.
Definition device.F90:219
subroutine, public device_event_destroy(event)
Destroy a device event.
Definition device.F90:1279
subroutine, public device_alloc(x_d, s)
Allocate memory on the device.
Definition device.F90:192
subroutine, public device_event_create(event, flags)
Create a device event queue.
Definition device.F90:1249
Defines a field.
Definition field.f90:34
Defines a fused Conjugate Gradient method for accelerators.
subroutine device_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha, p_cur, n)
type(ksp_monitor_t) function fusedcg_cpld_device_solve(this, ax, x, f, n, coef, blst, gs_h, niter)
Pipelined PCG solve.
subroutine fusedcg_cpld_device_free(this)
Deallocate a pipelined PCG solver.
subroutine device_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, po1_d, po2_d, po3_d, beta, n)
subroutine fusedcg_cpld_device_init(this, n, max_iter, m, rel_tol, abs_tol, monitor)
Initialise a fused PCG solver.
real(kind=rp) function device_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
type(ksp_monitor_t) function, dimension(3) fusedcg_cpld_device_solve_coupled(this, ax, x, y, z, fx, fy, fz, n, coef, blstx, blsty, blstz, gs_h, niter)
Pipelined PCG solve coupled solve.
integer, parameter device_fusedcg_cpld_p_space
subroutine device_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
Gather-scatter.
Implements the base abstract type for Krylov solvers plus helper types.
Definition krylov.f90:34
integer, parameter, public ksp_max_iter
Maximum number of iters.
Definition krylov.f90:51
Definition math.f90:60
real(kind=rp) function, public glsc3(a, b, c, n)
Weighted inner product .
Definition math.f90:1067
subroutine, public copy(a, b, n)
Copy a vector .
Definition math.f90:249
subroutine, public rzero(a, n)
Zero a real vector.
Definition math.f90:205
integer, parameter, public c_rp
Definition num_types.f90:13
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Operators.
Definition operators.f90:34
Krylov preconditioner.
Definition precon.f90:34
Utilities.
Definition utils.f90:35
Base type for a matrix-vector product providing .
Definition ax.f90:43
A list of allocatable `bc_t`. Follows the standard interface of lists.
Definition bc_list.f90:48
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition coef.f90:56
Fused preconditioned conjugate gradient method.
Type for storing initial and final residuals in a Krylov solver.
Definition krylov.f90:56
Base abstract type for a canonical Krylov method, solving .
Definition krylov.f90:73
Defines a canonical Krylov preconditioner.
Definition precon.f90:40