Neko 1.99.3
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
gmres_device.F90
Go to the documentation of this file.
1! Copyright (c) 2022-2024, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
36 use krylov, only : ksp_t, ksp_monitor_t
37 use precon, only : pc_t
38 use ax_product, only : ax_t
39 use num_types, only: rp, c_rp
40 use field, only : field_t
41 use coefs, only : coef_t
42 use gather_scatter, only : gs_t, gs_op_add
43 use bc_list, only : bc_list_t
45 use math, only : rone, rzero, abscmp
50 use device
51 use utils, only : neko_error
53 use mpi_f08, only : mpi_in_place, mpi_sum, mpi_allreduce
54 use, intrinsic :: iso_c_binding, only : c_ptr, c_null_ptr, c_loc, &
55 c_associated, c_int, c_size_t, c_sizeof
56 implicit none
57 private
58
60 type, public, extends(ksp_t) :: gmres_device_t
61 integer :: m_restart = 30
62 real(kind=rp), allocatable :: w(:)
63 real(kind=rp), allocatable :: c(:)
64 real(kind=rp), allocatable :: r(:)
65 real(kind=rp), allocatable :: z(:,:)
66 real(kind=rp), allocatable :: h(:,:)
67 real(kind=rp), allocatable :: v(:,:)
68 real(kind=rp), allocatable :: s(:)
69 real(kind=rp), allocatable :: gam(:)
70 type(c_ptr) :: w_d = c_null_ptr
71 type(c_ptr) :: c_d = c_null_ptr
72 type(c_ptr) :: r_d = c_null_ptr
73 type(c_ptr) :: s_d = c_null_ptr
74 type(c_ptr) :: gam_d = c_null_ptr
75 type(c_ptr), allocatable :: z_d(:), h_d(:), v_d(:)
76 type(c_ptr) :: z_d_d = c_null_ptr
77 type(c_ptr) :: h_d_d = c_null_ptr
78 type(c_ptr) :: v_d_d = c_null_ptr
79 type(c_ptr) :: gs_event = c_null_ptr
80 contains
81 procedure, pass(this) :: init => gmres_device_init
82 procedure, pass(this) :: free => gmres_device_free
83 procedure, pass(this) :: solve => gmres_device_solve
84 procedure, pass(this) :: solve_coupled => gmres_device_solve_coupled
85 end type gmres_device_t
86
87#ifdef HAVE_HIP
88 interface
89 real(c_rp) function hip_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) &
90 bind(c, name = 'hip_gmres_part2')
91 use, intrinsic :: iso_c_binding
92 import c_rp
93 implicit none
94 type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
95 integer(c_int) :: j, n
96 end function hip_gmres_part2
97 end interface
98#elif HAVE_CUDA
99
100 interface
101 real(c_rp) function cuda_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) &
102 bind(c, name = 'cuda_gmres_part2')
103 use, intrinsic :: iso_c_binding
104 import c_rp
105 implicit none
106 type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
107 integer(c_int) :: j, n
108 end function cuda_gmres_part2
109 end interface
110#endif
111
112contains
113
114 function device_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) result(alpha)
115 type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
116 integer(c_int) :: j, n
117 real(c_rp) :: alpha
118 integer :: ierr
119#ifdef HAVE_HIP
120 alpha = hip_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
121#elif HAVE_CUDA
122 alpha = cuda_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
123#else
124 call neko_error('No device backend configured')
125#endif
126
127#ifndef HAVE_DEVICE_MPI
128 if (pe_size .gt. 1) then
129 call mpi_allreduce(mpi_in_place, alpha, 1, &
130 mpi_real_precision, mpi_sum, neko_comm, ierr)
131 end if
132#endif
133
134 end function device_gmres_part2
135
137 subroutine gmres_device_init(this, n, max_iter, M, rel_tol, abs_tol, monitor)
138 class(gmres_device_t), target, intent(inout) :: this
139 integer, intent(in) :: n
140 integer, intent(in) :: max_iter
141 class(pc_t), optional, intent(in), target :: M
142 real(kind=rp), optional, intent(in) :: rel_tol
143 real(kind=rp), optional, intent(in) :: abs_tol
144 logical, optional, intent(in) :: monitor
145 type(device_ident_t), target :: M_ident
146 type(c_ptr) :: ptr
147 integer(c_size_t) :: z_size
148 integer :: i
149
150 call this%free()
151
152 if (present(m)) then
153 this%M => m
154 else
155 this%M => m_ident
156 end if
157
158 allocate(this%w(n))
159 allocate(this%r(n))
160 call device_map(this%w, this%w_d, n)
161 call device_map(this%r, this%r_d, n)
162
163 allocate(this%c(this%m_restart))
164 allocate(this%s(this%m_restart))
165 allocate(this%gam(this%m_restart + 1))
166 call device_map(this%c, this%c_d, this%m_restart)
167 call device_map(this%s, this%s_d, this%m_restart)
168 call device_map(this%gam, this%gam_d, this%m_restart+1)
169
170 allocate(this%z(n, this%m_restart))
171 allocate(this%v(n, this%m_restart))
172 allocate(this%h(this%m_restart, this%m_restart))
173 allocate(this%z_d(this%m_restart))
174 allocate(this%v_d(this%m_restart))
175 allocate(this%h_d(this%m_restart))
176 do i = 1, this%m_restart
177 this%z_d(i) = c_null_ptr
178 call device_map(this%z(:,i), this%z_d(i), n)
179
180 this%v_d(i) = c_null_ptr
181 call device_map(this%v(:,i), this%v_d(i), n)
182
183 this%h_d(i) = c_null_ptr
184 call device_map(this%h(:,i), this%h_d(i), this%m_restart)
185 end do
186
187 z_size = c_sizeof(c_null_ptr) * (this%m_restart)
188 call device_alloc(this%z_d_d, z_size)
189 call device_alloc(this%v_d_d, z_size)
190 call device_alloc(this%h_d_d, z_size)
191 ptr = c_loc(this%z_d)
192 call device_memcpy(ptr, this%z_d_d, z_size, &
193 host_to_device, sync = .false.)
194 ptr = c_loc(this%v_d)
195 call device_memcpy(ptr, this%v_d_d, z_size, &
196 host_to_device, sync = .false.)
197 ptr = c_loc(this%h_d)
198 call device_memcpy(ptr, this%h_d_d, z_size, &
199 host_to_device, sync = .false.)
200
201
202 if (present(rel_tol) .and. present(abs_tol) .and. present(monitor)) then
203 call this%ksp_init(max_iter, rel_tol, abs_tol, monitor = monitor)
204 else if (present(rel_tol) .and. present(abs_tol)) then
205 call this%ksp_init(max_iter, rel_tol, abs_tol)
206 else if (present(monitor) .and. present(abs_tol)) then
207 call this%ksp_init(max_iter, abs_tol = abs_tol, monitor = monitor)
208 else if (present(rel_tol) .and. present(monitor)) then
209 call this%ksp_init(max_iter, rel_tol, monitor = monitor)
210 else if (present(rel_tol)) then
211 call this%ksp_init(max_iter, rel_tol = rel_tol)
212 else if (present(abs_tol)) then
213 call this%ksp_init(max_iter, abs_tol = abs_tol)
214 else if (present(monitor)) then
215 call this%ksp_init(max_iter, monitor = monitor)
216 else
217 call this%ksp_init(max_iter)
218 end if
219
220 call device_event_create(this%gs_event, 2)
221
222 end subroutine gmres_device_init
223
225 subroutine gmres_device_free(this)
226 class(gmres_device_t), intent(inout) :: this
227 integer :: i
228
229 call this%ksp_free()
230
231 if (allocated(this%w)) then
232 if (c_associated(this%w_d)) then
233 call device_unmap(this%w, this%w_d)
234 end if
235 deallocate(this%w)
236 end if
237
238 if (allocated(this%c)) then
239 if (c_associated(this%c_d)) then
240 call device_unmap(this%c, this%c_d)
241 end if
242 deallocate(this%c)
243 end if
244
245 if (allocated(this%r)) then
246 if (c_associated(this%r_d)) then
247 call device_unmap(this%r, this%r_d)
248 end if
249 deallocate(this%r)
250 end if
251
252 if (allocated(this%z)) then
253 if (allocated(this%z_d)) then
254 do i = 1, this%m_restart
255 if (c_associated(this%z_d(i))) then
256 call device_unmap(this%z(:,i), this%z_d(i))
257 end if
258 end do
259 end if
260 deallocate(this%z)
261 end if
262
263 if (allocated(this%h)) then
264 if (allocated(this%h_d)) then
265 do i = 1, this%m_restart
266 if (c_associated(this%h_d(i))) then
267 call device_unmap(this%h(:,i), this%h_d(i))
268 end if
269 end do
270 end if
271 deallocate(this%h)
272 end if
273
274 if (allocated(this%v)) then
275 if (allocated(this%v_d)) then
276 do i = 1, this%m_restart
277 if (c_associated(this%v_d(i))) then
278 call device_unmap(this%v(:,i), this%v_d(i))
279 end if
280 end do
281 end if
282 deallocate(this%v)
283 end if
284
285 if (allocated(this%s)) then
286 if (c_associated(this%s_d)) then
287 call device_unmap(this%s, this%s_d)
288 end if
289 deallocate(this%s)
290 end if
291 if (allocated(this%gam)) then
292 if (c_associated(this%gam_d)) then
293 call device_unmap(this%gam, this%gam_d)
294 end if
295 deallocate(this%gam)
296 end if
297
298 if (c_associated(this%z_d_d)) then
299 call device_free(this%z_d_d)
300 end if
301 if (c_associated(this%v_d_d)) then
302 call device_free(this%v_d_d)
303 end if
304 if (c_associated(this%h_d_d)) then
305 call device_free(this%h_d_d)
306 end if
307
308 nullify(this%M)
309
310 if (c_associated(this%gs_event)) then
311 call device_event_destroy(this%gs_event)
312 end if
313
314 end subroutine gmres_device_free
315
317 function gmres_device_solve(this, Ax, x, f, n, coef, blst, gs_h, niter) &
318 result(ksp_results)
319 class(gmres_device_t), intent(inout) :: this
320 class(ax_t), intent(in) :: ax
321 type(field_t), intent(inout) :: x
322 integer, intent(in) :: n
323 real(kind=rp), dimension(n), intent(in) :: f
324 type(coef_t), intent(inout) :: coef
325 type(bc_list_t), intent(inout) :: blst
326 type(gs_t), intent(inout) :: gs_h
327 type(ksp_monitor_t) :: ksp_results
328 integer, optional, intent(in) :: niter
329 integer :: iter, max_iter
330 integer :: i, j, k
331 real(kind=rp) :: rnorm, alpha, temp, lr, alpha2, norm_fac
332 logical :: conv
333 type(c_ptr) :: f_d
334
335 f_d = device_get_ptr(f)
336
337 conv = .false.
338 iter = 0
339 rnorm = 0.0_rp
340
341 if (present(niter)) then
342 max_iter = niter
343 else
344 max_iter = this%max_iter
345 end if
346
347 associate(w => this%w, c => this%c, r => this%r, z => this%z, h => this%h, &
348 v => this%v, s => this%s, gam => this%gam, v_d => this%v_d, &
349 w_d => this%w_d, r_d => this%r_d, h_d => this%h_d, &
350 v_d_d => this%v_d_d, x_d => x%x_d, z_d_d => this%z_d_d, &
351 c_d => this%c_d)
352
353 norm_fac = 1.0_rp / sqrt(coef%volume)
354 call rzero(gam, this%m_restart + 1)
355 call rone(s, this%m_restart)
356 call rone(c, this%m_restart)
357 call rzero(h, this%m_restart * this%m_restart)
358 call device_rzero(x%x_d, n)
359 call device_rzero(this%gam_d, this%m_restart + 1)
360 call device_rone(this%s_d, this%m_restart)
361 call device_rone(this%c_d, this%m_restart)
362
363 call rzero(this%h, this%m_restart**2)
364 ! do j = 1, this%m_restart
365 ! call device_rzero(h_d(j), this%m_restart)
366 ! end do
367
368 call this%monitor_start('GMRES')
369 do while (.not. conv .and. iter .lt. max_iter)
370
371 if (iter .eq. 0) then
372 call device_copy(r_d, f_d, n)
373 else
374 call device_copy(r_d, f_d, n)
375 call ax%compute(w, x%x, coef, x%msh, x%Xh)
376 call gs_h%op(w, n, gs_op_add, this%gs_event)
377 call device_event_sync(this%gs_event)
378 call blst%apply_scalar(w, n)
379 call device_sub2(r_d, w_d, n)
380 end if
381
382 gam(1) = sqrt(device_glsc3(r_d, r_d, coef%mult_d, n))
383 if (iter .eq. 0) then
384 ksp_results%res_start = gam(1) * norm_fac
385 end if
386
387 if (abscmp(gam(1), 0.0_rp)) exit
388
389 rnorm = 0.0_rp
390 temp = 1.0_rp / gam(1)
391 call device_cmult2(v_d(1), r_d, temp, n)
392 do j = 1, this%m_restart
393 iter = iter+1
394
395 call this%M%solve(z(1,j), v(1,j), n)
396
397 call ax%compute(w, z(1,j), coef, x%msh, x%Xh)
398 call gs_h%op(w, n, gs_op_add, this%gs_event)
399 call device_event_sync(this%gs_event)
400 call blst%apply_scalar(w, n)
401
402 if (neko_bcknd_opencl .eq. 1 .or. neko_bcknd_metal .eq. 1) then
403 do i = 1, j
404 h(i,j) = device_glsc3(w_d, v_d(i), coef%mult_d, n)
405
406 call device_add2s2(w_d, v_d(i), -h(i,j), n)
407
408 alpha2 = device_glsc3(w_d, w_d, coef%mult_d, n)
409 end do
410 else
411 call device_glsc3_many(h(1,j), w_d, v_d_d, coef%mult_d, j, n)
412
413 call device_memcpy(h(:,j), h_d(j), j, &
414 host_to_device, sync = .false.)
415
416 alpha2 = device_gmres_part2(w_d, v_d_d, h_d(j), &
417 coef%mult_d, j, n)
418
419 end if
420
421 alpha = sqrt(alpha2)
422 do i = 1, j-1
423 temp = h(i,j)
424 h(i,j) = c(i)*temp + s(i) * h(i+1,j)
425 h(i+1,j) = -s(i)*temp + c(i) * h(i+1,j)
426 end do
427
428 rnorm = 0.0_rp
429 if (abscmp(alpha, 0.0_rp)) then
430 conv = .true.
431 exit
432 end if
433
434 lr = sqrt(h(j,j) * h(j,j) + alpha2)
435 temp = 1.0_rp / lr
436 c(j) = h(j,j) * temp
437 s(j) = alpha * temp
438 h(j,j) = lr
439 call device_memcpy(h(:,j), h_d(j), j, &
440 host_to_device, sync = .false.)
441 gam(j+1) = -s(j) * gam(j)
442 gam(j) = c(j) * gam(j)
443
444 rnorm = abs(gam(j+1)) * norm_fac
445 call this%monitor_iter(iter, rnorm)
446 if (rnorm .lt. this%abs_tol) then
447 conv = .true.
448 exit
449 end if
450
451 if (iter + 1 .gt. max_iter) exit
452
453 if (j .lt. this%m_restart) then
454 temp = 1.0_rp / alpha
455 call device_cmult2(v_d(j+1), w_d, temp, n)
456 end if
457
458 end do
459
460 j = min(j, this%m_restart)
461 do k = j, 1, -1
462 temp = gam(k)
463 do i = j, k+1, -1
464 temp = temp - h(k,i) * c(i)
465 end do
466 c(k) = temp / h(k,k)
467 end do
468
469 if (neko_bcknd_opencl .eq. 1 .or. neko_bcknd_metal .eq. 1) then
470 do i = 1, j
471 call device_add2s2(x_d, this%z_d(i), c(i), n)
472 end do
473 else
474 call device_memcpy(c, c_d, j, host_to_device, sync = .false.)
475 call device_add2s2_many(x_d, z_d_d, c_d, j, n)
476 end if
477 end do
478
479 end associate
480 call this%monitor_stop()
481 ksp_results%res_final = rnorm
482 ksp_results%iter = iter
483 ksp_results%converged = this%is_converged(iter, rnorm)
484
485 end function gmres_device_solve
486
488 function gmres_device_solve_coupled(this, Ax, x, y, z, fx, fy, fz, &
489 n, coef, blstx, blsty, blstz, gs_h, niter) result(ksp_results)
490 class(gmres_device_t), intent(inout) :: this
491 class(ax_t), intent(in) :: ax
492 type(field_t), intent(inout) :: x
493 type(field_t), intent(inout) :: y
494 type(field_t), intent(inout) :: z
495 integer, intent(in) :: n
496 real(kind=rp), dimension(n), intent(in) :: fx
497 real(kind=rp), dimension(n), intent(in) :: fy
498 real(kind=rp), dimension(n), intent(in) :: fz
499 type(coef_t), intent(inout) :: coef
500 type(bc_list_t), intent(inout) :: blstx
501 type(bc_list_t), intent(inout) :: blsty
502 type(bc_list_t), intent(inout) :: blstz
503 type(gs_t), intent(inout) :: gs_h
504 type(ksp_monitor_t), dimension(3) :: ksp_results
505 integer, optional, intent(in) :: niter
506
507 ksp_results(1) = this%solve(ax, x, fx, n, coef, blstx, gs_h, niter)
508 ksp_results(2) = this%solve(ax, y, fy, n, coef, blsty, gs_h, niter)
509 ksp_results(3) = this%solve(ax, z, fz, n, coef, blstz, gs_h, niter)
510
511 end function gmres_device_solve_coupled
512
513end module gmres_device
__device__ T solve(const T u, const T y, const T guess, const T nu, const T kappa, const T B)
real cuda_gmres_part2(void *w, void *v, void *h, void *mult, int *j, int *n)
Definition gmres_aux.cu:62
Defines a Matrix-vector product.
Definition ax.f90:34
Defines a list of bc_t.
Definition bc_list.f90:34
Coefficients.
Definition coef.f90:34
Definition comm.F90:1
type(mpi_datatype), public mpi_real_precision
MPI type for working precision of REAL types.
Definition comm.F90:53
integer, public pe_size
MPI size of communicator.
Definition comm.F90:61
type(mpi_comm), public neko_comm
MPI communicator.
Definition comm.F90:45
Identity Krylov preconditioner for accelerators.
subroutine, public device_add2s1(a_d, b_d, c1, n, strm)
subroutine, public device_add2s2_many(y_d, x_d_d, a_d, j, n, strm)
subroutine, public device_add2s2(a_d, b_d, c1, n, strm)
Vector addition with scalar multiplication (multiplication on first argument)
subroutine, public device_rzero(a_d, n, strm)
Zero a real vector.
subroutine, public device_rone(a_d, n, strm)
Set all elements to one.
subroutine, public device_glsc3_many(h, w_d, v_d_d, mult_d, j, n, strm)
subroutine, public device_sub2(a_d, b_d, n, strm)
Vector substraction .
subroutine, public device_copy(a_d, b_d, n, strm)
Copy a vector .
real(kind=rp) function, public device_glsc3(a_d, b_d, c_d, n, strm)
Weighted inner product .
subroutine, public device_cmult2(a_d, b_d, c, n, strm)
Multiplication by constant c .
Device abstraction, common interface for various accelerators.
Definition device.F90:34
Defines a field.
Definition field.f90:34
Gather-scatter.
Defines various GMRES methods.
real(c_rp) function device_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
subroutine gmres_device_init(this, n, max_iter, m, rel_tol, abs_tol, monitor)
Initialise a standard GMRES solver.
type(ksp_monitor_t) function gmres_device_solve(this, ax, x, f, n, coef, blst, gs_h, niter)
Standard GMRES solve.
subroutine gmres_device_free(this)
Deallocate a standard GMRES solver.
type(ksp_monitor_t) function, dimension(3) gmres_device_solve_coupled(this, ax, x, y, z, fx, fy, fz, n, coef, blstx, blsty, blstz, gs_h, niter)
Standard GMRES coupled solve.
Implements the base abstract type for Krylov solvers plus helper types.
Definition krylov.f90:34
Definition math.f90:60
subroutine, public rone(a, n)
Set all elements to one.
Definition math.f90:275
subroutine, public rzero(a, n)
Zero a real vector.
Definition math.f90:233
Build configurations.
integer, parameter neko_bcknd_opencl
integer, parameter neko_bcknd_metal
integer, parameter, public c_rp
Definition num_types.f90:13
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Krylov preconditioner.
Definition precon.f90:34
Utilities.
Definition utils.f90:35
Base type for a matrix-vector product providing .
Definition ax.f90:43
A list of allocatable `bc_t`. Follows the standard interface of lists.
Definition bc_list.f90:49
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition coef.f90:63
Defines a canonical Krylov preconditioner for accelerators.
Standard preconditioned generalized minimal residual method.
Type for storing initial and final residuals in a Krylov solver.
Definition krylov.f90:56
Base abstract type for a canonical Krylov method, solving .
Definition krylov.f90:73
Defines a canonical Krylov preconditioner.
Definition precon.f90:40