Neko  0.9.99
A portable framework for high-order spectral element flow simulations
gmres_device.F90
Go to the documentation of this file.
1 ! Copyright (c) 2022-2024, The Neko Authors
2 ! All rights reserved.
3 !
4 ! Redistribution and use in source and binary forms, with or without
5 ! modification, are permitted provided that the following conditions
6 ! are met:
7 !
8 ! * Redistributions of source code must retain the above copyright
9 ! notice, this list of conditions and the following disclaimer.
10 !
11 ! * Redistributions in binary form must reproduce the above
12 ! copyright notice, this list of conditions and the following
13 ! disclaimer in the documentation and/or other materials provided
14 ! with the distribution.
15 !
16 ! * Neither the name of the authors nor the names of its
17 ! contributors may be used to endorse or promote products derived
18 ! from this software without specific prior written permission.
19 !
20 ! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 ! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 ! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23 ! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24 ! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25 ! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26 ! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 ! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 ! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29 ! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30 ! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31 ! POSSIBILITY OF SUCH DAMAGE.
32 !
35  use krylov, only : ksp_t, ksp_monitor_t
36  use precon, only : pc_t
37  use ax_product, only : ax_t
38  use num_types, only: rp, c_rp
39  use field, only : field_t
40  use coefs, only : coef_t
41  use gather_scatter, only : gs_t, gs_op_add
42  use bc, only : bc_list_t, bc_list_apply
44  use math, only : rone, rzero, abscmp
49  use device
50  use comm
51  use, intrinsic :: iso_c_binding
52  implicit none
53  private
54 
56  type, public, extends(ksp_t) :: gmres_device_t
57  integer :: m_restart
58  real(kind=rp), allocatable :: w(:)
59  real(kind=rp), allocatable :: c(:)
60  real(kind=rp), allocatable :: r(:)
61  real(kind=rp), allocatable :: z(:,:)
62  real(kind=rp), allocatable :: h(:,:)
63  real(kind=rp), allocatable :: v(:,:)
64  real(kind=rp), allocatable :: s(:)
65  real(kind=rp), allocatable :: gam(:)
66  type(c_ptr) :: w_d = c_null_ptr
67  type(c_ptr) :: c_d = c_null_ptr
68  type(c_ptr) :: r_d = c_null_ptr
69  type(c_ptr) :: s_d = c_null_ptr
70  type(c_ptr) :: gam_d = c_null_ptr
71  type(c_ptr), allocatable :: z_d(:), h_d(:), v_d(:)
72  type(c_ptr) :: z_d_d = c_null_ptr
73  type(c_ptr) :: h_d_d = c_null_ptr
74  type(c_ptr) :: v_d_d = c_null_ptr
75  type(c_ptr) :: gs_event = c_null_ptr
76  contains
77  procedure, pass(this) :: init => gmres_device_init
78  procedure, pass(this) :: free => gmres_device_free
79  procedure, pass(this) :: solve => gmres_device_solve
80  procedure, pass(this) :: solve_coupled => gmres_device_solve_coupled
81  end type gmres_device_t
82 
83 #ifdef HAVE_HIP
84  interface
85  real(c_rp) function hip_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) &
86  bind(c, name = 'hip_gmres_part2')
87  use, intrinsic :: iso_c_binding
88  import c_rp
89  implicit none
90  type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
91  integer(c_int) :: j, n
92  end function hip_gmres_part2
93  end interface
94 #elif HAVE_CUDA
95 
96  interface
97  real(c_rp) function cuda_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) &
98  bind(c, name = 'cuda_gmres_part2')
99  use, intrinsic :: iso_c_binding
100  import c_rp
101  implicit none
102  type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
103  integer(c_int) :: j, n
104  end function cuda_gmres_part2
105  end interface
106 #endif
107 
108 contains
109 
110  function device_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) result(alpha)
111  type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
112  integer(c_int) :: j, n
113  real(c_rp) :: alpha
114  integer :: ierr
115 #ifdef HAVE_HIP
116  alpha = hip_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
117 #elif HAVE_CUDA
118  alpha = cuda_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
119 #else
120  call neko_error('No device backend configured')
121 #endif
122 
123 #ifndef HAVE_DEVICE_MPI
124  if (pe_size .gt. 1) then
125  call mpi_allreduce(mpi_in_place, alpha, 1, &
126  mpi_real_precision, mpi_sum, neko_comm, ierr)
127  end if
128 #endif
129 
130  end function device_gmres_part2
131 
133  subroutine gmres_device_init(this, n, max_iter, M, m_restart, &
134  rel_tol, abs_tol, monitor)
135  class(gmres_device_t), target, intent(inout) :: this
136  integer, intent(in) :: n
137  integer, intent(in) :: max_iter
138  class(pc_t), optional, intent(in), target :: M
139  integer, optional, intent(in) :: m_restart
140  real(kind=rp), optional, intent(in) :: rel_tol
141  real(kind=rp), optional, intent(in) :: abs_tol
142  logical, optional, intent(in) :: monitor
143  type(device_ident_t), target :: M_ident
144  type(c_ptr) :: ptr
145  integer(c_size_t) :: z_size
146  integer :: i
147 
148  if (present(m_restart)) then
149  this%m_restart = m_restart
150  else
151  this%m_restart = 30
152  end if
153 
154 
155  call this%free()
156 
157  if (present(m)) then
158  this%M => m
159  else
160  this%M => m_ident
161  end if
162 
163  allocate(this%w(n))
164  allocate(this%r(n))
165  call device_map(this%w, this%w_d, n)
166  call device_map(this%r, this%r_d, n)
167 
168  allocate(this%c(this%m_restart))
169  allocate(this%s(this%m_restart))
170  allocate(this%gam(this%m_restart + 1))
171  call device_map(this%c, this%c_d, this%m_restart)
172  call device_map(this%s, this%s_d, this%m_restart)
173  call device_map(this%gam, this%gam_d, this%m_restart+1)
174 
175  allocate(this%z(n, this%m_restart))
176  allocate(this%v(n, this%m_restart))
177  allocate(this%h(this%m_restart, this%m_restart))
178  allocate(this%z_d(this%m_restart))
179  allocate(this%v_d(this%m_restart))
180  allocate(this%h_d(this%m_restart))
181  do i = 1, this%m_restart
182  this%z_d(i) = c_null_ptr
183  call device_map(this%z(:,i), this%z_d(i), n)
184 
185  this%v_d(i) = c_null_ptr
186  call device_map(this%v(:,i), this%v_d(i), n)
187 
188  this%h_d(i) = c_null_ptr
189  call device_map(this%h(:,i), this%h_d(i), this%m_restart)
190  end do
191 
192  z_size = c_sizeof(c_null_ptr) * (this%m_restart)
193  call device_alloc(this%z_d_d, z_size)
194  call device_alloc(this%v_d_d, z_size)
195  call device_alloc(this%h_d_d, z_size)
196  ptr = c_loc(this%z_d)
197  call device_memcpy(ptr, this%z_d_d, z_size, &
198  host_to_device, sync = .false.)
199  ptr = c_loc(this%v_d)
200  call device_memcpy(ptr, this%v_d_d, z_size, &
201  host_to_device, sync = .false.)
202  ptr = c_loc(this%h_d)
203  call device_memcpy(ptr, this%h_d_d, z_size, &
204  host_to_device, sync = .false.)
205 
206 
207  if (present(rel_tol) .and. present(abs_tol) .and. present(monitor)) then
208  call this%ksp_init(max_iter, rel_tol, abs_tol, monitor = monitor)
209  else if (present(rel_tol) .and. present(abs_tol)) then
210  call this%ksp_init(max_iter, rel_tol, abs_tol)
211  else if (present(monitor) .and. present(abs_tol)) then
212  call this%ksp_init(max_iter, abs_tol = abs_tol, monitor = monitor)
213  else if (present(rel_tol) .and. present(monitor)) then
214  call this%ksp_init(max_iter, rel_tol, monitor = monitor)
215  else if (present(rel_tol)) then
216  call this%ksp_init(max_iter, rel_tol = rel_tol)
217  else if (present(abs_tol)) then
218  call this%ksp_init(max_iter, abs_tol = abs_tol)
219  else if (present(monitor)) then
220  call this%ksp_init(max_iter, monitor = monitor)
221  else
222  call this%ksp_init(max_iter)
223  end if
224 
225  call device_event_create(this%gs_event, 2)
226 
227  end subroutine gmres_device_init
228 
230  subroutine gmres_device_free(this)
231  class(gmres_device_t), intent(inout) :: this
232  integer :: i
233 
234  call this%ksp_free()
235 
236  if (allocated(this%w)) then
237  deallocate(this%w)
238  end if
239 
240  if (allocated(this%c)) then
241  deallocate(this%c)
242  end if
243 
244  if (allocated(this%r)) then
245  deallocate(this%r)
246  end if
247 
248  if (allocated(this%z)) then
249  deallocate(this%z)
250  end if
251 
252  if (allocated(this%h)) then
253  deallocate(this%h)
254  end if
255 
256  if (allocated(this%v)) then
257  deallocate(this%v)
258  end if
259 
260  if (allocated(this%s)) then
261  deallocate(this%s)
262  end if
263  if (allocated(this%gam)) then
264  deallocate(this%gam)
265  end if
266 
267  if (allocated(this%v_d)) then
268  do i = 1, this%m_restart
269  if (c_associated(this%v_d(i))) then
270  call device_free(this%v_d(i))
271  end if
272  end do
273  end if
274 
275  if (allocated(this%z_d)) then
276  do i = 1, this%m_restart
277  if (c_associated(this%z_d(i))) then
278  call device_free(this%z_d(i))
279  end if
280  end do
281  end if
282  if (allocated(this%h_d)) then
283  do i = 1, this%m_restart
284  if (c_associated(this%h_d(i))) then
285  call device_free(this%h_d(i))
286  end if
287  end do
288  end if
289 
290 
291 
292  if (c_associated(this%gam_d)) then
293  call device_free(this%gam_d)
294  end if
295  if (c_associated(this%w_d)) then
296  call device_free(this%w_d)
297  end if
298  if (c_associated(this%c_d)) then
299  call device_free(this%c_d)
300  end if
301  if (c_associated(this%r_d)) then
302  call device_free(this%r_d)
303  end if
304  if (c_associated(this%s_d)) then
305  call device_free(this%s_d)
306  end if
307 
308  nullify(this%M)
309 
310  if (c_associated(this%gs_event)) then
311  call device_event_destroy(this%gs_event)
312  end if
313 
314  end subroutine gmres_device_free
315 
317  function gmres_device_solve(this, Ax, x, f, n, coef, blst, gs_h, niter) &
318  result(ksp_results)
319  class(gmres_device_t), intent(inout) :: this
320  class(ax_t), intent(in) :: ax
321  type(field_t), intent(inout) :: x
322  integer, intent(in) :: n
323  real(kind=rp), dimension(n), intent(in) :: f
324  type(coef_t), intent(inout) :: coef
325  type(bc_list_t), intent(in) :: blst
326  type(gs_t), intent(inout) :: gs_h
327  type(ksp_monitor_t) :: ksp_results
328  integer, optional, intent(in) :: niter
329  integer :: iter, max_iter
330  integer :: i, j, k
331  real(kind=rp) :: rnorm, alpha, temp, lr, alpha2, norm_fac
332  logical :: conv
333  type(c_ptr) :: f_d
334 
335  f_d = device_get_ptr(f)
336 
337  conv = .false.
338  iter = 0
339 
340  if (present(niter)) then
341  max_iter = niter
342  else
343  max_iter = this%max_iter
344  end if
345 
346  associate(w => this%w, c => this%c, r => this%r, z => this%z, h => this%h, &
347  v => this%v, s => this%s, gam => this%gam, v_d => this%v_d, &
348  w_d => this%w_d, r_d => this%r_d, h_d => this%h_d, &
349  v_d_d => this%v_d_d, x_d => x%x_d, z_d_d => this%z_d_d, &
350  c_d => this%c_d)
351 
352  norm_fac = 1.0_rp / sqrt(coef%volume)
353  call rzero(gam, this%m_restart + 1)
354  call rone(s, this%m_restart)
355  call rone(c, this%m_restart)
356  call rzero(h, this%m_restart * this%m_restart)
357  call device_rzero(x%x_d, n)
358  call device_rzero(this%gam_d, this%m_restart + 1)
359  call device_rone(this%s_d, this%m_restart)
360  call device_rone(this%c_d, this%m_restart)
361 
362  call rzero(this%h, this%m_restart**2)
363 ! do j = 1, this%m_restart
364 ! call device_rzero(h_d(j), this%m_restart)
365 ! end do
366 
367  call this%monitor_start('GMRES')
368  do while (.not. conv .and. iter .lt. max_iter)
369 
370  if (iter .eq. 0) then
371  call device_copy(r_d, f_d, n)
372  else
373  call device_copy(r_d, f_d, n)
374  call ax%compute(w, x%x, coef, x%msh, x%Xh)
375  call gs_h%op(w, n, gs_op_add, this%gs_event)
376  call device_event_sync(this%gs_event)
377  call bc_list_apply(blst, w, n)
378  call device_sub2(r_d, w_d, n)
379  end if
380 
381  gam(1) = sqrt(device_glsc3(r_d, r_d, coef%mult_d, n))
382  if (iter .eq. 0) then
383  ksp_results%res_start = gam(1) * norm_fac
384  end if
385 
386  if (abscmp(gam(1), 0.0_rp)) return
387 
388  rnorm = 0.0_rp
389  temp = 1.0_rp / gam(1)
390  call device_cmult2(v_d(1), r_d, temp, n)
391  do j = 1, this%m_restart
392  iter = iter+1
393 
394  call this%M%solve(z(1,j), v(1,j), n)
395 
396  call ax%compute(w, z(1,j), coef, x%msh, x%Xh)
397  call gs_h%op(w, n, gs_op_add, this%gs_event)
398  call device_event_sync(this%gs_event)
399  call bc_list_apply(blst, w, n)
400 
401  if (neko_bcknd_opencl .eq. 1) then
402  do i = 1, j
403  h(i,j) = device_glsc3(w_d, v_d(i), coef%mult_d, n)
404 
405  call device_add2s2(w_d, v_d(i), -h(i,j), n)
406 
407  alpha2 = device_glsc3(w_d, w_d, coef%mult_d, n)
408  end do
409  else
410  call device_glsc3_many(h(1,j), w_d, v_d_d, coef%mult_d, j, n)
411 
412  call device_memcpy(h(:,j), h_d(j), j, &
413  host_to_device, sync = .false.)
414 
415  alpha2 = device_gmres_part2(w_d, v_d_d, h_d(j), &
416  coef%mult_d, j, n)
417 
418  end if
419 
420  alpha = sqrt(alpha2)
421  do i = 1, j-1
422  temp = h(i,j)
423  h(i,j) = c(i)*temp + s(i) * h(i+1,j)
424  h(i+1,j) = -s(i)*temp + c(i) * h(i+1,j)
425  end do
426 
427  rnorm = 0.0_rp
428  if (abscmp(alpha, 0.0_rp)) then
429  conv = .true.
430  exit
431  end if
432 
433  lr = sqrt(h(j,j) * h(j,j) + alpha2)
434  temp = 1.0_rp / lr
435  c(j) = h(j,j) * temp
436  s(j) = alpha * temp
437  h(j,j) = lr
438  call device_memcpy(h(:,j), h_d(j), j, &
439  host_to_device, sync = .false.)
440  gam(j+1) = -s(j) * gam(j)
441  gam(j) = c(j) * gam(j)
442 
443  rnorm = abs(gam(j+1)) * norm_fac
444  call this%monitor_iter(iter, rnorm)
445  if (rnorm .lt. this%abs_tol) then
446  conv = .true.
447  exit
448  end if
449 
450  if (iter + 1 .gt. max_iter) exit
451 
452  if (j .lt. this%m_restart) then
453  temp = 1.0_rp / alpha
454  call device_cmult2(v_d(j+1), w_d, temp, n)
455  end if
456 
457  end do
458 
459  j = min(j, this%m_restart)
460  do k = j, 1, -1
461  temp = gam(k)
462  do i = j, k+1, -1
463  temp = temp - h(k,i) * c(i)
464  end do
465  c(k) = temp / h(k,k)
466  end do
467 
468  if (neko_bcknd_opencl .eq. 1) then
469  do i = 1, j
470  call device_add2s2(x_d, this%z_d(i), c(i), n)
471  end do
472  else
473  call device_memcpy(c, c_d, j, host_to_device, sync = .false.)
474  call device_add2s2_many(x_d, z_d_d, c_d, j, n)
475  end if
476  end do
477 
478  end associate
479  call this%monitor_stop()
480  ksp_results%res_final = rnorm
481  ksp_results%iter = iter
482 
483  end function gmres_device_solve
484 
486  function gmres_device_solve_coupled(this, Ax, x, y, z, fx, fy, fz, &
487  n, coef, blstx, blsty, blstz, gs_h, niter) result(ksp_results)
488  class(gmres_device_t), intent(inout) :: this
489  class(ax_t), intent(in) :: ax
490  type(field_t), intent(inout) :: x
491  type(field_t), intent(inout) :: y
492  type(field_t), intent(inout) :: z
493  integer, intent(in) :: n
494  real(kind=rp), dimension(n), intent(in) :: fx
495  real(kind=rp), dimension(n), intent(in) :: fy
496  real(kind=rp), dimension(n), intent(in) :: fz
497  type(coef_t), intent(inout) :: coef
498  type(bc_list_t), intent(in) :: blstx
499  type(bc_list_t), intent(in) :: blsty
500  type(bc_list_t), intent(in) :: blstz
501  type(gs_t), intent(inout) :: gs_h
502  type(ksp_monitor_t), dimension(3) :: ksp_results
503  integer, optional, intent(in) :: niter
504 
505  ksp_results(1) = this%solve(ax, x, fx, n, coef, blstx, gs_h, niter)
506  ksp_results(2) = this%solve(ax, y, fy, n, coef, blsty, gs_h, niter)
507  ksp_results(3) = this%solve(ax, z, fz, n, coef, blstz, gs_h, niter)
508 
509  end function gmres_device_solve_coupled
510 
511 end module gmres_device
512 
513 
real cuda_gmres_part2(void *w, void *v, void *h, void *mult, int *j, int *n)
Definition: gmres_aux.cu:52
Defines a Matrix-vector product.
Definition: ax.f90:34
Defines a boundary condition.
Definition: bc.f90:34
Coefficients.
Definition: coef.f90:34
Definition: comm.F90:1
Identity Krylov preconditioner for accelerators.
subroutine, public device_add2s1(a_d, b_d, c1, n)
subroutine, public device_rzero(a_d, n)
Zero a real vector.
subroutine, public device_rone(a_d, n)
Set all elements to one.
subroutine, public device_add2s2(a_d, b_d, c1, n)
Vector addition with scalar multiplication (multiplication on first argument)
subroutine, public device_cmult2(a_d, b_d, c, n)
Multiplication by constant c .
subroutine, public device_add2s2_many(y_d, x_d_d, a_d, j, n)
real(kind=rp) function, public device_glsc3(a_d, b_d, c_d, n)
Weighted inner product .
subroutine, public device_copy(a_d, b_d, n)
Copy a vector .
Definition: device_math.F90:76
subroutine, public device_glsc3_many(h, w_d, v_d_d, mult_d, j, n)
subroutine, public device_sub2(a_d, b_d, n)
Vector substraction .
Device abstraction, common interface for various accelerators.
Definition: device.F90:34
Defines a field.
Definition: field.f90:34
Gather-scatter.
Defines various GMRES methods.
real(c_rp) function device_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
type(ksp_monitor_t) function gmres_device_solve(this, Ax, x, f, n, coef, blst, gs_h, niter)
Standard GMRES solve.
type(ksp_monitor_t) function, dimension(3) gmres_device_solve_coupled(this, Ax, x, y, z, fx, fy, fz, n, coef, blstx, blsty, blstz, gs_h, niter)
Standard GMRES coupled solve.
subroutine gmres_device_init(this, n, max_iter, M, m_restart, rel_tol, abs_tol, monitor)
Initialise a standard GMRES solver.
subroutine gmres_device_free(this)
Deallocate a standard GMRES solver.
Implements the base abstract type for Krylov solvers plus helper types.
Definition: krylov.f90:34
Definition: math.f90:60
subroutine, public rone(a, n)
Set all elements to one.
Definition: math.f90:228
subroutine, public rzero(a, n)
Zero a real vector.
Definition: math.f90:195
integer, parameter, public c_rp
Definition: num_types.f90:13
integer, parameter, public rp
Global precision used in computations.
Definition: num_types.f90:12
Krylov preconditioner.
Definition: precon.f90:34
Base type for a matrix-vector product providing .
Definition: ax.f90:43
A list of boundary conditions.
Definition: bc.f90:104
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition: coef.f90:55
Defines a canonical Krylov preconditioner for accelerators.
Standard preconditioned generalized minimal residual method.
Type for storing initial and final residuals in a Krylov solver.
Definition: krylov.f90:56
Base abstract type for a canonical Krylov method, solving .
Definition: krylov.f90:66
Defines a canonical Krylov preconditioner.
Definition: precon.f90:40