Neko 1.99.2
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
cg_cpld_device.f90
Go to the documentation of this file.
1! Copyright (c) 2025, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
35 use num_types, only: rp
37 use precon, only : pc_t
38 use ax_product, only : ax_t
39 use field, only : field_t
40 use coefs, only : coef_t
41 use gather_scatter, only : gs_t, gs_op_add
42 use bc_list, only : bc_list_t
43 use math, only : abscmp
44 use device
48 use utils, only : neko_error
49 use operators, only : rotate_cyc
50 use, intrinsic :: iso_c_binding, only : c_ptr, c_null_ptr, c_associated
51 implicit none
52
54 type, public, extends(ksp_t) :: cg_cpld_device_t
55 real(kind=rp), allocatable :: w1(:)
56 real(kind=rp), allocatable :: w2(:)
57 real(kind=rp), allocatable :: w3(:)
58 real(kind=rp), allocatable :: r1(:)
59 real(kind=rp), allocatable :: r2(:)
60 real(kind=rp), allocatable :: r3(:)
61 real(kind=rp), allocatable :: p1(:)
62 real(kind=rp), allocatable :: p2(:)
63 real(kind=rp), allocatable :: p3(:)
64 real(kind=rp), allocatable :: z1(:)
65 real(kind=rp), allocatable :: z2(:)
66 real(kind=rp), allocatable :: z3(:)
67 real(kind=rp), allocatable :: tmp(:)
68
69
70 type(c_ptr) :: w1_d = c_null_ptr
71 type(c_ptr) :: w2_d = c_null_ptr
72 type(c_ptr) :: w3_d = c_null_ptr
73
74 type(c_ptr) :: r1_d = c_null_ptr
75 type(c_ptr) :: r2_d = c_null_ptr
76 type(c_ptr) :: r3_d = c_null_ptr
77
78 type(c_ptr) :: p1_d = c_null_ptr
79 type(c_ptr) :: p2_d = c_null_ptr
80 type(c_ptr) :: p3_d = c_null_ptr
81
82 type(c_ptr) :: z1_d = c_null_ptr
83 type(c_ptr) :: z2_d = c_null_ptr
84 type(c_ptr) :: z3_d = c_null_ptr
85
86 type(c_ptr) :: tmp_d = c_null_ptr
87
88 type(c_ptr) :: gs_event = c_null_ptr
89 contains
90 procedure, pass(this) :: init => cg_cpld_device_init
91 procedure, pass(this) :: free => cg_cpld_device_free
92 procedure, pass(this) :: solve => cg_cpld_device_nop
93 procedure, pass(this) :: solve_coupled => cg_cpld_device_solve
94 end type cg_cpld_device_t
95
96contains
97
99 subroutine cg_cpld_device_init(this, n, max_iter, M, rel_tol, abs_tol, monitor)
100 class(cg_cpld_device_t), target, intent(inout) :: this
101 class(pc_t), optional, intent(in), target :: M
102 integer, intent(in) :: n
103 integer, intent(in) :: max_iter
104 real(kind=rp), optional, intent(in) :: rel_tol
105 real(kind=rp), optional, intent(in) :: abs_tol
106 logical, optional, intent(in) :: monitor
107
108 call this%free()
109
110 allocate(this%w1(n))
111 allocate(this%w2(n))
112 allocate(this%w3(n))
113 allocate(this%r1(n))
114 allocate(this%r2(n))
115 allocate(this%r3(n))
116 allocate(this%p1(n))
117 allocate(this%p2(n))
118 allocate(this%p3(n))
119 allocate(this%z1(n))
120 allocate(this%z2(n))
121 allocate(this%z3(n))
122 allocate(this%tmp(n))
123
124 call device_map(this%tmp, this%tmp_d, n)
125 call device_map(this%z1, this%z1_d, n)
126 call device_map(this%z2, this%z2_d, n)
127 call device_map(this%z3, this%z3_d, n)
128 call device_map(this%p1, this%p1_d, n)
129 call device_map(this%p2, this%p2_d, n)
130 call device_map(this%p3, this%p3_d, n)
131 call device_map(this%r1, this%r1_d, n)
132 call device_map(this%r2, this%r2_d, n)
133 call device_map(this%r3, this%r3_d, n)
134 call device_map(this%w1, this%w1_d, n)
135 call device_map(this%w2, this%w2_d, n)
136 call device_map(this%w3, this%w3_d, n)
137
138 if (present(m)) then
139 this%M => m
140 end if
141
142 if (present(rel_tol) .and. present(abs_tol) .and. present(monitor)) then
143 call this%ksp_init(max_iter, rel_tol, abs_tol, monitor = monitor)
144 else if (present(rel_tol) .and. present(abs_tol)) then
145 call this%ksp_init(max_iter, rel_tol, abs_tol)
146 else if (present(monitor) .and. present(abs_tol)) then
147 call this%ksp_init(max_iter, abs_tol = abs_tol, monitor = monitor)
148 else if (present(rel_tol) .and. present(monitor)) then
149 call this%ksp_init(max_iter, rel_tol, monitor = monitor)
150 else if (present(rel_tol)) then
151 call this%ksp_init(max_iter, rel_tol = rel_tol)
152 else if (present(abs_tol)) then
153 call this%ksp_init(max_iter, abs_tol = abs_tol)
154 else if (present(monitor)) then
155 call this%ksp_init(max_iter, monitor = monitor)
156 else
157 call this%ksp_init(max_iter)
158 end if
159
160 call device_event_create(this%gs_event, 2)
161 end subroutine cg_cpld_device_init
162
164 subroutine cg_cpld_device_free(this)
165 class(cg_cpld_device_t), intent(inout) :: this
166
167 call this%ksp_free()
168
169 if (allocated(this%w1)) then
170 deallocate(this%w1)
171 end if
172
173 if (allocated(this%w2)) then
174 deallocate(this%w2)
175 end if
176
177 if (allocated(this%w3)) then
178 deallocate(this%w3)
179 end if
180
181 if (allocated(this%r1)) then
182 deallocate(this%r1)
183 end if
184
185 if (allocated(this%r2)) then
186 deallocate(this%r2)
187 end if
188
189 if (allocated(this%r3)) then
190 deallocate(this%r3)
191 end if
192
193 if (allocated(this%p1)) then
194 deallocate(this%p1)
195 end if
196
197 if (allocated(this%p2)) then
198 deallocate(this%p2)
199 end if
200
201 if (allocated(this%p3)) then
202 deallocate(this%p3)
203 end if
204
205 if (allocated(this%z1)) then
206 deallocate(this%z1)
207 end if
208
209 if (allocated(this%z2)) then
210 deallocate(this%z2)
211 end if
212
213 if (allocated(this%z3)) then
214 deallocate(this%z3)
215 end if
216
217 if (allocated(this%tmp)) then
218 deallocate(this%tmp)
219 end if
220
221 nullify(this%M)
222
223 if (c_associated(this%w1_d)) then
224 call device_free(this%w1_d)
225 end if
226
227 if (c_associated(this%w2_d)) then
228 call device_free(this%w2_d)
229 end if
230
231 if (c_associated(this%w3_d)) then
232 call device_free(this%w3_d)
233 end if
234
235 if (c_associated(this%r1_d)) then
236 call device_free(this%r1_d)
237 end if
238
239 if (c_associated(this%r2_d)) then
240 call device_free(this%r2_d)
241 end if
242
243 if (c_associated(this%r3_d)) then
244 call device_free(this%r3_d)
245 end if
246
247 if (c_associated(this%p1_d)) then
248 call device_free(this%p1_d)
249 end if
250
251 if (c_associated(this%p2_d)) then
252 call device_free(this%p2_d)
253 end if
254
255 if (c_associated(this%p3_d)) then
256 call device_free(this%p3_d)
257 end if
258
259 if (c_associated(this%z1_d)) then
260 call device_free(this%z1_d)
261 end if
262
263 if (c_associated(this%z2_d)) then
264 call device_free(this%z2_d)
265 end if
266
267 if (c_associated(this%z3_d)) then
268 call device_free(this%z3_d)
269 end if
270
271 if (c_associated(this%tmp_d)) then
272 call device_free(this%tmp_d)
273 end if
274
275 if (c_associated(this%gs_event)) then
276 call device_event_destroy(this%gs_event)
277 end if
278
279 end subroutine cg_cpld_device_free
280
281 function cg_cpld_device_nop(this, Ax, x, f, n, coef, blst, gs_h, niter) &
282 result(ksp_results)
283 class(cg_cpld_device_t), intent(inout) :: this
284 class(ax_t), intent(in) :: ax
285 type(field_t), intent(inout) :: x
286 integer, intent(in) :: n
287 real(kind=rp), dimension(n), intent(in) :: f
288 type(coef_t), intent(inout) :: coef
289 type(bc_list_t), intent(inout) :: blst
290 type(gs_t), intent(inout) :: gs_h
291 type(ksp_monitor_t) :: ksp_results
292 integer, optional, intent(in) :: niter
293
294 ! Throw and error
295 call neko_error('The cpldcg solver is only defined for coupled solves')
296
297 ksp_results%res_final = 0.0
298 ksp_results%iter = 0
299 end function cg_cpld_device_nop
300
302 function cg_cpld_device_solve(this, Ax, x, y, z, fx, fy, fz, &
303 n, coef, blstx, blsty, blstz, gs_h, niter) result(ksp_results)
304 class(cg_cpld_device_t), intent(inout) :: this
305 class(ax_t), intent(in) :: ax
306 type(field_t), intent(inout) :: x
307 type(field_t), intent(inout) :: y
308 type(field_t), intent(inout) :: z
309 integer, intent(in) :: n
310 real(kind=rp), dimension(n), intent(in) :: fx
311 real(kind=rp), dimension(n), intent(in) :: fy
312 real(kind=rp), dimension(n), intent(in) :: fz
313 type(coef_t), intent(inout) :: coef
314 type(bc_list_t), intent(inout) :: blstx
315 type(bc_list_t), intent(inout) :: blsty
316 type(bc_list_t), intent(inout) :: blstz
317 type(gs_t), intent(inout) :: gs_h
318 type(ksp_monitor_t), dimension(3) :: ksp_results
319 integer, optional, intent(in) :: niter
320 integer :: i, iter, max_iter
321 real(kind=rp) :: rnorm, rtr, rtr0, rtz2, rtz1
322 real(kind=rp) :: beta, pap, alpha, alphm, norm_fac
323 integer, parameter :: gdim = 3
324 type(c_ptr) :: fx_d
325 type(c_ptr) :: fy_d
326 type(c_ptr) :: fz_d
327
328 fx_d = device_get_ptr(fx)
329 fy_d = device_get_ptr(fy)
330 fz_d = device_get_ptr(fz)
331
332 if (present(niter)) then
333 max_iter = niter
334 else
335 max_iter = this%max_iter
336 end if
337 norm_fac = 1.0_rp / coef%volume
338
339 associate(p1_d => this%p1_d, p2_d => this%p2_d, p3_d => this%p3_d, &
340 z1_d => this%z1_d, z2_d => this%z2_d, z3_d => this%z3_d, &
341 r1_d => this%r1_d, r2_d => this%r2_d, r3_d => this%r3_d, &
342 w1_d => this%w1_d, w2_d => this%w2_d, w3_d => this%w3_d, &
343 tmp_d => this%tmp_d)
344
345 rtz1 = 1.0_rp
346 call device_rzero(x%x_d, n)
347 call device_rzero(y%x_d, n)
348 call device_rzero(z%x_d, n)
349 call device_rzero(p1_d, n)
350 call device_rzero(p2_d, n)
351 call device_rzero(p3_d, n)
352 call device_rzero(z1_d, n)
353 call device_rzero(z2_d, n)
354 call device_rzero(z3_d, n)
355 call device_copy(r1_d, fx_d, n)
356 call device_copy(r2_d, fy_d, n)
357 call device_copy(r3_d, fz_d, n)
358 call device_vdot3(tmp_d, r1_d, r2_d, r3_d, r1_d, r2_d, r3_d, n)
359
360
361 rtr = device_glsc3(tmp_d, coef%mult_d, coef%binv_d, n)
362 rnorm = sqrt(rtr)*norm_fac
363 ksp_results%res_start = rnorm
364 ksp_results%res_final = rnorm
365 ksp_results%iter = 0
366 if(abscmp(rnorm, 0.0_rp)) then
367 ksp_results%converged = .true.
368 return
369 end if
370
371 call this%monitor_start('device_cpldCG')
372 do iter = 1, max_iter
373 call this%M%solve(this%z1, this%r1, n)
374 call this%M%solve(this%z2, this%r2, n)
375 call this%M%solve(this%z3, this%r3, n)
376 rtz2 = rtz1
377
378 call device_vdot3(tmp_d, z1_d, z2_d, z3_d, r1_d, r2_d, r3_d, n)
379
380 rtz1 = device_glsc2(tmp_d, coef%mult_d, n)
381
382 beta = rtz1 / rtz2
383 if (iter .eq. 1) beta = 0.0_rp
384 call device_add2s1(p1_d, z1_d, beta, n)
385 call device_add2s1(p2_d, z2_d, beta, n)
386 call device_add2s1(p3_d, z3_d, beta, n)
387
388 call ax%compute_vector(this%w1, this%w2, this%w3, &
389 this%p1, this%p2, this%p3, coef, x%msh, x%Xh)
390
391 call rotate_cyc(this%w1, this%w2, this%w3, 1, coef)
392 call gs_h%op(this%w1, n, gs_op_add, this%gs_event)
393 call device_event_sync(this%gs_event)
394 call gs_h%op(this%w2, n, gs_op_add, this%gs_event)
395 call device_event_sync(this%gs_event)
396 call gs_h%op(this%w3, n, gs_op_add, this%gs_event)
397 call device_event_sync(this%gs_event)
398 call rotate_cyc(this%w1, this%w2, this%w3, 0, coef)
399
400 call blstx%apply(this%w1, n)
401 call blsty%apply(this%w2, n)
402 call blstz%apply(this%w3, n)
403
404 call device_vdot3(tmp_d, w1_d, w2_d, w3_d, p1_d, p2_d, p3_d, n)
405
406 pap = device_glsc2(tmp_d, coef%mult_d, n)
407
408 alpha = rtz1 / pap
409 alphm = -alpha
410 call device_opadd2cm(x%x_d, y%x_d, z%x_d, &
411 p1_d, p2_d, p3_d, alpha, n, gdim)
412 call device_opadd2cm(r1_d, r2_d, r3_d, &
413 w1_d, w2_d, w3_d, alphm, n, gdim)
414 call device_vdot3(tmp_d, r1_d, r2_d, r3_d, r1_d, r2_d, r3_d, n)
415
416 rtr = device_glsc3(tmp_d, coef%mult_d, coef%binv_d, n)
417 if (iter .eq. 1) rtr0 = rtr
418 rnorm = sqrt(rtr) * norm_fac
419 call this%monitor_iter(iter, rnorm)
420 if (rnorm .lt. this%abs_tol) then
421 exit
422 end if
423 end do
424 end associate
425 call this%monitor_stop()
426 ksp_results%res_final = rnorm
427 ksp_results%iter = iter
428 ksp_results%converged = this%is_converged(iter, rnorm)
429
430 end function cg_cpld_device_solve
431
432end module cg_cpld_device
__device__ T solve(const T u, const T y, const T guess, const T nu, const T kappa, const T B)
Return the device pointer for an associated Fortran array.
Definition device.F90:101
Map a Fortran array to a device (allocate and associate)
Definition device.F90:77
Defines a Matrix-vector product.
Definition ax.f90:34
Defines a list of bc_t.
Definition bc_list.f90:34
Defines a coupled Conjugate Gradient methods for accelerators.
type(ksp_monitor_t) function, dimension(3) cg_cpld_device_solve(this, ax, x, y, z, fx, fy, fz, n, coef, blstx, blsty, blstz, gs_h, niter)
Standard PCG solve.
type(ksp_monitor_t) function cg_cpld_device_nop(this, ax, x, f, n, coef, blst, gs_h, niter)
subroutine cg_cpld_device_free(this)
Deallocate a device based PCG solver.
subroutine cg_cpld_device_init(this, n, max_iter, m, rel_tol, abs_tol, monitor)
Initialise a device based PCG solver.
Coefficients.
Definition coef.f90:34
subroutine, public device_add2s1(a_d, b_d, c1, n, strm)
subroutine, public device_rzero(a_d, n, strm)
Zero a real vector.
subroutine, public device_vdot3(dot_d, u1_d, u2_d, u3_d, v1_d, v2_d, v3_d, n, strm)
Compute a dot product (3-d version) assuming vector components etc.
subroutine, public device_copy(a_d, b_d, n, strm)
Copy a vector .
real(kind=rp) function, public device_glsc3(a_d, b_d, c_d, n, strm)
Weighted inner product .
real(kind=rp) function, public device_glsc2(a_d, b_d, n, strm)
Weighted inner product .
subroutine, public device_opadd2cm(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, c, n, gdim)
Device abstraction, common interface for various accelerators.
Definition device.F90:34
subroutine, public device_event_sync(event)
Synchronize an event.
Definition device.F90:1314
subroutine, public device_free(x_d)
Deallocate memory on the device.
Definition device.F90:219
subroutine, public device_event_destroy(event)
Destroy a device event.
Definition device.F90:1279
subroutine, public device_event_create(event, flags)
Create a device event queue.
Definition device.F90:1249
Defines a field.
Definition field.f90:34
Gather-scatter.
Implements the base abstract type for Krylov solvers plus helper types.
Definition krylov.f90:34
integer, parameter, public ksp_max_iter
Maximum number of iters.
Definition krylov.f90:51
Definition math.f90:60
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Operators.
Definition operators.f90:34
Krylov preconditioner.
Definition precon.f90:34
Utilities.
Definition utils.f90:35
Base type for a matrix-vector product providing .
Definition ax.f90:43
A list of allocatable `bc_t`. Follows the standard interface of lists.
Definition bc_list.f90:48
Device based coupled preconditioned conjugate gradient method.
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition coef.f90:56
Type for storing initial and final residuals in a Krylov solver.
Definition krylov.f90:56
Base abstract type for a canonical Krylov method, solving .
Definition krylov.f90:73
Defines a canonical Krylov preconditioner.
Definition precon.f90:40