Neko  0.9.0
A portable framework for high-order spectral element flow simulations
fusedcg_cpld_device.F90
Go to the documentation of this file.
1 ! Copyright (c) 2021-2024, The Neko Authors
2 ! All rights reserved.
3 !
4 ! Redistribution and use in source and binary forms, with or without
5 ! modification, are permitted provided that the following conditions
6 ! are met:
7 !
8 ! * Redistributions of source code must retain the above copyright
9 ! notice, this list of conditions and the following disclaimer.
10 !
11 ! * Redistributions in binary form must reproduce the above
12 ! copyright notice, this list of conditions and the following
13 ! disclaimer in the documentation and/or other materials provided
14 ! with the distribution.
15 !
16 ! * Neither the name of the authors nor the names of its
17 ! contributors may be used to endorse or promote products derived
18 ! from this software without specific prior written permission.
19 !
20 ! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 ! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 ! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23 ! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24 ! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25 ! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26 ! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 ! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 ! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29 ! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30 ! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31 ! POSSIBILITY OF SUCH DAMAGE.
32 !
36  use precon, only : pc_t
37  use ax_product, only : ax_t
38  use num_types, only: rp, c_rp
39  use field, only : field_t
40  use coefs, only : coef_t
41  use gather_scatter, only : gs_t, gs_op_add
42  use bc, only : bc_list_t, bc_list_apply
43  use math, only : glsc3, rzero, copy, abscmp
45  use device
46  use comm
47  implicit none
48  private
49 
50  integer, parameter :: device_fusedcg_cpld_p_space = 10
51 
53  type, public, extends(ksp_t) :: fusedcg_cpld_device_t
54  real(kind=rp), allocatable :: w1(:)
55  real(kind=rp), allocatable :: w2(:)
56  real(kind=rp), allocatable :: w3(:)
57  real(kind=rp), allocatable :: r1(:)
58  real(kind=rp), allocatable :: r2(:)
59  real(kind=rp), allocatable :: r3(:)
60  real(kind=rp), allocatable :: z1(:)
61  real(kind=rp), allocatable :: z2(:)
62  real(kind=rp), allocatable :: z3(:)
63  real(kind=rp), allocatable :: tmp(:)
64  real(kind=rp), allocatable :: p1(:,:)
65  real(kind=rp), allocatable :: p2(:,:)
66  real(kind=rp), allocatable :: p3(:,:)
67  real(kind=rp), allocatable :: alpha(:)
68  type(c_ptr) :: w1_d = c_null_ptr
69  type(c_ptr) :: w2_d = c_null_ptr
70  type(c_ptr) :: w3_d = c_null_ptr
71  type(c_ptr) :: r1_d = c_null_ptr
72  type(c_ptr) :: r2_d = c_null_ptr
73  type(c_ptr) :: r3_d = c_null_ptr
74  type(c_ptr) :: z1_d = c_null_ptr
75  type(c_ptr) :: z2_d = c_null_ptr
76  type(c_ptr) :: z3_d = c_null_ptr
77  type(c_ptr) :: alpha_d = c_null_ptr
78  type(c_ptr) :: p1_d_d = c_null_ptr
79  type(c_ptr) :: p2_d_d = c_null_ptr
80  type(c_ptr) :: p3_d_d = c_null_ptr
81  type(c_ptr) :: tmp_d = c_null_ptr
82  type(c_ptr), allocatable :: p1_d(:)
83  type(c_ptr), allocatable :: p2_d(:)
84  type(c_ptr), allocatable :: p3_d(:)
85  type(c_ptr) :: gs_event1 = c_null_ptr
86  type(c_ptr) :: gs_event2 = c_null_ptr
87  type(c_ptr) :: gs_event3 = c_null_ptr
88  contains
89  procedure, pass(this) :: init => fusedcg_cpld_device_init
90  procedure, pass(this) :: free => fusedcg_cpld_device_free
91  procedure, pass(this) :: solve => fusedcg_cpld_device_solve
92  procedure, pass(this) :: solve_coupled => fusedcg_cpld_device_solve_coupled
93  end type fusedcg_cpld_device_t
94 
95 #ifdef HAVE_CUDA
96  interface
97  subroutine cuda_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
98  b1_d, b2_d, b3_d, tmp_d, n) bind(c, name='cuda_fusedcg_cpld_part1')
99  use, intrinsic :: iso_c_binding
100  import c_rp
101  implicit none
102  type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d
103  integer(c_int) :: n
104  end subroutine cuda_fusedcg_cpld_part1
105  end interface
106 
107  interface
108  subroutine cuda_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
109  po1_d, po2_d, po3_d, beta, n) bind(c, name='cuda_fusedcg_cpld_update_p')
110  use, intrinsic :: iso_c_binding
111  import c_rp
112  implicit none
113  type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
114  type(c_ptr), value :: po1_d, po2_d, po3_d
115  real(c_rp) :: beta
116  integer(c_int) :: n
117  end subroutine cuda_fusedcg_cpld_update_p
118  end interface
119 
120  interface
121  subroutine cuda_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, &
122  alpha, p_cur, n) bind(c, name='cuda_fusedcg_cpld_update_x')
123  use, intrinsic :: iso_c_binding
124  implicit none
125  type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
126  integer(c_int) :: p_cur, n
127  end subroutine cuda_fusedcg_cpld_update_x
128  end interface
129 
130  interface
131  real(c_rp) function cuda_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
132  c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) &
133  bind(c, name='cuda_fusedcg_cpld_part2')
134  use, intrinsic :: iso_c_binding
135  import c_rp
136  implicit none
137  type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
138  type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
139  real(c_rp) :: alpha
140  integer(c_int) :: n, p_cur
141  end function cuda_fusedcg_cpld_part2
142  end interface
143 #elif HAVE_HIP
144  interface
145  subroutine hip_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
146  b1_d, b2_d, b3_d, tmp_d, n) bind(c, name='hip_fusedcg_cpld_part1')
147  use, intrinsic :: iso_c_binding
148  import c_rp
149  implicit none
150  type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d
151  integer(c_int) :: n
152  end subroutine hip_fusedcg_cpld_part1
153  end interface
154 
155  interface
156  subroutine hip_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
157  po1_d, po2_d, po3_d, beta, n) bind(c, name='hip_fusedcg_cpld_update_p')
158  use, intrinsic :: iso_c_binding
159  import c_rp
160  implicit none
161  type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
162  type(c_ptr), value :: po1_d, po2_d, po3_d
163  real(c_rp) :: beta
164  integer(c_int) :: n
165  end subroutine hip_fusedcg_cpld_update_p
166  end interface
167 
168  interface
169  subroutine hip_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, &
170  alpha, p_cur, n) bind(c, name='hip_fusedcg_cpld_update_x')
171  use, intrinsic :: iso_c_binding
172  implicit none
173  type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
174  integer(c_int) :: p_cur, n
175  end subroutine hip_fusedcg_cpld_update_x
176  end interface
177 
178  interface
179  real(c_rp) function hip_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
180  c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) &
181  bind(c, name='hip_fusedcg_cpld_part2')
182  use, intrinsic :: iso_c_binding
183  import c_rp
184  implicit none
185  type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
186  type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
187  real(c_rp) :: alpha
188  integer(c_int) :: n, p_cur
189  end function hip_fusedcg_cpld_part2
190  end interface
191 #endif
192 
193 contains
194 
195  subroutine device_fusedcg_cpld_part1(a1_d, a2_d, a3_d, &
196  b1_d, b2_d, b3_d, tmp_d, n)
197  type(c_ptr), value :: a1_d, a2_d, a3_d, b1_d, b2_d, b3_d
198  type(c_ptr), value :: tmp_d
199  integer(c_int) :: n
200 #ifdef HAVE_HIP
201  call hip_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
202 #elif HAVE_CUDA
203  call cuda_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
204 #else
205  call neko_error('No device backend configured')
206 #endif
207  end subroutine device_fusedcg_cpld_part1
208 
209  subroutine device_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
210  po1_d, po2_d, po3_d, beta, n)
211  type(c_ptr), value :: p1_d, p2_d, p3_d, z1_d, z2_d, z3_d
212  type(c_ptr), value :: po1_d, po2_d, po3_d
213  real(c_rp) :: beta
214  integer(c_int) :: n
215 #ifdef HAVE_HIP
216  call hip_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
217  po1_d, po2_d, po3_d, beta, n)
218 #elif HAVE_CUDA
219  call cuda_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, &
220  po1_d, po2_d, po3_d, beta, n)
221 #else
222  call neko_error('No device backend configured')
223 #endif
224  end subroutine device_fusedcg_cpld_update_p
225 
226  subroutine device_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
227  p1_d, p2_d, p3_d, alpha, p_cur, n)
228  type(c_ptr), value :: x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha
229  integer(c_int) :: p_cur, n
230 #ifdef HAVE_HIP
231  call hip_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
232  p1_d, p2_d, p3_d, alpha, p_cur, n)
233 #elif HAVE_CUDA
234  call cuda_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, &
235  p1_d, p2_d, p3_d, alpha, p_cur, n)
236 #else
237  call neko_error('No device backend configured')
238 #endif
239  end subroutine device_fusedcg_cpld_update_x
240 
241  function device_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
242  c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n) result(res)
243  type(c_ptr), value :: a1_d, a2_d, a3_d, b_d
244  type(c_ptr), value :: c1_d, c2_d, c3_d, alpha_d
245  real(c_rp) :: alpha
246  integer :: n, p_cur
247  real(kind=rp) :: res
248  integer :: ierr
249 #ifdef HAVE_HIP
250  res = hip_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
251  c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
252 #elif HAVE_CUDA
253  res = cuda_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, &
254  c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
255 #else
256  call neko_error('No device backend configured')
257 #endif
258 
259 #ifndef HAVE_DEVICE_MPI
260  if (pe_size .gt. 1) then
261  call mpi_allreduce(mpi_in_place, res, 1, &
262  mpi_real_precision, mpi_sum, neko_comm, ierr)
263  end if
264 #endif
265 
266  end function device_fusedcg_cpld_part2
267 
269  subroutine fusedcg_cpld_device_init(this, n, max_iter, M, &
270  rel_tol, abs_tol, monitor)
271  class(fusedcg_cpld_device_t), target, intent(inout) :: this
272  class(pc_t), optional, intent(inout), target :: M
273  integer, intent(in) :: n
274  integer, intent(in) :: max_iter
275  real(kind=rp), optional, intent(inout) :: rel_tol
276  real(kind=rp), optional, intent(inout) :: abs_tol
277  logical, optional, intent(in) :: monitor
278  type(c_ptr) :: ptr
279  integer(c_size_t) :: p_size
280  integer :: i
281 
282  call this%free()
283 
284  allocate(this%w1(n))
285  allocate(this%w2(n))
286  allocate(this%w3(n))
287  allocate(this%r1(n))
288  allocate(this%r2(n))
289  allocate(this%r3(n))
290  allocate(this%z1(n))
291  allocate(this%z2(n))
292  allocate(this%z3(n))
293  allocate(this%tmp(n))
294  allocate(this%p1(n, device_fusedcg_cpld_p_space))
295  allocate(this%p2(n, device_fusedcg_cpld_p_space))
296  allocate(this%p3(n, device_fusedcg_cpld_p_space))
297  allocate(this%p1_d(device_fusedcg_cpld_p_space))
298  allocate(this%p2_d(device_fusedcg_cpld_p_space))
299  allocate(this%p3_d(device_fusedcg_cpld_p_space))
300  allocate(this%alpha(device_fusedcg_cpld_p_space))
301 
302  if (present(m)) then
303  this%M => m
304  end if
305 
306  call device_map(this%w1, this%w1_d, n)
307  call device_map(this%w2, this%w2_d, n)
308  call device_map(this%w3, this%w3_d, n)
309  call device_map(this%r1, this%r1_d, n)
310  call device_map(this%r2, this%r2_d, n)
311  call device_map(this%r3, this%r3_d, n)
312  call device_map(this%z1, this%z1_d, n)
313  call device_map(this%z2, this%z2_d, n)
314  call device_map(this%z3, this%z3_d, n)
315  call device_map(this%tmp, this%tmp_d, n)
316  call device_map(this%alpha, this%alpha_d, device_fusedcg_cpld_p_space)
317  do i = 1, device_fusedcg_cpld_p_space+1
318  this%p1_d(i) = c_null_ptr
319  call device_map(this%p1(:,i), this%p1_d(i), n)
320 
321  this%p2_d(i) = c_null_ptr
322  call device_map(this%p2(:,i), this%p2_d(i), n)
323 
324  this%p3_d(i) = c_null_ptr
325  call device_map(this%p3(:,i), this%p3_d(i), n)
326  end do
327 
328  p_size = c_sizeof(c_null_ptr) * (device_fusedcg_cpld_p_space)
329  call device_alloc(this%p1_d_d, p_size)
330  call device_alloc(this%p2_d_d, p_size)
331  call device_alloc(this%p3_d_d, p_size)
332  ptr = c_loc(this%p1_d)
333  call device_memcpy(ptr, this%p1_d_d, p_size, &
334  host_to_device, sync=.false.)
335  ptr = c_loc(this%p2_d)
336  call device_memcpy(ptr, this%p2_d_d, p_size, &
337  host_to_device, sync=.false.)
338  ptr = c_loc(this%p3_d)
339  call device_memcpy(ptr, this%p3_d_d, p_size, &
340  host_to_device, sync=.false.)
341  if (present(rel_tol) .and. present(abs_tol) .and. present(monitor)) then
342  call this%ksp_init(max_iter, rel_tol, abs_tol, monitor = monitor)
343  else if (present(rel_tol) .and. present(abs_tol)) then
344  call this%ksp_init(max_iter, rel_tol, abs_tol)
345  else if (present(monitor) .and. present(abs_tol)) then
346  call this%ksp_init(max_iter, abs_tol = abs_tol, monitor = monitor)
347  else if (present(rel_tol) .and. present(monitor)) then
348  call this%ksp_init(max_iter, rel_tol, monitor = monitor)
349  else if (present(rel_tol)) then
350  call this%ksp_init(max_iter, rel_tol = rel_tol)
351  else if (present(abs_tol)) then
352  call this%ksp_init(max_iter, abs_tol = abs_tol)
353  else if (present(monitor)) then
354  call this%ksp_init(max_iter, monitor = monitor)
355  else
356  call this%ksp_init(max_iter)
357  end if
358 
359  call device_event_create(this%gs_event1, 2)
360  call device_event_create(this%gs_event2, 2)
361  call device_event_create(this%gs_event3, 2)
362 
363  end subroutine fusedcg_cpld_device_init
364 
366  subroutine fusedcg_cpld_device_free(this)
367  class(fusedcg_cpld_device_t), intent(inout) :: this
368  integer :: i
369 
370  call this%ksp_free()
371 
372  if (allocated(this%w1)) then
373  deallocate(this%w1)
374  end if
375 
376  if (allocated(this%w2)) then
377  deallocate(this%w2)
378  end if
379 
380  if (allocated(this%w3)) then
381  deallocate(this%w3)
382  end if
383 
384  if (allocated(this%r1)) then
385  deallocate(this%r1)
386  end if
387 
388  if (allocated(this%r2)) then
389  deallocate(this%r2)
390  end if
391 
392  if (allocated(this%r3)) then
393  deallocate(this%r3)
394  end if
395 
396  if (allocated(this%z1)) then
397  deallocate(this%z1)
398  end if
399 
400  if (allocated(this%z2)) then
401  deallocate(this%z2)
402  end if
403 
404  if (allocated(this%z3)) then
405  deallocate(this%z3)
406  end if
407 
408  if (allocated(this%tmp)) then
409  deallocate(this%tmp)
410  end if
411 
412  if (allocated(this%alpha)) then
413  deallocate(this%alpha)
414  end if
415 
416  if (allocated(this%p1)) then
417  deallocate(this%p1)
418  end if
419 
420  if (allocated(this%p2)) then
421  deallocate(this%p2)
422  end if
423 
424  if (allocated(this%p3)) then
425  deallocate(this%p3)
426  end if
427 
428  if (c_associated(this%w1_d)) then
429  call device_free(this%w1_d)
430  end if
431 
432  if (c_associated(this%w2_d)) then
433  call device_free(this%w2_d)
434  end if
435 
436  if (c_associated(this%w3_d)) then
437  call device_free(this%w3_d)
438  end if
439 
440  if (c_associated(this%r1_d)) then
441  call device_free(this%r1_d)
442  end if
443 
444  if (c_associated(this%r2_d)) then
445  call device_free(this%r2_d)
446  end if
447 
448  if (c_associated(this%r3_d)) then
449  call device_free(this%r3_d)
450  end if
451 
452  if (c_associated(this%z1_d)) then
453  call device_free(this%z1_d)
454  end if
455 
456  if (c_associated(this%z2_d)) then
457  call device_free(this%z2_d)
458  end if
459 
460  if (c_associated(this%z3_d)) then
461  call device_free(this%z3_d)
462  end if
463 
464  if (c_associated(this%alpha_d)) then
465  call device_free(this%alpha_d)
466  end if
467 
468  if (c_associated(this%tmp_d)) then
469  call device_free(this%tmp_d)
470  end if
471 
472  if (allocated(this%p1_d)) then
474  if (c_associated(this%p1_d(i))) then
475  call device_free(this%p1_d(i))
476  end if
477  end do
478  end if
479 
480  if (allocated(this%p2_d)) then
482  if (c_associated(this%p2_d(i))) then
483  call device_free(this%p2_d(i))
484  end if
485  end do
486  end if
487 
488  if (allocated(this%p3_d)) then
490  if (c_associated(this%p3_d(i))) then
491  call device_free(this%p3_d(i))
492  end if
493  end do
494  end if
495 
496  nullify(this%M)
497 
498  if (c_associated(this%gs_event1)) then
499  call device_event_destroy(this%gs_event1)
500  end if
501 
502  if (c_associated(this%gs_event2)) then
503  call device_event_destroy(this%gs_event2)
504  end if
505 
506  if (c_associated(this%gs_event3)) then
507  call device_event_destroy(this%gs_event3)
508  end if
509 
510  end subroutine fusedcg_cpld_device_free
511 
513  function fusedcg_cpld_device_solve_coupled(this, Ax, x, y, z, fx, fy, fz, &
514  n, coef, blstx, blsty, blstz, gs_h, niter) result(ksp_results)
515  class(fusedcg_cpld_device_t), intent(inout) :: this
516  class(ax_t), intent(inout) :: ax
517  type(field_t), intent(inout) :: x
518  type(field_t), intent(inout) :: y
519  type(field_t), intent(inout) :: z
520  integer, intent(in) :: n
521  real(kind=rp), dimension(n), intent(inout) :: fx
522  real(kind=rp), dimension(n), intent(inout) :: fy
523  real(kind=rp), dimension(n), intent(inout) :: fz
524  type(coef_t), intent(inout) :: coef
525  type(bc_list_t), intent(inout) :: blstx
526  type(bc_list_t), intent(inout) :: blsty
527  type(bc_list_t), intent(inout) :: blstz
528  type(gs_t), intent(inout) :: gs_h
529  type(ksp_monitor_t), dimension(3) :: ksp_results
530  integer, optional, intent(in) :: niter
531  integer :: iter, max_iter, ierr, i, p_cur, p_prev
532  real(kind=rp) :: rnorm, rtr, norm_fac, rtz1, rtz2
533  real(kind=rp) :: pap, beta
534  type(c_ptr) :: fx_d
535  type(c_ptr) :: fy_d
536  type(c_ptr) :: fz_d
537 
538  fx_d = device_get_ptr(fx)
539  fy_d = device_get_ptr(fy)
540  fz_d = device_get_ptr(fz)
541 
542  if (present(niter)) then
543  max_iter = niter
544  else
545  max_iter = ksp_max_iter
546  end if
547  norm_fac = 1.0_rp / sqrt(coef%volume)
548 
549  associate(w1 => this%w1, w2 => this%w2, w3 => this%w3, r1 => this%r1, &
550  r2 => this%r2, r3 => this%r3, p1 => this%p1, p2 => this%p2, &
551  p3 => this%p3, z1 => this%z1, z2 => this%z2, z3 => this%z3, &
552  tmp_d => this%tmp_d, alpha => this%alpha, alpha_d => this%alpha_d, &
553  w1_d => this%w1_d, w2_d => this%w2_d, w3_d => this%w3_d, &
554  r1_d => this%r1_d, r2_d => this%r2_d, r3_d => this%r3_d, &
555  z1_d => this%z1_d, z2_d => this%z2_d, z3_d => this%z3_d, &
556  p1_d => this%p1_d, p2_d => this%p2_d, p3_d => this%p3_d, &
557  p1_d_d => this%p1_d_d, p2_d_d => this%p2_d_d, p3_d_d => this%p3_d_d)
558 
559  rtz1 = 1.0_rp
561  p_cur = 1
562 
563 
564  call device_rzero(x%x_d, n)
565  call device_rzero(y%x_d, n)
566  call device_rzero(z%x_d, n)
567  call device_rzero(p1_d(1), n)
568  call device_rzero(p2_d(1), n)
569  call device_rzero(p3_d(1), n)
570  call device_copy(r1_d, fx_d, n)
571  call device_copy(r2_d, fy_d, n)
572  call device_copy(r3_d, fz_d, n)
573 
574  call device_fusedcg_cpld_part1(r1_d, r2_d, r3_d, r1_d, &
575  r2_d, r3_d, tmp_d, n)
576 
577  rtr = device_glsc3(tmp_d, coef%mult_d, coef%binv_d, n)
578 
579  rnorm = sqrt(rtr)*norm_fac
580  ksp_results%res_start = rnorm
581  ksp_results%res_final = rnorm
582  ksp_results(1)%iter = 0
583  ksp_results(2:3)%iter = -1
584  if(abscmp(rnorm, 0.0_rp)) return
585  call this%monitor_start('fcpldCG')
586  do iter = 1, max_iter
587  call this%M%solve(z1, r1, n)
588  call this%M%solve(z2, r2, n)
589  call this%M%solve(z3, r3, n)
590  rtz2 = rtz1
591  call device_fusedcg_cpld_part1(z1_d, z2_d, z3_d, &
592  r1_d, r2_d, r3_d, tmp_d, n)
593  rtz1 = device_glsc2(tmp_d, coef%mult_d, n)
594 
595  beta = rtz1 / rtz2
596  if (iter .eq. 1) beta = 0.0_rp
597 
598  call device_fusedcg_cpld_update_p(p1_d(p_cur), p2_d(p_cur), p3_d(p_cur), &
599  z1_d, z2_d, z3_d, p1_d(p_prev), p2_d(p_prev), p3_d(p_prev), beta, n)
600 
601  call ax%compute_vector(w1, w2, w3, &
602  p1(1, p_cur), p2(1, p_cur), p3(1, p_cur), coef, x%msh, x%Xh)
603  call gs_h%op(w1, n, gs_op_add, this%gs_event1)
604  call gs_h%op(w2, n, gs_op_add, this%gs_event2)
605  call gs_h%op(w3, n, gs_op_add, this%gs_event3)
606  call device_event_sync(this%gs_event1)
607  call device_event_sync(this%gs_event2)
608  call device_event_sync(this%gs_event3)
609  call bc_list_apply(blstx, w1, n)
610  call bc_list_apply(blsty, w2, n)
611  call bc_list_apply(blstz, w3, n)
612 
613  call device_fusedcg_cpld_part1(w1_d, w2_d, w3_d, p1_d(p_cur), &
614  p2_d(p_cur), p3_d(p_cur), tmp_d, n)
615 
616  pap = device_glsc2(tmp_d, coef%mult_d, n)
617 
618  alpha(p_cur) = rtz1 / pap
619  rtr = device_fusedcg_cpld_part2(r1_d, r2_d, r3_d, coef%mult_d, &
620  w1_d, w2_d, w3_d, alpha_d, alpha(p_cur), p_cur, n)
621  rnorm = sqrt(rtr)*norm_fac
622  call this%monitor_iter(iter, rnorm)
623  if ((p_cur .eq. device_fusedcg_cpld_p_space) .or. &
624  (rnorm .lt. this%abs_tol) .or. iter .eq. max_iter) then
625  call device_fusedcg_cpld_update_x(x%x_d, y%x_d, z%x_d, &
626  p1_d_d, p2_d_d, p3_d_d, alpha_d, p_cur, n)
627  p_prev = p_cur
628  p_cur = 1
629  if (rnorm .lt. this%abs_tol) exit
630  else
631  p_prev = p_cur
632  p_cur = p_cur + 1
633  end if
634  end do
635  call this%monitor_stop()
636  ksp_results%res_final = rnorm
637  ksp_results%iter = iter
638 
639  end associate
640 
642 
644  function fusedcg_cpld_device_solve(this, Ax, x, f, n, coef, blst, &
645  gs_h, niter) result(ksp_results)
646  class(fusedcg_cpld_device_t), intent(inout) :: this
647  class(ax_t), intent(inout) :: ax
648  type(field_t), intent(inout) :: x
649  integer, intent(in) :: n
650  real(kind=rp), dimension(n), intent(inout) :: f
651  type(coef_t), intent(inout) :: coef
652  type(bc_list_t), intent(inout) :: blst
653  type(gs_t), intent(inout) :: gs_h
654  type(ksp_monitor_t) :: ksp_results
655  integer, optional, intent(in) :: niter
656 
657  ! Throw and error
658  call neko_error('The cpldcg solver is only defined for coupled solves')
659 
660  ksp_results%res_final = 0.0
661  ksp_results%iter = 0
662 
663  end function fusedcg_cpld_device_solve
664 
665 end module fusedcg_cpld_device
666 
667 
void hip_fusedcg_cpld_update_x(void *x1, void *x2, void *x3, void *p1, void *p2, void *p3, void *alpha, int *p_cur, int *n)
void hip_fusedcg_cpld_update_p(void *p1, void *p2, void *p3, void *z1, void *z2, void *z3, void *po1, void *po2, void *po3, real *beta, int *n)
real hip_fusedcg_cpld_part2(void *a1, void *a2, void *a3, void *b, void *c1, void *c2, void *c3, void *alpha_d, real *alpha, int *p_cur, int *n)
void hip_fusedcg_cpld_part1(void *a1, void *a2, void *a3, void *b1, void *b2, void *b3, void *tmp, int *n)
Return the device pointer for an associated Fortran array.
Definition: device.F90:81
Map a Fortran array to a device (allocate and associate)
Definition: device.F90:57
Copy data between host and device (or device and device)
Definition: device.F90:51
Defines a Matrix-vector product.
Definition: ax.f90:34
Defines a boundary condition.
Definition: bc.f90:34
Coefficients.
Definition: coef.f90:34
Definition: comm.F90:1
type(mpi_comm) neko_comm
MPI communicator.
Definition: comm.F90:16
type(mpi_datatype) mpi_real_precision
MPI type for working precision of REAL types.
Definition: comm.F90:23
integer pe_size
MPI size of communicator.
Definition: comm.F90:31
subroutine, public device_rzero(a_d, n)
Zero a real vector.
real(kind=rp) function, public device_glsc2(a_d, b_d, n)
Weighted inner product .
real(kind=rp) function, public device_glsc3(a_d, b_d, c_d, n)
Weighted inner product .
subroutine, public device_copy(a_d, b_d, n)
Copy a vector .
Definition: device_math.F90:76
Device abstraction, common interface for various accelerators.
Definition: device.F90:34
subroutine, public device_event_sync(event)
Synchronize an event.
Definition: device.F90:1229
integer, parameter, public host_to_device
Definition: device.F90:47
subroutine, public device_free(x_d)
Deallocate memory on the device.
Definition: device.F90:185
subroutine, public device_event_destroy(event)
Destroy a device event.
Definition: device.F90:1194
subroutine, public device_alloc(x_d, s)
Allocate memory on the device.
Definition: device.F90:164
subroutine, public device_event_create(event, flags)
Create a device event queue.
Definition: device.F90:1164
Defines a field.
Definition: field.f90:34
Defines a fused Conjugate Gradient method for accelerators.
subroutine device_fusedcg_cpld_update_x(x1_d, x2_d, x3_d, p1_d, p2_d, p3_d, alpha, p_cur, n)
subroutine fusedcg_cpld_device_init(this, n, max_iter, M, rel_tol, abs_tol, monitor)
Initialise a fused PCG solver.
subroutine fusedcg_cpld_device_free(this)
Deallocate a pipelined PCG solver.
subroutine device_fusedcg_cpld_update_p(p1_d, p2_d, p3_d, z1_d, z2_d, z3_d, po1_d, po2_d, po3_d, beta, n)
type(ksp_monitor_t) function fusedcg_cpld_device_solve(this, Ax, x, f, n, coef, blst, gs_h, niter)
Pipelined PCG solve.
type(ksp_monitor_t) function, dimension(3) fusedcg_cpld_device_solve_coupled(this, Ax, x, y, z, fx, fy, fz, n, coef, blstx, blsty, blstz, gs_h, niter)
Pipelined PCG solve coupled solve.
real(kind=rp) function device_fusedcg_cpld_part2(a1_d, a2_d, a3_d, b_d, c1_d, c2_d, c3_d, alpha_d, alpha, p_cur, n)
integer, parameter device_fusedcg_cpld_p_space
subroutine device_fusedcg_cpld_part1(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, tmp_d, n)
Gather-scatter.
Implements the base abstract type for Krylov solvers plus helper types.
Definition: krylov.f90:34
integer, parameter, public ksp_max_iter
Maximum number of iters.
Definition: krylov.f90:51
Definition: math.f90:60
real(kind=rp) function, public glsc3(a, b, c, n)
Weighted inner product .
Definition: math.f90:895
subroutine, public copy(a, b, n)
Copy a vector .
Definition: math.f90:239
subroutine, public rzero(a, n)
Zero a real vector.
Definition: math.f90:195
integer, parameter, public c_rp
Definition: num_types.f90:13
integer, parameter, public rp
Global precision used in computations.
Definition: num_types.f90:12
Krylov preconditioner.
Definition: precon.f90:34
Base type for a matrix-vector product providing .
Definition: ax.f90:43
A list of boundary conditions.
Definition: bc.f90:104
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition: coef.f90:55
Fused preconditioned conjugate gradient method.
Type for storing initial and final residuals in a Krylov solver.
Definition: krylov.f90:56
Base abstract type for a canonical Krylov method, solving .
Definition: krylov.f90:66
Defines a canonical Krylov preconditioner.
Definition: precon.f90:40