Neko 1.99.1
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
gmres_device.F90
Go to the documentation of this file.
1! Copyright (c) 2022-2024, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
36 use krylov, only : ksp_t, ksp_monitor_t
37 use precon, only : pc_t
38 use ax_product, only : ax_t
39 use num_types, only: rp, c_rp
40 use field, only : field_t
41 use coefs, only : coef_t
42 use gather_scatter, only : gs_t, gs_op_add
43 use bc_list, only : bc_list_t
45 use math, only : rone, rzero, abscmp
50 use device
51 use utils, only : neko_error
53 use mpi_f08, only : mpi_in_place, mpi_sum, mpi_allreduce
54 use, intrinsic :: iso_c_binding, only : c_ptr, c_null_ptr, c_loc, &
55 c_associated, c_int, c_size_t, c_sizeof
56 implicit none
57 private
58
60 type, public, extends(ksp_t) :: gmres_device_t
61 integer :: m_restart = 30
62 real(kind=rp), allocatable :: w(:)
63 real(kind=rp), allocatable :: c(:)
64 real(kind=rp), allocatable :: r(:)
65 real(kind=rp), allocatable :: z(:,:)
66 real(kind=rp), allocatable :: h(:,:)
67 real(kind=rp), allocatable :: v(:,:)
68 real(kind=rp), allocatable :: s(:)
69 real(kind=rp), allocatable :: gam(:)
70 type(c_ptr) :: w_d = c_null_ptr
71 type(c_ptr) :: c_d = c_null_ptr
72 type(c_ptr) :: r_d = c_null_ptr
73 type(c_ptr) :: s_d = c_null_ptr
74 type(c_ptr) :: gam_d = c_null_ptr
75 type(c_ptr), allocatable :: z_d(:), h_d(:), v_d(:)
76 type(c_ptr) :: z_d_d = c_null_ptr
77 type(c_ptr) :: h_d_d = c_null_ptr
78 type(c_ptr) :: v_d_d = c_null_ptr
79 type(c_ptr) :: gs_event = c_null_ptr
80 contains
81 procedure, pass(this) :: init => gmres_device_init
82 procedure, pass(this) :: free => gmres_device_free
83 procedure, pass(this) :: solve => gmres_device_solve
84 procedure, pass(this) :: solve_coupled => gmres_device_solve_coupled
85 end type gmres_device_t
86
87#ifdef HAVE_HIP
88 interface
89 real(c_rp) function hip_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) &
90 bind(c, name = 'hip_gmres_part2')
91 use, intrinsic :: iso_c_binding
92 import c_rp
93 implicit none
94 type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
95 integer(c_int) :: j, n
96 end function hip_gmres_part2
97 end interface
98#elif HAVE_CUDA
99
100 interface
101 real(c_rp) function cuda_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) &
102 bind(c, name = 'cuda_gmres_part2')
103 use, intrinsic :: iso_c_binding
104 import c_rp
105 implicit none
106 type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
107 integer(c_int) :: j, n
108 end function cuda_gmres_part2
109 end interface
110#endif
111
112contains
113
114 function device_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) result(alpha)
115 type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
116 integer(c_int) :: j, n
117 real(c_rp) :: alpha
118 integer :: ierr
119#ifdef HAVE_HIP
120 alpha = hip_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
121#elif HAVE_CUDA
122 alpha = cuda_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
123#else
124 call neko_error('No device backend configured')
125#endif
126
127#ifndef HAVE_DEVICE_MPI
128 if (pe_size .gt. 1) then
129 call mpi_allreduce(mpi_in_place, alpha, 1, &
130 mpi_real_precision, mpi_sum, neko_comm, ierr)
131 end if
132#endif
133
134 end function device_gmres_part2
135
137 subroutine gmres_device_init(this, n, max_iter, M, rel_tol, abs_tol, monitor)
138 class(gmres_device_t), target, intent(inout) :: this
139 integer, intent(in) :: n
140 integer, intent(in) :: max_iter
141 class(pc_t), optional, intent(in), target :: M
142 real(kind=rp), optional, intent(in) :: rel_tol
143 real(kind=rp), optional, intent(in) :: abs_tol
144 logical, optional, intent(in) :: monitor
145 type(device_ident_t), target :: M_ident
146 type(c_ptr) :: ptr
147 integer(c_size_t) :: z_size
148 integer :: i
149
150 call this%free()
151
152 if (present(m)) then
153 this%M => m
154 else
155 this%M => m_ident
156 end if
157
158 allocate(this%w(n))
159 allocate(this%r(n))
160 call device_map(this%w, this%w_d, n)
161 call device_map(this%r, this%r_d, n)
162
163 allocate(this%c(this%m_restart))
164 allocate(this%s(this%m_restart))
165 allocate(this%gam(this%m_restart + 1))
166 call device_map(this%c, this%c_d, this%m_restart)
167 call device_map(this%s, this%s_d, this%m_restart)
168 call device_map(this%gam, this%gam_d, this%m_restart+1)
169
170 allocate(this%z(n, this%m_restart))
171 allocate(this%v(n, this%m_restart))
172 allocate(this%h(this%m_restart, this%m_restart))
173 allocate(this%z_d(this%m_restart))
174 allocate(this%v_d(this%m_restart))
175 allocate(this%h_d(this%m_restart))
176 do i = 1, this%m_restart
177 this%z_d(i) = c_null_ptr
178 call device_map(this%z(:,i), this%z_d(i), n)
179
180 this%v_d(i) = c_null_ptr
181 call device_map(this%v(:,i), this%v_d(i), n)
182
183 this%h_d(i) = c_null_ptr
184 call device_map(this%h(:,i), this%h_d(i), this%m_restart)
185 end do
186
187 z_size = c_sizeof(c_null_ptr) * (this%m_restart)
188 call device_alloc(this%z_d_d, z_size)
189 call device_alloc(this%v_d_d, z_size)
190 call device_alloc(this%h_d_d, z_size)
191 ptr = c_loc(this%z_d)
192 call device_memcpy(ptr, this%z_d_d, z_size, &
193 host_to_device, sync = .false.)
194 ptr = c_loc(this%v_d)
195 call device_memcpy(ptr, this%v_d_d, z_size, &
196 host_to_device, sync = .false.)
197 ptr = c_loc(this%h_d)
198 call device_memcpy(ptr, this%h_d_d, z_size, &
199 host_to_device, sync = .false.)
200
201
202 if (present(rel_tol) .and. present(abs_tol) .and. present(monitor)) then
203 call this%ksp_init(max_iter, rel_tol, abs_tol, monitor = monitor)
204 else if (present(rel_tol) .and. present(abs_tol)) then
205 call this%ksp_init(max_iter, rel_tol, abs_tol)
206 else if (present(monitor) .and. present(abs_tol)) then
207 call this%ksp_init(max_iter, abs_tol = abs_tol, monitor = monitor)
208 else if (present(rel_tol) .and. present(monitor)) then
209 call this%ksp_init(max_iter, rel_tol, monitor = monitor)
210 else if (present(rel_tol)) then
211 call this%ksp_init(max_iter, rel_tol = rel_tol)
212 else if (present(abs_tol)) then
213 call this%ksp_init(max_iter, abs_tol = abs_tol)
214 else if (present(monitor)) then
215 call this%ksp_init(max_iter, monitor = monitor)
216 else
217 call this%ksp_init(max_iter)
218 end if
219
220 call device_event_create(this%gs_event, 2)
221
222 end subroutine gmres_device_init
223
225 subroutine gmres_device_free(this)
226 class(gmres_device_t), intent(inout) :: this
227 integer :: i
228
229 call this%ksp_free()
230
231 if (allocated(this%w)) then
232 deallocate(this%w)
233 end if
234
235 if (allocated(this%c)) then
236 deallocate(this%c)
237 end if
238
239 if (allocated(this%r)) then
240 deallocate(this%r)
241 end if
242
243 if (allocated(this%z)) then
244 deallocate(this%z)
245 end if
246
247 if (allocated(this%h)) then
248 deallocate(this%h)
249 end if
250
251 if (allocated(this%v)) then
252 deallocate(this%v)
253 end if
254
255 if (allocated(this%s)) then
256 deallocate(this%s)
257 end if
258 if (allocated(this%gam)) then
259 deallocate(this%gam)
260 end if
261
262 if (allocated(this%v_d)) then
263 do i = 1, this%m_restart
264 if (c_associated(this%v_d(i))) then
265 call device_free(this%v_d(i))
266 end if
267 end do
268 end if
269
270 if (allocated(this%z_d)) then
271 do i = 1, this%m_restart
272 if (c_associated(this%z_d(i))) then
273 call device_free(this%z_d(i))
274 end if
275 end do
276 end if
277 if (allocated(this%h_d)) then
278 do i = 1, this%m_restart
279 if (c_associated(this%h_d(i))) then
280 call device_free(this%h_d(i))
281 end if
282 end do
283 end if
284
285
286
287 if (c_associated(this%gam_d)) then
288 call device_free(this%gam_d)
289 end if
290 if (c_associated(this%w_d)) then
291 call device_free(this%w_d)
292 end if
293 if (c_associated(this%c_d)) then
294 call device_free(this%c_d)
295 end if
296 if (c_associated(this%r_d)) then
297 call device_free(this%r_d)
298 end if
299 if (c_associated(this%s_d)) then
300 call device_free(this%s_d)
301 end if
302 if (c_associated(this%z_d_d)) then
303 call device_free(this%z_d_d)
304 end if
305 if (c_associated(this%v_d_d)) then
306 call device_free(this%v_d_d)
307 end if
308 if (c_associated(this%h_d_d)) then
309 call device_free(this%h_d_d)
310 end if
311
312 nullify(this%M)
313
314 if (c_associated(this%gs_event)) then
315 call device_event_destroy(this%gs_event)
316 end if
317
318 end subroutine gmres_device_free
319
321 function gmres_device_solve(this, Ax, x, f, n, coef, blst, gs_h, niter) &
322 result(ksp_results)
323 class(gmres_device_t), intent(inout) :: this
324 class(ax_t), intent(in) :: ax
325 type(field_t), intent(inout) :: x
326 integer, intent(in) :: n
327 real(kind=rp), dimension(n), intent(in) :: f
328 type(coef_t), intent(inout) :: coef
329 type(bc_list_t), intent(inout) :: blst
330 type(gs_t), intent(inout) :: gs_h
331 type(ksp_monitor_t) :: ksp_results
332 integer, optional, intent(in) :: niter
333 integer :: iter, max_iter
334 integer :: i, j, k
335 real(kind=rp) :: rnorm, alpha, temp, lr, alpha2, norm_fac
336 logical :: conv
337 type(c_ptr) :: f_d
338
339 f_d = device_get_ptr(f)
340
341 conv = .false.
342 iter = 0
343 rnorm = 0.0_rp
344
345 if (present(niter)) then
346 max_iter = niter
347 else
348 max_iter = this%max_iter
349 end if
350
351 associate(w => this%w, c => this%c, r => this%r, z => this%z, h => this%h, &
352 v => this%v, s => this%s, gam => this%gam, v_d => this%v_d, &
353 w_d => this%w_d, r_d => this%r_d, h_d => this%h_d, &
354 v_d_d => this%v_d_d, x_d => x%x_d, z_d_d => this%z_d_d, &
355 c_d => this%c_d)
356
357 norm_fac = 1.0_rp / sqrt(coef%volume)
358 call rzero(gam, this%m_restart + 1)
359 call rone(s, this%m_restart)
360 call rone(c, this%m_restart)
361 call rzero(h, this%m_restart * this%m_restart)
362 call device_rzero(x%x_d, n)
363 call device_rzero(this%gam_d, this%m_restart + 1)
364 call device_rone(this%s_d, this%m_restart)
365 call device_rone(this%c_d, this%m_restart)
366
367 call rzero(this%h, this%m_restart**2)
368 ! do j = 1, this%m_restart
369 ! call device_rzero(h_d(j), this%m_restart)
370 ! end do
371
372 call this%monitor_start('GMRES')
373 do while (.not. conv .and. iter .lt. max_iter)
374
375 if (iter .eq. 0) then
376 call device_copy(r_d, f_d, n)
377 else
378 call device_copy(r_d, f_d, n)
379 call ax%compute(w, x%x, coef, x%msh, x%Xh)
380 call gs_h%op(w, n, gs_op_add, this%gs_event)
381 call device_event_sync(this%gs_event)
382 call blst%apply_scalar(w, n)
383 call device_sub2(r_d, w_d, n)
384 end if
385
386 gam(1) = sqrt(device_glsc3(r_d, r_d, coef%mult_d, n))
387 if (iter .eq. 0) then
388 ksp_results%res_start = gam(1) * norm_fac
389 end if
390
391 if (abscmp(gam(1), 0.0_rp)) exit
392
393 rnorm = 0.0_rp
394 temp = 1.0_rp / gam(1)
395 call device_cmult2(v_d(1), r_d, temp, n)
396 do j = 1, this%m_restart
397 iter = iter+1
398
399 call this%M%solve(z(1,j), v(1,j), n)
400
401 call ax%compute(w, z(1,j), coef, x%msh, x%Xh)
402 call gs_h%op(w, n, gs_op_add, this%gs_event)
403 call device_event_sync(this%gs_event)
404 call blst%apply_scalar(w, n)
405
406 if (neko_bcknd_opencl .eq. 1) then
407 do i = 1, j
408 h(i,j) = device_glsc3(w_d, v_d(i), coef%mult_d, n)
409
410 call device_add2s2(w_d, v_d(i), -h(i,j), n)
411
412 alpha2 = device_glsc3(w_d, w_d, coef%mult_d, n)
413 end do
414 else
415 call device_glsc3_many(h(1,j), w_d, v_d_d, coef%mult_d, j, n)
416
417 call device_memcpy(h(:,j), h_d(j), j, &
418 host_to_device, sync = .false.)
419
420 alpha2 = device_gmres_part2(w_d, v_d_d, h_d(j), &
421 coef%mult_d, j, n)
422
423 end if
424
425 alpha = sqrt(alpha2)
426 do i = 1, j-1
427 temp = h(i,j)
428 h(i,j) = c(i)*temp + s(i) * h(i+1,j)
429 h(i+1,j) = -s(i)*temp + c(i) * h(i+1,j)
430 end do
431
432 rnorm = 0.0_rp
433 if (abscmp(alpha, 0.0_rp)) then
434 conv = .true.
435 exit
436 end if
437
438 lr = sqrt(h(j,j) * h(j,j) + alpha2)
439 temp = 1.0_rp / lr
440 c(j) = h(j,j) * temp
441 s(j) = alpha * temp
442 h(j,j) = lr
443 call device_memcpy(h(:,j), h_d(j), j, &
444 host_to_device, sync = .false.)
445 gam(j+1) = -s(j) * gam(j)
446 gam(j) = c(j) * gam(j)
447
448 rnorm = abs(gam(j+1)) * norm_fac
449 call this%monitor_iter(iter, rnorm)
450 if (rnorm .lt. this%abs_tol) then
451 conv = .true.
452 exit
453 end if
454
455 if (iter + 1 .gt. max_iter) exit
456
457 if (j .lt. this%m_restart) then
458 temp = 1.0_rp / alpha
459 call device_cmult2(v_d(j+1), w_d, temp, n)
460 end if
461
462 end do
463
464 j = min(j, this%m_restart)
465 do k = j, 1, -1
466 temp = gam(k)
467 do i = j, k+1, -1
468 temp = temp - h(k,i) * c(i)
469 end do
470 c(k) = temp / h(k,k)
471 end do
472
473 if (neko_bcknd_opencl .eq. 1) then
474 do i = 1, j
475 call device_add2s2(x_d, this%z_d(i), c(i), n)
476 end do
477 else
478 call device_memcpy(c, c_d, j, host_to_device, sync = .false.)
479 call device_add2s2_many(x_d, z_d_d, c_d, j, n)
480 end if
481 end do
482
483 end associate
484 call this%monitor_stop()
485 ksp_results%res_final = rnorm
486 ksp_results%iter = iter
487 ksp_results%converged = this%is_converged(iter, rnorm)
488
489 end function gmres_device_solve
490
492 function gmres_device_solve_coupled(this, Ax, x, y, z, fx, fy, fz, &
493 n, coef, blstx, blsty, blstz, gs_h, niter) result(ksp_results)
494 class(gmres_device_t), intent(inout) :: this
495 class(ax_t), intent(in) :: ax
496 type(field_t), intent(inout) :: x
497 type(field_t), intent(inout) :: y
498 type(field_t), intent(inout) :: z
499 integer, intent(in) :: n
500 real(kind=rp), dimension(n), intent(in) :: fx
501 real(kind=rp), dimension(n), intent(in) :: fy
502 real(kind=rp), dimension(n), intent(in) :: fz
503 type(coef_t), intent(inout) :: coef
504 type(bc_list_t), intent(inout) :: blstx
505 type(bc_list_t), intent(inout) :: blsty
506 type(bc_list_t), intent(inout) :: blstz
507 type(gs_t), intent(inout) :: gs_h
508 type(ksp_monitor_t), dimension(3) :: ksp_results
509 integer, optional, intent(in) :: niter
510
511 ksp_results(1) = this%solve(ax, x, fx, n, coef, blstx, gs_h, niter)
512 ksp_results(2) = this%solve(ax, y, fy, n, coef, blsty, gs_h, niter)
513 ksp_results(3) = this%solve(ax, z, fz, n, coef, blstz, gs_h, niter)
514
515 end function gmres_device_solve_coupled
516
517end module gmres_device
__device__ T solve(const T u, const T y, const T guess, const T nu, const T kappa, const T B)
real cuda_gmres_part2(void *w, void *v, void *h, void *mult, int *j, int *n)
Definition gmres_aux.cu:62
Defines a Matrix-vector product.
Definition ax.f90:34
Defines a list of bc_t.
Definition bc_list.f90:34
Coefficients.
Definition coef.f90:34
Definition comm.F90:1
type(mpi_datatype), public mpi_real_precision
MPI type for working precision of REAL types.
Definition comm.F90:51
integer, public pe_size
MPI size of communicator.
Definition comm.F90:59
type(mpi_comm), public neko_comm
MPI communicator.
Definition comm.F90:43
Identity Krylov preconditioner for accelerators.
subroutine, public device_add2s1(a_d, b_d, c1, n, strm)
subroutine, public device_add2s2_many(y_d, x_d_d, a_d, j, n, strm)
subroutine, public device_add2s2(a_d, b_d, c1, n, strm)
Vector addition with scalar multiplication (multiplication on first argument)
subroutine, public device_rzero(a_d, n, strm)
Zero a real vector.
subroutine, public device_rone(a_d, n, strm)
Set all elements to one.
subroutine, public device_glsc3_many(h, w_d, v_d_d, mult_d, j, n, strm)
subroutine, public device_sub2(a_d, b_d, n, strm)
Vector substraction .
subroutine, public device_copy(a_d, b_d, n, strm)
Copy a vector .
real(kind=rp) function, public device_glsc3(a_d, b_d, c_d, n, strm)
Weighted inner product .
subroutine, public device_cmult2(a_d, b_d, c, n, strm)
Multiplication by constant c .
Device abstraction, common interface for various accelerators.
Definition device.F90:34
Defines a field.
Definition field.f90:34
Gather-scatter.
Defines various GMRES methods.
real(c_rp) function device_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
subroutine gmres_device_init(this, n, max_iter, m, rel_tol, abs_tol, monitor)
Initialise a standard GMRES solver.
type(ksp_monitor_t) function gmres_device_solve(this, ax, x, f, n, coef, blst, gs_h, niter)
Standard GMRES solve.
subroutine gmres_device_free(this)
Deallocate a standard GMRES solver.
type(ksp_monitor_t) function, dimension(3) gmres_device_solve_coupled(this, ax, x, y, z, fx, fy, fz, n, coef, blstx, blsty, blstz, gs_h, niter)
Standard GMRES coupled solve.
Implements the base abstract type for Krylov solvers plus helper types.
Definition krylov.f90:34
Definition math.f90:60
subroutine, public rone(a, n)
Set all elements to one.
Definition math.f90:238
subroutine, public rzero(a, n)
Zero a real vector.
Definition math.f90:205
Build configurations.
integer, parameter neko_bcknd_opencl
integer, parameter, public c_rp
Definition num_types.f90:13
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Krylov preconditioner.
Definition precon.f90:34
Utilities.
Definition utils.f90:35
Base type for a matrix-vector product providing .
Definition ax.f90:43
A list of allocatable `bc_t`. Follows the standard interface of lists.
Definition bc_list.f90:48
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition coef.f90:55
Defines a canonical Krylov preconditioner for accelerators.
Standard preconditioned generalized minimal residual method.
Type for storing initial and final residuals in a Krylov solver.
Definition krylov.f90:56
Base abstract type for a canonical Krylov method, solving .
Definition krylov.f90:73
Defines a canonical Krylov preconditioner.
Definition precon.f90:40