Loading [MathJax]/extensions/tex2jax.js
Neko 0.9.99
A portable framework for high-order spectral element flow simulations
All Classes Namespaces Files Functions Variables Typedefs Enumerator Macros Pages
gmres_device.F90
Go to the documentation of this file.
1! Copyright (c) 2022-2024, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
36 use krylov, only : ksp_t, ksp_monitor_t
37 use precon, only : pc_t
38 use ax_product, only : ax_t
39 use num_types, only: rp, c_rp
40 use field, only : field_t
41 use coefs, only : coef_t
42 use gather_scatter, only : gs_t, gs_op_add
43 use bc_list, only : bc_list_t
45 use math, only : rone, rzero, abscmp
50 use device
51 use utils, only : neko_error
52 use comm, only : neko_comm, mpi_in_place, mpi_sum, mpi_real_precision, &
53 mpi_allreduce, pe_size
54 use, intrinsic :: iso_c_binding, only : c_ptr, c_null_ptr, c_loc, &
55 c_associated, c_int, c_size_t, c_sizeof
56 implicit none
57 private
58
60 type, public, extends(ksp_t) :: gmres_device_t
61 integer :: m_restart
62 real(kind=rp), allocatable :: w(:)
63 real(kind=rp), allocatable :: c(:)
64 real(kind=rp), allocatable :: r(:)
65 real(kind=rp), allocatable :: z(:,:)
66 real(kind=rp), allocatable :: h(:,:)
67 real(kind=rp), allocatable :: v(:,:)
68 real(kind=rp), allocatable :: s(:)
69 real(kind=rp), allocatable :: gam(:)
70 type(c_ptr) :: w_d = c_null_ptr
71 type(c_ptr) :: c_d = c_null_ptr
72 type(c_ptr) :: r_d = c_null_ptr
73 type(c_ptr) :: s_d = c_null_ptr
74 type(c_ptr) :: gam_d = c_null_ptr
75 type(c_ptr), allocatable :: z_d(:), h_d(:), v_d(:)
76 type(c_ptr) :: z_d_d = c_null_ptr
77 type(c_ptr) :: h_d_d = c_null_ptr
78 type(c_ptr) :: v_d_d = c_null_ptr
79 type(c_ptr) :: gs_event = c_null_ptr
80 contains
81 procedure, pass(this) :: init => gmres_device_init
82 procedure, pass(this) :: free => gmres_device_free
83 procedure, pass(this) :: solve => gmres_device_solve
84 procedure, pass(this) :: solve_coupled => gmres_device_solve_coupled
85 end type gmres_device_t
86
87#ifdef HAVE_HIP
88 interface
89 real(c_rp) function hip_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) &
90 bind(c, name = 'hip_gmres_part2')
91 use, intrinsic :: iso_c_binding
92 import c_rp
93 implicit none
94 type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
95 integer(c_int) :: j, n
96 end function hip_gmres_part2
97 end interface
98#elif HAVE_CUDA
99
100 interface
101 real(c_rp) function cuda_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) &
102 bind(c, name = 'cuda_gmres_part2')
103 use, intrinsic :: iso_c_binding
104 import c_rp
105 implicit none
106 type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
107 integer(c_int) :: j, n
108 end function cuda_gmres_part2
109 end interface
110#endif
111
112contains
113
114 function device_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) result(alpha)
115 type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
116 integer(c_int) :: j, n
117 real(c_rp) :: alpha
118 integer :: ierr
119#ifdef HAVE_HIP
120 alpha = hip_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
121#elif HAVE_CUDA
122 alpha = cuda_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
123#else
124 call neko_error('No device backend configured')
125#endif
126
127#ifndef HAVE_DEVICE_MPI
128 if (pe_size .gt. 1) then
129 call mpi_allreduce(mpi_in_place, alpha, 1, &
130 mpi_real_precision, mpi_sum, neko_comm, ierr)
131 end if
132#endif
133
134 end function device_gmres_part2
135
137 subroutine gmres_device_init(this, n, max_iter, M, m_restart, &
138 rel_tol, abs_tol, monitor)
139 class(gmres_device_t), target, intent(inout) :: this
140 integer, intent(in) :: n
141 integer, intent(in) :: max_iter
142 class(pc_t), optional, intent(in), target :: M
143 integer, optional, intent(in) :: m_restart
144 real(kind=rp), optional, intent(in) :: rel_tol
145 real(kind=rp), optional, intent(in) :: abs_tol
146 logical, optional, intent(in) :: monitor
147 type(device_ident_t), target :: M_ident
148 type(c_ptr) :: ptr
149 integer(c_size_t) :: z_size
150 integer :: i
151
152 if (present(m_restart)) then
153 this%m_restart = m_restart
154 else
155 this%m_restart = 30
156 end if
157
158
159 call this%free()
160
161 if (present(m)) then
162 this%M => m
163 else
164 this%M => m_ident
165 end if
166
167 allocate(this%w(n))
168 allocate(this%r(n))
169 call device_map(this%w, this%w_d, n)
170 call device_map(this%r, this%r_d, n)
171
172 allocate(this%c(this%m_restart))
173 allocate(this%s(this%m_restart))
174 allocate(this%gam(this%m_restart + 1))
175 call device_map(this%c, this%c_d, this%m_restart)
176 call device_map(this%s, this%s_d, this%m_restart)
177 call device_map(this%gam, this%gam_d, this%m_restart+1)
178
179 allocate(this%z(n, this%m_restart))
180 allocate(this%v(n, this%m_restart))
181 allocate(this%h(this%m_restart, this%m_restart))
182 allocate(this%z_d(this%m_restart))
183 allocate(this%v_d(this%m_restart))
184 allocate(this%h_d(this%m_restart))
185 do i = 1, this%m_restart
186 this%z_d(i) = c_null_ptr
187 call device_map(this%z(:,i), this%z_d(i), n)
188
189 this%v_d(i) = c_null_ptr
190 call device_map(this%v(:,i), this%v_d(i), n)
191
192 this%h_d(i) = c_null_ptr
193 call device_map(this%h(:,i), this%h_d(i), this%m_restart)
194 end do
195
196 z_size = c_sizeof(c_null_ptr) * (this%m_restart)
197 call device_alloc(this%z_d_d, z_size)
198 call device_alloc(this%v_d_d, z_size)
199 call device_alloc(this%h_d_d, z_size)
200 ptr = c_loc(this%z_d)
201 call device_memcpy(ptr, this%z_d_d, z_size, &
202 host_to_device, sync = .false.)
203 ptr = c_loc(this%v_d)
204 call device_memcpy(ptr, this%v_d_d, z_size, &
205 host_to_device, sync = .false.)
206 ptr = c_loc(this%h_d)
207 call device_memcpy(ptr, this%h_d_d, z_size, &
208 host_to_device, sync = .false.)
209
210
211 if (present(rel_tol) .and. present(abs_tol) .and. present(monitor)) then
212 call this%ksp_init(max_iter, rel_tol, abs_tol, monitor = monitor)
213 else if (present(rel_tol) .and. present(abs_tol)) then
214 call this%ksp_init(max_iter, rel_tol, abs_tol)
215 else if (present(monitor) .and. present(abs_tol)) then
216 call this%ksp_init(max_iter, abs_tol = abs_tol, monitor = monitor)
217 else if (present(rel_tol) .and. present(monitor)) then
218 call this%ksp_init(max_iter, rel_tol, monitor = monitor)
219 else if (present(rel_tol)) then
220 call this%ksp_init(max_iter, rel_tol = rel_tol)
221 else if (present(abs_tol)) then
222 call this%ksp_init(max_iter, abs_tol = abs_tol)
223 else if (present(monitor)) then
224 call this%ksp_init(max_iter, monitor = monitor)
225 else
226 call this%ksp_init(max_iter)
227 end if
228
229 call device_event_create(this%gs_event, 2)
230
231 end subroutine gmres_device_init
232
234 subroutine gmres_device_free(this)
235 class(gmres_device_t), intent(inout) :: this
236 integer :: i
237
238 call this%ksp_free()
239
240 if (allocated(this%w)) then
241 deallocate(this%w)
242 end if
243
244 if (allocated(this%c)) then
245 deallocate(this%c)
246 end if
247
248 if (allocated(this%r)) then
249 deallocate(this%r)
250 end if
251
252 if (allocated(this%z)) then
253 deallocate(this%z)
254 end if
255
256 if (allocated(this%h)) then
257 deallocate(this%h)
258 end if
259
260 if (allocated(this%v)) then
261 deallocate(this%v)
262 end if
263
264 if (allocated(this%s)) then
265 deallocate(this%s)
266 end if
267 if (allocated(this%gam)) then
268 deallocate(this%gam)
269 end if
270
271 if (allocated(this%v_d)) then
272 do i = 1, this%m_restart
273 if (c_associated(this%v_d(i))) then
274 call device_free(this%v_d(i))
275 end if
276 end do
277 end if
278
279 if (allocated(this%z_d)) then
280 do i = 1, this%m_restart
281 if (c_associated(this%z_d(i))) then
282 call device_free(this%z_d(i))
283 end if
284 end do
285 end if
286 if (allocated(this%h_d)) then
287 do i = 1, this%m_restart
288 if (c_associated(this%h_d(i))) then
289 call device_free(this%h_d(i))
290 end if
291 end do
292 end if
293
294
295
296 if (c_associated(this%gam_d)) then
297 call device_free(this%gam_d)
298 end if
299 if (c_associated(this%w_d)) then
300 call device_free(this%w_d)
301 end if
302 if (c_associated(this%c_d)) then
303 call device_free(this%c_d)
304 end if
305 if (c_associated(this%r_d)) then
306 call device_free(this%r_d)
307 end if
308 if (c_associated(this%s_d)) then
309 call device_free(this%s_d)
310 end if
311
312 nullify(this%M)
313
314 if (c_associated(this%gs_event)) then
315 call device_event_destroy(this%gs_event)
316 end if
317
318 end subroutine gmres_device_free
319
321 function gmres_device_solve(this, Ax, x, f, n, coef, blst, gs_h, niter) &
322 result(ksp_results)
323 class(gmres_device_t), intent(inout) :: this
324 class(ax_t), intent(in) :: ax
325 type(field_t), intent(inout) :: x
326 integer, intent(in) :: n
327 real(kind=rp), dimension(n), intent(in) :: f
328 type(coef_t), intent(inout) :: coef
329 type(bc_list_t), intent(inout) :: blst
330 type(gs_t), intent(inout) :: gs_h
331 type(ksp_monitor_t) :: ksp_results
332 integer, optional, intent(in) :: niter
333 integer :: iter, max_iter
334 integer :: i, j, k
335 real(kind=rp) :: rnorm, alpha, temp, lr, alpha2, norm_fac
336 logical :: conv
337 type(c_ptr) :: f_d
338
339 f_d = device_get_ptr(f)
340
341 conv = .false.
342 iter = 0
343
344 if (present(niter)) then
345 max_iter = niter
346 else
347 max_iter = this%max_iter
348 end if
349
350 associate(w => this%w, c => this%c, r => this%r, z => this%z, h => this%h, &
351 v => this%v, s => this%s, gam => this%gam, v_d => this%v_d, &
352 w_d => this%w_d, r_d => this%r_d, h_d => this%h_d, &
353 v_d_d => this%v_d_d, x_d => x%x_d, z_d_d => this%z_d_d, &
354 c_d => this%c_d)
355
356 norm_fac = 1.0_rp / sqrt(coef%volume)
357 call rzero(gam, this%m_restart + 1)
358 call rone(s, this%m_restart)
359 call rone(c, this%m_restart)
360 call rzero(h, this%m_restart * this%m_restart)
361 call device_rzero(x%x_d, n)
362 call device_rzero(this%gam_d, this%m_restart + 1)
363 call device_rone(this%s_d, this%m_restart)
364 call device_rone(this%c_d, this%m_restart)
365
366 call rzero(this%h, this%m_restart**2)
367 ! do j = 1, this%m_restart
368 ! call device_rzero(h_d(j), this%m_restart)
369 ! end do
370
371 call this%monitor_start('GMRES')
372 do while (.not. conv .and. iter .lt. max_iter)
373
374 if (iter .eq. 0) then
375 call device_copy(r_d, f_d, n)
376 else
377 call device_copy(r_d, f_d, n)
378 call ax%compute(w, x%x, coef, x%msh, x%Xh)
379 call gs_h%op(w, n, gs_op_add, this%gs_event)
380 call device_event_sync(this%gs_event)
381 call blst%apply_scalar(w, n)
382 call device_sub2(r_d, w_d, n)
383 end if
384
385 gam(1) = sqrt(device_glsc3(r_d, r_d, coef%mult_d, n))
386 if (iter .eq. 0) then
387 ksp_results%res_start = gam(1) * norm_fac
388 end if
389
390 if (abscmp(gam(1), 0.0_rp)) return
391
392 rnorm = 0.0_rp
393 temp = 1.0_rp / gam(1)
394 call device_cmult2(v_d(1), r_d, temp, n)
395 do j = 1, this%m_restart
396 iter = iter+1
397
398 call this%M%solve(z(1,j), v(1,j), n)
399
400 call ax%compute(w, z(1,j), coef, x%msh, x%Xh)
401 call gs_h%op(w, n, gs_op_add, this%gs_event)
402 call device_event_sync(this%gs_event)
403 call blst%apply_scalar(w, n)
404
405 if (neko_bcknd_opencl .eq. 1) then
406 do i = 1, j
407 h(i,j) = device_glsc3(w_d, v_d(i), coef%mult_d, n)
408
409 call device_add2s2(w_d, v_d(i), -h(i,j), n)
410
411 alpha2 = device_glsc3(w_d, w_d, coef%mult_d, n)
412 end do
413 else
414 call device_glsc3_many(h(1,j), w_d, v_d_d, coef%mult_d, j, n)
415
416 call device_memcpy(h(:,j), h_d(j), j, &
417 host_to_device, sync = .false.)
418
419 alpha2 = device_gmres_part2(w_d, v_d_d, h_d(j), &
420 coef%mult_d, j, n)
421
422 end if
423
424 alpha = sqrt(alpha2)
425 do i = 1, j-1
426 temp = h(i,j)
427 h(i,j) = c(i)*temp + s(i) * h(i+1,j)
428 h(i+1,j) = -s(i)*temp + c(i) * h(i+1,j)
429 end do
430
431 rnorm = 0.0_rp
432 if (abscmp(alpha, 0.0_rp)) then
433 conv = .true.
434 exit
435 end if
436
437 lr = sqrt(h(j,j) * h(j,j) + alpha2)
438 temp = 1.0_rp / lr
439 c(j) = h(j,j) * temp
440 s(j) = alpha * temp
441 h(j,j) = lr
442 call device_memcpy(h(:,j), h_d(j), j, &
443 host_to_device, sync = .false.)
444 gam(j+1) = -s(j) * gam(j)
445 gam(j) = c(j) * gam(j)
446
447 rnorm = abs(gam(j+1)) * norm_fac
448 call this%monitor_iter(iter, rnorm)
449 if (rnorm .lt. this%abs_tol) then
450 conv = .true.
451 exit
452 end if
453
454 if (iter + 1 .gt. max_iter) exit
455
456 if (j .lt. this%m_restart) then
457 temp = 1.0_rp / alpha
458 call device_cmult2(v_d(j+1), w_d, temp, n)
459 end if
460
461 end do
462
463 j = min(j, this%m_restart)
464 do k = j, 1, -1
465 temp = gam(k)
466 do i = j, k+1, -1
467 temp = temp - h(k,i) * c(i)
468 end do
469 c(k) = temp / h(k,k)
470 end do
471
472 if (neko_bcknd_opencl .eq. 1) then
473 do i = 1, j
474 call device_add2s2(x_d, this%z_d(i), c(i), n)
475 end do
476 else
477 call device_memcpy(c, c_d, j, host_to_device, sync = .false.)
478 call device_add2s2_many(x_d, z_d_d, c_d, j, n)
479 end if
480 end do
481
482 end associate
483 call this%monitor_stop()
484 ksp_results%res_final = rnorm
485 ksp_results%iter = iter
486 ksp_results%converged = this%is_converged(iter, rnorm)
487
488 end function gmres_device_solve
489
491 function gmres_device_solve_coupled(this, Ax, x, y, z, fx, fy, fz, &
492 n, coef, blstx, blsty, blstz, gs_h, niter) result(ksp_results)
493 class(gmres_device_t), intent(inout) :: this
494 class(ax_t), intent(in) :: ax
495 type(field_t), intent(inout) :: x
496 type(field_t), intent(inout) :: y
497 type(field_t), intent(inout) :: z
498 integer, intent(in) :: n
499 real(kind=rp), dimension(n), intent(in) :: fx
500 real(kind=rp), dimension(n), intent(in) :: fy
501 real(kind=rp), dimension(n), intent(in) :: fz
502 type(coef_t), intent(inout) :: coef
503 type(bc_list_t), intent(inout) :: blstx
504 type(bc_list_t), intent(inout) :: blsty
505 type(bc_list_t), intent(inout) :: blstz
506 type(gs_t), intent(inout) :: gs_h
507 type(ksp_monitor_t), dimension(3) :: ksp_results
508 integer, optional, intent(in) :: niter
509
510 ksp_results(1) = this%solve(ax, x, fx, n, coef, blstx, gs_h, niter)
511 ksp_results(2) = this%solve(ax, y, fy, n, coef, blsty, gs_h, niter)
512 ksp_results(3) = this%solve(ax, z, fz, n, coef, blstz, gs_h, niter)
513
514 end function gmres_device_solve_coupled
515
516end module gmres_device
real cuda_gmres_part2(void *w, void *v, void *h, void *mult, int *j, int *n)
Definition gmres_aux.cu:62
Defines a Matrix-vector product.
Definition ax.f90:34
Defines a list of bc_t.
Definition bc_list.f90:34
Coefficients.
Definition coef.f90:34
Definition comm.F90:1
type(mpi_comm) neko_comm
MPI communicator.
Definition comm.F90:38
type(mpi_datatype) mpi_real_precision
MPI type for working precision of REAL types.
Definition comm.F90:45
integer pe_size
MPI size of communicator.
Definition comm.F90:53
Identity Krylov preconditioner for accelerators.
subroutine, public device_add2s1(a_d, b_d, c1, n)
subroutine, public device_rzero(a_d, n)
Zero a real vector.
subroutine, public device_rone(a_d, n)
Set all elements to one.
subroutine, public device_add2s2(a_d, b_d, c1, n)
Vector addition with scalar multiplication (multiplication on first argument)
subroutine, public device_cmult2(a_d, b_d, c, n)
Multiplication by constant c .
subroutine, public device_add2s2_many(y_d, x_d_d, a_d, j, n)
real(kind=rp) function, public device_glsc3(a_d, b_d, c_d, n)
Weighted inner product .
subroutine, public device_copy(a_d, b_d, n)
Copy a vector .
subroutine, public device_glsc3_many(h, w_d, v_d_d, mult_d, j, n)
subroutine, public device_sub2(a_d, b_d, n)
Vector substraction .
Device abstraction, common interface for various accelerators.
Definition device.F90:34
Defines a field.
Definition field.f90:34
Gather-scatter.
Defines various GMRES methods.
real(c_rp) function device_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
type(ksp_monitor_t) function gmres_device_solve(this, ax, x, f, n, coef, blst, gs_h, niter)
Standard GMRES solve.
subroutine gmres_device_init(this, n, max_iter, m, m_restart, rel_tol, abs_tol, monitor)
Initialise a standard GMRES solver.
subroutine gmres_device_free(this)
Deallocate a standard GMRES solver.
type(ksp_monitor_t) function, dimension(3) gmres_device_solve_coupled(this, ax, x, y, z, fx, fy, fz, n, coef, blstx, blsty, blstz, gs_h, niter)
Standard GMRES coupled solve.
Implements the base abstract type for Krylov solvers plus helper types.
Definition krylov.f90:34
Definition math.f90:60
subroutine, public rone(a, n)
Set all elements to one.
Definition math.f90:227
subroutine, public rzero(a, n)
Zero a real vector.
Definition math.f90:194
Build configurations.
integer, parameter neko_bcknd_opencl
integer, parameter, public c_rp
Definition num_types.f90:13
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Krylov preconditioner.
Definition precon.f90:34
Utilities.
Definition utils.f90:35
Base type for a matrix-vector product providing .
Definition ax.f90:43
A list of allocatable `bc_t`. Follows the standard interface of lists.
Definition bc_list.f90:47
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition coef.f90:55
Defines a canonical Krylov preconditioner for accelerators.
Standard preconditioned generalized minimal residual method.
Type for storing initial and final residuals in a Krylov solver.
Definition krylov.f90:56
Base abstract type for a canonical Krylov method, solving .
Definition krylov.f90:68
Defines a canonical Krylov preconditioner.
Definition precon.f90:40