Neko 1.99.1
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
gmres_device.F90
Go to the documentation of this file.
1! Copyright (c) 2022-2024, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
36 use krylov, only : ksp_t, ksp_monitor_t
37 use precon, only : pc_t
38 use ax_product, only : ax_t
39 use num_types, only: rp, c_rp
40 use field, only : field_t
41 use coefs, only : coef_t
42 use gather_scatter, only : gs_t, gs_op_add
43 use bc_list, only : bc_list_t
45 use math, only : rone, rzero, abscmp
50 use device
51 use utils, only : neko_error
53 use mpi_f08, only : mpi_in_place, mpi_sum, mpi_allreduce
54 use, intrinsic :: iso_c_binding, only : c_ptr, c_null_ptr, c_loc, &
55 c_associated, c_int, c_size_t, c_sizeof
56 implicit none
57 private
58
60 type, public, extends(ksp_t) :: gmres_device_t
61 integer :: m_restart = 30
62 real(kind=rp), allocatable :: w(:)
63 real(kind=rp), allocatable :: c(:)
64 real(kind=rp), allocatable :: r(:)
65 real(kind=rp), allocatable :: z(:,:)
66 real(kind=rp), allocatable :: h(:,:)
67 real(kind=rp), allocatable :: v(:,:)
68 real(kind=rp), allocatable :: s(:)
69 real(kind=rp), allocatable :: gam(:)
70 type(c_ptr) :: w_d = c_null_ptr
71 type(c_ptr) :: c_d = c_null_ptr
72 type(c_ptr) :: r_d = c_null_ptr
73 type(c_ptr) :: s_d = c_null_ptr
74 type(c_ptr) :: gam_d = c_null_ptr
75 type(c_ptr), allocatable :: z_d(:), h_d(:), v_d(:)
76 type(c_ptr) :: z_d_d = c_null_ptr
77 type(c_ptr) :: h_d_d = c_null_ptr
78 type(c_ptr) :: v_d_d = c_null_ptr
79 type(c_ptr) :: gs_event = c_null_ptr
80 contains
81 procedure, pass(this) :: init => gmres_device_init
82 procedure, pass(this) :: free => gmres_device_free
83 procedure, pass(this) :: solve => gmres_device_solve
84 procedure, pass(this) :: solve_coupled => gmres_device_solve_coupled
85 end type gmres_device_t
86
87#ifdef HAVE_HIP
88 interface
89 real(c_rp) function hip_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) &
90 bind(c, name = 'hip_gmres_part2')
91 use, intrinsic :: iso_c_binding
92 import c_rp
93 implicit none
94 type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
95 integer(c_int) :: j, n
96 end function hip_gmres_part2
97 end interface
98#elif HAVE_CUDA
99
100 interface
101 real(c_rp) function cuda_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) &
102 bind(c, name = 'cuda_gmres_part2')
103 use, intrinsic :: iso_c_binding
104 import c_rp
105 implicit none
106 type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
107 integer(c_int) :: j, n
108 end function cuda_gmres_part2
109 end interface
110#endif
111
112contains
113
114 function device_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n) result(alpha)
115 type(c_ptr), value :: h_d, w_d, v_d_d, mult_d
116 integer(c_int) :: j, n
117 real(c_rp) :: alpha
118 integer :: ierr
119#ifdef HAVE_HIP
120 alpha = hip_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
121#elif HAVE_CUDA
122 alpha = cuda_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
123#else
124 call neko_error('No device backend configured')
125#endif
126
127#ifndef HAVE_DEVICE_MPI
128 if (pe_size .gt. 1) then
129 call mpi_allreduce(mpi_in_place, alpha, 1, &
130 mpi_real_precision, mpi_sum, neko_comm, ierr)
131 end if
132#endif
133
134 end function device_gmres_part2
135
137 subroutine gmres_device_init(this, n, max_iter, M, rel_tol, abs_tol, monitor)
138 class(gmres_device_t), target, intent(inout) :: this
139 integer, intent(in) :: n
140 integer, intent(in) :: max_iter
141 class(pc_t), optional, intent(in), target :: M
142 real(kind=rp), optional, intent(in) :: rel_tol
143 real(kind=rp), optional, intent(in) :: abs_tol
144 logical, optional, intent(in) :: monitor
145 type(device_ident_t), target :: M_ident
146 type(c_ptr) :: ptr
147 integer(c_size_t) :: z_size
148 integer :: i
149
150 call this%free()
151
152 if (present(m)) then
153 this%M => m
154 else
155 this%M => m_ident
156 end if
157
158 allocate(this%w(n))
159 allocate(this%r(n))
160 call device_map(this%w, this%w_d, n)
161 call device_map(this%r, this%r_d, n)
162
163 allocate(this%c(this%m_restart))
164 allocate(this%s(this%m_restart))
165 allocate(this%gam(this%m_restart + 1))
166 call device_map(this%c, this%c_d, this%m_restart)
167 call device_map(this%s, this%s_d, this%m_restart)
168 call device_map(this%gam, this%gam_d, this%m_restart+1)
169
170 allocate(this%z(n, this%m_restart))
171 allocate(this%v(n, this%m_restart))
172 allocate(this%h(this%m_restart, this%m_restart))
173 allocate(this%z_d(this%m_restart))
174 allocate(this%v_d(this%m_restart))
175 allocate(this%h_d(this%m_restart))
176 do i = 1, this%m_restart
177 this%z_d(i) = c_null_ptr
178 call device_map(this%z(:,i), this%z_d(i), n)
179
180 this%v_d(i) = c_null_ptr
181 call device_map(this%v(:,i), this%v_d(i), n)
182
183 this%h_d(i) = c_null_ptr
184 call device_map(this%h(:,i), this%h_d(i), this%m_restart)
185 end do
186
187 z_size = c_sizeof(c_null_ptr) * (this%m_restart)
188 call device_alloc(this%z_d_d, z_size)
189 call device_alloc(this%v_d_d, z_size)
190 call device_alloc(this%h_d_d, z_size)
191 ptr = c_loc(this%z_d)
192 call device_memcpy(ptr, this%z_d_d, z_size, &
193 host_to_device, sync = .false.)
194 ptr = c_loc(this%v_d)
195 call device_memcpy(ptr, this%v_d_d, z_size, &
196 host_to_device, sync = .false.)
197 ptr = c_loc(this%h_d)
198 call device_memcpy(ptr, this%h_d_d, z_size, &
199 host_to_device, sync = .false.)
200
201
202 if (present(rel_tol) .and. present(abs_tol) .and. present(monitor)) then
203 call this%ksp_init(max_iter, rel_tol, abs_tol, monitor = monitor)
204 else if (present(rel_tol) .and. present(abs_tol)) then
205 call this%ksp_init(max_iter, rel_tol, abs_tol)
206 else if (present(monitor) .and. present(abs_tol)) then
207 call this%ksp_init(max_iter, abs_tol = abs_tol, monitor = monitor)
208 else if (present(rel_tol) .and. present(monitor)) then
209 call this%ksp_init(max_iter, rel_tol, monitor = monitor)
210 else if (present(rel_tol)) then
211 call this%ksp_init(max_iter, rel_tol = rel_tol)
212 else if (present(abs_tol)) then
213 call this%ksp_init(max_iter, abs_tol = abs_tol)
214 else if (present(monitor)) then
215 call this%ksp_init(max_iter, monitor = monitor)
216 else
217 call this%ksp_init(max_iter)
218 end if
219
220 call device_event_create(this%gs_event, 2)
221
222 end subroutine gmres_device_init
223
225 subroutine gmres_device_free(this)
226 class(gmres_device_t), intent(inout) :: this
227 integer :: i
228
229 call this%ksp_free()
230
231 if (allocated(this%w)) then
232 deallocate(this%w)
233 end if
234
235 if (allocated(this%c)) then
236 deallocate(this%c)
237 end if
238
239 if (allocated(this%r)) then
240 deallocate(this%r)
241 end if
242
243 if (allocated(this%z)) then
244 deallocate(this%z)
245 end if
246
247 if (allocated(this%h)) then
248 deallocate(this%h)
249 end if
250
251 if (allocated(this%v)) then
252 deallocate(this%v)
253 end if
254
255 if (allocated(this%s)) then
256 deallocate(this%s)
257 end if
258 if (allocated(this%gam)) then
259 deallocate(this%gam)
260 end if
261
262 if (allocated(this%v_d)) then
263 do i = 1, this%m_restart
264 if (c_associated(this%v_d(i))) then
265 call device_free(this%v_d(i))
266 end if
267 end do
268 end if
269
270 if (allocated(this%z_d)) then
271 do i = 1, this%m_restart
272 if (c_associated(this%z_d(i))) then
273 call device_free(this%z_d(i))
274 end if
275 end do
276 end if
277 if (allocated(this%h_d)) then
278 do i = 1, this%m_restart
279 if (c_associated(this%h_d(i))) then
280 call device_free(this%h_d(i))
281 end if
282 end do
283 end if
284
285
286
287 if (c_associated(this%gam_d)) then
288 call device_free(this%gam_d)
289 end if
290 if (c_associated(this%w_d)) then
291 call device_free(this%w_d)
292 end if
293 if (c_associated(this%c_d)) then
294 call device_free(this%c_d)
295 end if
296 if (c_associated(this%r_d)) then
297 call device_free(this%r_d)
298 end if
299 if (c_associated(this%s_d)) then
300 call device_free(this%s_d)
301 end if
302
303 nullify(this%M)
304
305 if (c_associated(this%gs_event)) then
306 call device_event_destroy(this%gs_event)
307 end if
308
309 end subroutine gmres_device_free
310
312 function gmres_device_solve(this, Ax, x, f, n, coef, blst, gs_h, niter) &
313 result(ksp_results)
314 class(gmres_device_t), intent(inout) :: this
315 class(ax_t), intent(in) :: ax
316 type(field_t), intent(inout) :: x
317 integer, intent(in) :: n
318 real(kind=rp), dimension(n), intent(in) :: f
319 type(coef_t), intent(inout) :: coef
320 type(bc_list_t), intent(inout) :: blst
321 type(gs_t), intent(inout) :: gs_h
322 type(ksp_monitor_t) :: ksp_results
323 integer, optional, intent(in) :: niter
324 integer :: iter, max_iter
325 integer :: i, j, k
326 real(kind=rp) :: rnorm, alpha, temp, lr, alpha2, norm_fac
327 logical :: conv
328 type(c_ptr) :: f_d
329
330 f_d = device_get_ptr(f)
331
332 conv = .false.
333 iter = 0
334 rnorm = 0.0_rp
335
336 if (present(niter)) then
337 max_iter = niter
338 else
339 max_iter = this%max_iter
340 end if
341
342 associate(w => this%w, c => this%c, r => this%r, z => this%z, h => this%h, &
343 v => this%v, s => this%s, gam => this%gam, v_d => this%v_d, &
344 w_d => this%w_d, r_d => this%r_d, h_d => this%h_d, &
345 v_d_d => this%v_d_d, x_d => x%x_d, z_d_d => this%z_d_d, &
346 c_d => this%c_d)
347
348 norm_fac = 1.0_rp / sqrt(coef%volume)
349 call rzero(gam, this%m_restart + 1)
350 call rone(s, this%m_restart)
351 call rone(c, this%m_restart)
352 call rzero(h, this%m_restart * this%m_restart)
353 call device_rzero(x%x_d, n)
354 call device_rzero(this%gam_d, this%m_restart + 1)
355 call device_rone(this%s_d, this%m_restart)
356 call device_rone(this%c_d, this%m_restart)
357
358 call rzero(this%h, this%m_restart**2)
359 ! do j = 1, this%m_restart
360 ! call device_rzero(h_d(j), this%m_restart)
361 ! end do
362
363 call this%monitor_start('GMRES')
364 do while (.not. conv .and. iter .lt. max_iter)
365
366 if (iter .eq. 0) then
367 call device_copy(r_d, f_d, n)
368 else
369 call device_copy(r_d, f_d, n)
370 call ax%compute(w, x%x, coef, x%msh, x%Xh)
371 call gs_h%op(w, n, gs_op_add, this%gs_event)
372 call device_event_sync(this%gs_event)
373 call blst%apply_scalar(w, n)
374 call device_sub2(r_d, w_d, n)
375 end if
376
377 gam(1) = sqrt(device_glsc3(r_d, r_d, coef%mult_d, n))
378 if (iter .eq. 0) then
379 ksp_results%res_start = gam(1) * norm_fac
380 end if
381
382 if (abscmp(gam(1), 0.0_rp)) exit
383
384 rnorm = 0.0_rp
385 temp = 1.0_rp / gam(1)
386 call device_cmult2(v_d(1), r_d, temp, n)
387 do j = 1, this%m_restart
388 iter = iter+1
389
390 call this%M%solve(z(1,j), v(1,j), n)
391
392 call ax%compute(w, z(1,j), coef, x%msh, x%Xh)
393 call gs_h%op(w, n, gs_op_add, this%gs_event)
394 call device_event_sync(this%gs_event)
395 call blst%apply_scalar(w, n)
396
397 if (neko_bcknd_opencl .eq. 1) then
398 do i = 1, j
399 h(i,j) = device_glsc3(w_d, v_d(i), coef%mult_d, n)
400
401 call device_add2s2(w_d, v_d(i), -h(i,j), n)
402
403 alpha2 = device_glsc3(w_d, w_d, coef%mult_d, n)
404 end do
405 else
406 call device_glsc3_many(h(1,j), w_d, v_d_d, coef%mult_d, j, n)
407
408 call device_memcpy(h(:,j), h_d(j), j, &
409 host_to_device, sync = .false.)
410
411 alpha2 = device_gmres_part2(w_d, v_d_d, h_d(j), &
412 coef%mult_d, j, n)
413
414 end if
415
416 alpha = sqrt(alpha2)
417 do i = 1, j-1
418 temp = h(i,j)
419 h(i,j) = c(i)*temp + s(i) * h(i+1,j)
420 h(i+1,j) = -s(i)*temp + c(i) * h(i+1,j)
421 end do
422
423 rnorm = 0.0_rp
424 if (abscmp(alpha, 0.0_rp)) then
425 conv = .true.
426 exit
427 end if
428
429 lr = sqrt(h(j,j) * h(j,j) + alpha2)
430 temp = 1.0_rp / lr
431 c(j) = h(j,j) * temp
432 s(j) = alpha * temp
433 h(j,j) = lr
434 call device_memcpy(h(:,j), h_d(j), j, &
435 host_to_device, sync = .false.)
436 gam(j+1) = -s(j) * gam(j)
437 gam(j) = c(j) * gam(j)
438
439 rnorm = abs(gam(j+1)) * norm_fac
440 call this%monitor_iter(iter, rnorm)
441 if (rnorm .lt. this%abs_tol) then
442 conv = .true.
443 exit
444 end if
445
446 if (iter + 1 .gt. max_iter) exit
447
448 if (j .lt. this%m_restart) then
449 temp = 1.0_rp / alpha
450 call device_cmult2(v_d(j+1), w_d, temp, n)
451 end if
452
453 end do
454
455 j = min(j, this%m_restart)
456 do k = j, 1, -1
457 temp = gam(k)
458 do i = j, k+1, -1
459 temp = temp - h(k,i) * c(i)
460 end do
461 c(k) = temp / h(k,k)
462 end do
463
464 if (neko_bcknd_opencl .eq. 1) then
465 do i = 1, j
466 call device_add2s2(x_d, this%z_d(i), c(i), n)
467 end do
468 else
469 call device_memcpy(c, c_d, j, host_to_device, sync = .false.)
470 call device_add2s2_many(x_d, z_d_d, c_d, j, n)
471 end if
472 end do
473
474 end associate
475 call this%monitor_stop()
476 ksp_results%res_final = rnorm
477 ksp_results%iter = iter
478 ksp_results%converged = this%is_converged(iter, rnorm)
479
480 end function gmres_device_solve
481
483 function gmres_device_solve_coupled(this, Ax, x, y, z, fx, fy, fz, &
484 n, coef, blstx, blsty, blstz, gs_h, niter) result(ksp_results)
485 class(gmres_device_t), intent(inout) :: this
486 class(ax_t), intent(in) :: ax
487 type(field_t), intent(inout) :: x
488 type(field_t), intent(inout) :: y
489 type(field_t), intent(inout) :: z
490 integer, intent(in) :: n
491 real(kind=rp), dimension(n), intent(in) :: fx
492 real(kind=rp), dimension(n), intent(in) :: fy
493 real(kind=rp), dimension(n), intent(in) :: fz
494 type(coef_t), intent(inout) :: coef
495 type(bc_list_t), intent(inout) :: blstx
496 type(bc_list_t), intent(inout) :: blsty
497 type(bc_list_t), intent(inout) :: blstz
498 type(gs_t), intent(inout) :: gs_h
499 type(ksp_monitor_t), dimension(3) :: ksp_results
500 integer, optional, intent(in) :: niter
501
502 ksp_results(1) = this%solve(ax, x, fx, n, coef, blstx, gs_h, niter)
503 ksp_results(2) = this%solve(ax, y, fy, n, coef, blsty, gs_h, niter)
504 ksp_results(3) = this%solve(ax, z, fz, n, coef, blstz, gs_h, niter)
505
506 end function gmres_device_solve_coupled
507
508end module gmres_device
__device__ T solve(const T u, const T y, const T guess, const T nu, const T kappa, const T B)
real cuda_gmres_part2(void *w, void *v, void *h, void *mult, int *j, int *n)
Definition gmres_aux.cu:62
Defines a Matrix-vector product.
Definition ax.f90:34
Defines a list of bc_t.
Definition bc_list.f90:34
Coefficients.
Definition coef.f90:34
Definition comm.F90:1
type(mpi_datatype), public mpi_real_precision
MPI type for working precision of REAL types.
Definition comm.F90:50
integer, public pe_size
MPI size of communicator.
Definition comm.F90:58
type(mpi_comm), public neko_comm
MPI communicator.
Definition comm.F90:42
Identity Krylov preconditioner for accelerators.
subroutine, public device_add2s1(a_d, b_d, c1, n, strm)
subroutine, public device_add2s2_many(y_d, x_d_d, a_d, j, n, strm)
subroutine, public device_add2s2(a_d, b_d, c1, n, strm)
Vector addition with scalar multiplication (multiplication on first argument)
subroutine, public device_rzero(a_d, n, strm)
Zero a real vector.
subroutine, public device_rone(a_d, n, strm)
Set all elements to one.
subroutine, public device_glsc3_many(h, w_d, v_d_d, mult_d, j, n, strm)
subroutine, public device_sub2(a_d, b_d, n, strm)
Vector substraction .
subroutine, public device_copy(a_d, b_d, n, strm)
Copy a vector .
real(kind=rp) function, public device_glsc3(a_d, b_d, c_d, n, strm)
Weighted inner product .
subroutine, public device_cmult2(a_d, b_d, c, n, strm)
Multiplication by constant c .
Device abstraction, common interface for various accelerators.
Definition device.F90:34
Defines a field.
Definition field.f90:34
Gather-scatter.
Defines various GMRES methods.
real(c_rp) function device_gmres_part2(w_d, v_d_d, h_d, mult_d, j, n)
subroutine gmres_device_init(this, n, max_iter, m, rel_tol, abs_tol, monitor)
Initialise a standard GMRES solver.
type(ksp_monitor_t) function gmres_device_solve(this, ax, x, f, n, coef, blst, gs_h, niter)
Standard GMRES solve.
subroutine gmres_device_free(this)
Deallocate a standard GMRES solver.
type(ksp_monitor_t) function, dimension(3) gmres_device_solve_coupled(this, ax, x, y, z, fx, fy, fz, n, coef, blstx, blsty, blstz, gs_h, niter)
Standard GMRES coupled solve.
Implements the base abstract type for Krylov solvers plus helper types.
Definition krylov.f90:34
Definition math.f90:60
subroutine, public rone(a, n)
Set all elements to one.
Definition math.f90:244
subroutine, public rzero(a, n)
Zero a real vector.
Definition math.f90:211
Build configurations.
integer, parameter neko_bcknd_opencl
integer, parameter, public c_rp
Definition num_types.f90:13
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Krylov preconditioner.
Definition precon.f90:34
Utilities.
Definition utils.f90:35
Base type for a matrix-vector product providing .
Definition ax.f90:43
A list of allocatable `bc_t`. Follows the standard interface of lists.
Definition bc_list.f90:48
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition coef.f90:55
Defines a canonical Krylov preconditioner for accelerators.
Standard preconditioned generalized minimal residual method.
Type for storing initial and final residuals in a Krylov solver.
Definition krylov.f90:56
Base abstract type for a canonical Krylov method, solving .
Definition krylov.f90:73
Defines a canonical Krylov preconditioner.
Definition precon.f90:40