Neko 1.99.3
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
cg_cpld_device.f90
Go to the documentation of this file.
1! Copyright (c) 2025, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
35 use num_types, only: rp
37 use precon, only : pc_t
38 use ax_product, only : ax_t
39 use field, only : field_t
40 use coefs, only : coef_t
41 use gather_scatter, only : gs_t, gs_op_add
42 use bc_list, only : bc_list_t
43 use math, only : abscmp
44 use device
48 use utils, only : neko_error
49 use operators, only : rotate_cyc
50 use, intrinsic :: iso_c_binding, only : c_ptr, c_null_ptr, c_associated
51 implicit none
52
54 type, public, extends(ksp_t) :: cg_cpld_device_t
55 real(kind=rp), allocatable :: w1(:)
56 real(kind=rp), allocatable :: w2(:)
57 real(kind=rp), allocatable :: w3(:)
58 real(kind=rp), allocatable :: r1(:)
59 real(kind=rp), allocatable :: r2(:)
60 real(kind=rp), allocatable :: r3(:)
61 real(kind=rp), allocatable :: p1(:)
62 real(kind=rp), allocatable :: p2(:)
63 real(kind=rp), allocatable :: p3(:)
64 real(kind=rp), allocatable :: z1(:)
65 real(kind=rp), allocatable :: z2(:)
66 real(kind=rp), allocatable :: z3(:)
67 real(kind=rp), allocatable :: tmp(:)
68
69
70 type(c_ptr) :: w1_d = c_null_ptr
71 type(c_ptr) :: w2_d = c_null_ptr
72 type(c_ptr) :: w3_d = c_null_ptr
73
74 type(c_ptr) :: r1_d = c_null_ptr
75 type(c_ptr) :: r2_d = c_null_ptr
76 type(c_ptr) :: r3_d = c_null_ptr
77
78 type(c_ptr) :: p1_d = c_null_ptr
79 type(c_ptr) :: p2_d = c_null_ptr
80 type(c_ptr) :: p3_d = c_null_ptr
81
82 type(c_ptr) :: z1_d = c_null_ptr
83 type(c_ptr) :: z2_d = c_null_ptr
84 type(c_ptr) :: z3_d = c_null_ptr
85
86 type(c_ptr) :: tmp_d = c_null_ptr
87
88 type(c_ptr) :: gs_event = c_null_ptr
89 contains
90 procedure, pass(this) :: init => cg_cpld_device_init
91 procedure, pass(this) :: free => cg_cpld_device_free
92 procedure, pass(this) :: solve => cg_cpld_device_nop
93 procedure, pass(this) :: solve_coupled => cg_cpld_device_solve
94 end type cg_cpld_device_t
95
96contains
97
99 subroutine cg_cpld_device_init(this, n, max_iter, M, rel_tol, abs_tol, monitor)
100 class(cg_cpld_device_t), target, intent(inout) :: this
101 class(pc_t), optional, intent(in), target :: M
102 integer, intent(in) :: n
103 integer, intent(in) :: max_iter
104 real(kind=rp), optional, intent(in) :: rel_tol
105 real(kind=rp), optional, intent(in) :: abs_tol
106 logical, optional, intent(in) :: monitor
107
108 call this%free()
109
110 allocate(this%w1(n))
111 allocate(this%w2(n))
112 allocate(this%w3(n))
113 allocate(this%r1(n))
114 allocate(this%r2(n))
115 allocate(this%r3(n))
116 allocate(this%p1(n))
117 allocate(this%p2(n))
118 allocate(this%p3(n))
119 allocate(this%z1(n))
120 allocate(this%z2(n))
121 allocate(this%z3(n))
122 allocate(this%tmp(n))
123
124 call device_map(this%tmp, this%tmp_d, n)
125 call device_map(this%z1, this%z1_d, n)
126 call device_map(this%z2, this%z2_d, n)
127 call device_map(this%z3, this%z3_d, n)
128 call device_map(this%p1, this%p1_d, n)
129 call device_map(this%p2, this%p2_d, n)
130 call device_map(this%p3, this%p3_d, n)
131 call device_map(this%r1, this%r1_d, n)
132 call device_map(this%r2, this%r2_d, n)
133 call device_map(this%r3, this%r3_d, n)
134 call device_map(this%w1, this%w1_d, n)
135 call device_map(this%w2, this%w2_d, n)
136 call device_map(this%w3, this%w3_d, n)
137
138 if (present(m)) then
139 this%M => m
140 end if
141
142 if (present(rel_tol) .and. present(abs_tol) .and. present(monitor)) then
143 call this%ksp_init(max_iter, rel_tol, abs_tol, monitor = monitor)
144 else if (present(rel_tol) .and. present(abs_tol)) then
145 call this%ksp_init(max_iter, rel_tol, abs_tol)
146 else if (present(monitor) .and. present(abs_tol)) then
147 call this%ksp_init(max_iter, abs_tol = abs_tol, monitor = monitor)
148 else if (present(rel_tol) .and. present(monitor)) then
149 call this%ksp_init(max_iter, rel_tol, monitor = monitor)
150 else if (present(rel_tol)) then
151 call this%ksp_init(max_iter, rel_tol = rel_tol)
152 else if (present(abs_tol)) then
153 call this%ksp_init(max_iter, abs_tol = abs_tol)
154 else if (present(monitor)) then
155 call this%ksp_init(max_iter, monitor = monitor)
156 else
157 call this%ksp_init(max_iter)
158 end if
159
160 call device_event_create(this%gs_event, 2)
161 end subroutine cg_cpld_device_init
162
164 subroutine cg_cpld_device_free(this)
165 class(cg_cpld_device_t), intent(inout) :: this
166
167 call this%ksp_free()
168
169 if (allocated(this%w1)) then
170 if (c_associated(this%w1_d)) then
171 call device_unmap(this%w1, this%w1_d)
172 end if
173 deallocate(this%w1)
174 end if
175
176 if (allocated(this%w2)) then
177 if (c_associated(this%w2_d)) then
178 call device_unmap(this%w2, this%w2_d)
179 end if
180 deallocate(this%w2)
181 end if
182
183 if (allocated(this%w3)) then
184 if (c_associated(this%w3_d)) then
185 call device_unmap(this%w3, this%w3_d)
186 end if
187 deallocate(this%w3)
188 end if
189
190 if (allocated(this%r1)) then
191 if (c_associated(this%r1_d)) then
192 call device_unmap(this%r1, this%r1_d)
193 end if
194 deallocate(this%r1)
195 end if
196
197 if (allocated(this%r2)) then
198 if (c_associated(this%r2_d)) then
199 call device_unmap(this%r2, this%r2_d)
200 end if
201 deallocate(this%r2)
202 end if
203
204 if (allocated(this%r3)) then
205 if (c_associated(this%r3_d)) then
206 call device_unmap(this%r3, this%r3_d)
207 end if
208 deallocate(this%r3)
209 end if
210
211 if (allocated(this%p1)) then
212 if (c_associated(this%p1_d)) then
213 call device_unmap(this%p1, this%p1_d)
214 end if
215 deallocate(this%p1)
216 end if
217
218 if (allocated(this%p2)) then
219 if (c_associated(this%p2_d)) then
220 call device_unmap(this%p2, this%p2_d)
221 end if
222 deallocate(this%p2)
223 end if
224
225 if (allocated(this%p3)) then
226 if (c_associated(this%p3_d)) then
227 call device_unmap(this%p3, this%p3_d)
228 end if
229 deallocate(this%p3)
230 end if
231
232 if (allocated(this%z1)) then
233 if (c_associated(this%z1_d)) then
234 call device_unmap(this%z1, this%z1_d)
235 end if
236 deallocate(this%z1)
237 end if
238
239 if (allocated(this%z2)) then
240 if (c_associated(this%z2_d)) then
241 call device_unmap(this%z2, this%z2_d)
242 end if
243 deallocate(this%z2)
244 end if
245
246 if (allocated(this%z3)) then
247 if (c_associated(this%z3_d)) then
248 call device_unmap(this%z3, this%z3_d)
249 end if
250 deallocate(this%z3)
251 end if
252
253 if (allocated(this%tmp)) then
254 if (c_associated(this%tmp_d)) then
255 call device_unmap(this%tmp, this%tmp_d)
256 end if
257 deallocate(this%tmp)
258 end if
259
260 nullify(this%M)
261
262 if (c_associated(this%gs_event)) then
263 call device_event_destroy(this%gs_event)
264 end if
265
266 end subroutine cg_cpld_device_free
267
268 function cg_cpld_device_nop(this, Ax, x, f, n, coef, blst, gs_h, niter) &
269 result(ksp_results)
270 class(cg_cpld_device_t), intent(inout) :: this
271 class(ax_t), intent(in) :: ax
272 type(field_t), intent(inout) :: x
273 integer, intent(in) :: n
274 real(kind=rp), dimension(n), intent(in) :: f
275 type(coef_t), intent(inout) :: coef
276 type(bc_list_t), intent(inout) :: blst
277 type(gs_t), intent(inout) :: gs_h
278 type(ksp_monitor_t) :: ksp_results
279 integer, optional, intent(in) :: niter
280
281 ! Throw and error
282 call neko_error('The cpldcg solver is only defined for coupled solves')
283
284 ksp_results%res_final = 0.0
285 ksp_results%iter = 0
286 end function cg_cpld_device_nop
287
289 function cg_cpld_device_solve(this, Ax, x, y, z, fx, fy, fz, &
290 n, coef, blstx, blsty, blstz, gs_h, niter) result(ksp_results)
291 class(cg_cpld_device_t), intent(inout) :: this
292 class(ax_t), intent(in) :: ax
293 type(field_t), intent(inout) :: x
294 type(field_t), intent(inout) :: y
295 type(field_t), intent(inout) :: z
296 integer, intent(in) :: n
297 real(kind=rp), dimension(n), intent(in) :: fx
298 real(kind=rp), dimension(n), intent(in) :: fy
299 real(kind=rp), dimension(n), intent(in) :: fz
300 type(coef_t), intent(inout) :: coef
301 type(bc_list_t), intent(inout) :: blstx
302 type(bc_list_t), intent(inout) :: blsty
303 type(bc_list_t), intent(inout) :: blstz
304 type(gs_t), intent(inout) :: gs_h
305 type(ksp_monitor_t), dimension(3) :: ksp_results
306 integer, optional, intent(in) :: niter
307 integer :: i, iter, max_iter
308 real(kind=rp) :: rnorm, rtr, rtr0, rtz2, rtz1
309 real(kind=rp) :: beta, pap, alpha, alphm, norm_fac
310 integer, parameter :: gdim = 3
311 type(c_ptr) :: fx_d
312 type(c_ptr) :: fy_d
313 type(c_ptr) :: fz_d
314
315 fx_d = device_get_ptr(fx)
316 fy_d = device_get_ptr(fy)
317 fz_d = device_get_ptr(fz)
318
319 if (present(niter)) then
320 max_iter = niter
321 else
322 max_iter = this%max_iter
323 end if
324 norm_fac = 1.0_rp / sqrt(coef%volume)
325
326 associate(p1_d => this%p1_d, p2_d => this%p2_d, p3_d => this%p3_d, &
327 z1_d => this%z1_d, z2_d => this%z2_d, z3_d => this%z3_d, &
328 r1_d => this%r1_d, r2_d => this%r2_d, r3_d => this%r3_d, &
329 w1_d => this%w1_d, w2_d => this%w2_d, w3_d => this%w3_d, &
330 tmp_d => this%tmp_d)
331
332 rtz1 = 1.0_rp
333 call device_rzero(x%x_d, n)
334 call device_rzero(y%x_d, n)
335 call device_rzero(z%x_d, n)
336 call device_rzero(p1_d, n)
337 call device_rzero(p2_d, n)
338 call device_rzero(p3_d, n)
339 call device_rzero(z1_d, n)
340 call device_rzero(z2_d, n)
341 call device_rzero(z3_d, n)
342 call device_copy(r1_d, fx_d, n)
343 call device_copy(r2_d, fy_d, n)
344 call device_copy(r3_d, fz_d, n)
345 call device_vdot3(tmp_d, r1_d, r2_d, r3_d, r1_d, r2_d, r3_d, n)
346
347
348 rtr = device_glsc2(tmp_d, coef%mult_d, n)
349 rnorm = sqrt(rtr)*norm_fac
350 ksp_results%res_start = rnorm
351 ksp_results%res_final = rnorm
352 ksp_results%iter = 0
353 if(abscmp(rnorm, 0.0_rp)) then
354 ksp_results%converged = .true.
355 return
356 end if
357
358 call this%monitor_start('device_cpldCG')
359 do iter = 1, max_iter
360 call this%M%solve(this%z1, this%r1, n)
361 call this%M%solve(this%z2, this%r2, n)
362 call this%M%solve(this%z3, this%r3, n)
363 rtz2 = rtz1
364
365 call device_vdot3(tmp_d, z1_d, z2_d, z3_d, r1_d, r2_d, r3_d, n)
366
367 rtz1 = device_glsc2(tmp_d, coef%mult_d, n)
368
369 beta = rtz1 / rtz2
370 if (iter .eq. 1) beta = 0.0_rp
371 call device_add2s1(p1_d, z1_d, beta, n)
372 call device_add2s1(p2_d, z2_d, beta, n)
373 call device_add2s1(p3_d, z3_d, beta, n)
374
375 call ax%compute_vector(this%w1, this%w2, this%w3, &
376 this%p1, this%p2, this%p3, coef, x%msh, x%Xh)
377
378 call rotate_cyc(w1_d, w2_d, w3_d, 1, coef)
379 call gs_h%op(this%w1, n, gs_op_add, this%gs_event)
380 call device_event_sync(this%gs_event)
381 call gs_h%op(this%w2, n, gs_op_add, this%gs_event)
382 call device_event_sync(this%gs_event)
383 call gs_h%op(this%w3, n, gs_op_add, this%gs_event)
384 call device_event_sync(this%gs_event)
385 call rotate_cyc(w1_d, w2_d, w3_d, 0, coef)
386
387 call blstx%apply(this%w1, n)
388 call blsty%apply(this%w2, n)
389 call blstz%apply(this%w3, n)
390
391 call device_vdot3(tmp_d, w1_d, w2_d, w3_d, p1_d, p2_d, p3_d, n)
392
393 pap = device_glsc2(tmp_d, coef%mult_d, n)
394
395 alpha = rtz1 / pap
396 alphm = -alpha
397 call device_opadd2cm(x%x_d, y%x_d, z%x_d, &
398 p1_d, p2_d, p3_d, alpha, n, gdim)
399 call device_opadd2cm(r1_d, r2_d, r3_d, &
400 w1_d, w2_d, w3_d, alphm, n, gdim)
401 call device_vdot3(tmp_d, r1_d, r2_d, r3_d, r1_d, r2_d, r3_d, n)
402
403 rtr = device_glsc2(tmp_d, coef%mult_d, n)
404 if (iter .eq. 1) rtr0 = rtr
405 rnorm = sqrt(rtr) * norm_fac
406 call this%monitor_iter(iter, rnorm)
407 if (rnorm .lt. this%abs_tol) then
408 exit
409 end if
410 end do
411 end associate
412 call this%monitor_stop()
413 ksp_results%res_final = rnorm
414 ksp_results%iter = iter
415 ksp_results%converged = this%is_converged(iter, rnorm)
416
417 end function cg_cpld_device_solve
418
419end module cg_cpld_device
__device__ T solve(const T u, const T y, const T guess, const T nu, const T kappa, const T B)
Return the device pointer for an associated Fortran array.
Definition device.F90:108
Map a Fortran array to a device (allocate and associate)
Definition device.F90:78
Unmap a Fortran array from a device (deassociate and free)
Definition device.F90:84
Apply cyclic boundary condition to a vector field.
Defines a Matrix-vector product.
Definition ax.f90:34
Defines a list of bc_t.
Definition bc_list.f90:34
Defines a coupled Conjugate Gradient methods for accelerators.
type(ksp_monitor_t) function, dimension(3) cg_cpld_device_solve(this, ax, x, y, z, fx, fy, fz, n, coef, blstx, blsty, blstz, gs_h, niter)
Standard PCG solve.
type(ksp_monitor_t) function cg_cpld_device_nop(this, ax, x, f, n, coef, blst, gs_h, niter)
subroutine cg_cpld_device_free(this)
Deallocate a device based PCG solver.
subroutine cg_cpld_device_init(this, n, max_iter, m, rel_tol, abs_tol, monitor)
Initialise a device based PCG solver.
Coefficients.
Definition coef.f90:34
subroutine, public device_add2s1(a_d, b_d, c1, n, strm)
subroutine, public device_rzero(a_d, n, strm)
Zero a real vector.
subroutine, public device_vdot3(dot_d, u1_d, u2_d, u3_d, v1_d, v2_d, v3_d, n, strm)
Compute a dot product (3-d version) assuming vector components etc.
subroutine, public device_copy(a_d, b_d, n, strm)
Copy a vector .
real(kind=rp) function, public device_glsc2(a_d, b_d, n, strm)
Weighted inner product .
subroutine, public device_opadd2cm(a1_d, a2_d, a3_d, b1_d, b2_d, b3_d, c, n, gdim)
Device abstraction, common interface for various accelerators.
Definition device.F90:34
subroutine, public device_event_sync(event)
Synchronize an event.
Definition device.F90:1594
subroutine, public device_event_destroy(event)
Destroy a device event.
Definition device.F90:1550
subroutine, public device_event_create(event, flags)
Create a device event queue.
Definition device.F90:1516
Defines a field.
Definition field.f90:34
Gather-scatter.
Implements the base abstract type for Krylov solvers plus helper types.
Definition krylov.f90:34
integer, parameter, public ksp_max_iter
Maximum number of iters.
Definition krylov.f90:51
Definition math.f90:60
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Operators.
Definition operators.f90:34
Krylov preconditioner.
Definition precon.f90:34
Utilities.
Definition utils.f90:35
Base type for a matrix-vector product providing .
Definition ax.f90:43
A list of allocatable `bc_t`. Follows the standard interface of lists.
Definition bc_list.f90:49
Device based coupled preconditioned conjugate gradient method.
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition coef.f90:63
Type for storing initial and final residuals in a Krylov solver.
Definition krylov.f90:56
Base abstract type for a canonical Krylov method, solving .
Definition krylov.f90:73
Defines a canonical Krylov preconditioner.
Definition precon.f90:40