Loading [MathJax]/jax/output/HTML-CSS/config.js
Neko 0.9.99
A portable framework for high-order spectral element flow simulations
All Classes Namespaces Files Functions Variables Typedefs Enumerator Macros Pages
tree_amg.f90
Go to the documentation of this file.
1! Copyright (c) 2024, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
35 use num_types, only : rp
36 use utils, only : neko_error
37 use math, only : rzero, col2
40 use coefs, only : coef_t
41 use mesh, only : mesh_t
42 use space, only : space_t
43 use ax_product, only: ax_t
44 use bc_list, only: bc_list_t
45 use gather_scatter, only : gs_t, gs_op_add
48 use, intrinsic :: iso_c_binding
49 implicit none
50 private
51
53 type, private :: tamg_node_t
54 logical :: isleaf = .true.
55 integer :: gid = -1
56 integer :: lvl = -1
57 integer :: ndofs = 0
58 integer, allocatable :: dofs(:)
59 real(kind=rp) :: xyz(3)
60 real(kind=rp), allocatable :: interp_r(:)
61 real(kind=rp), allocatable :: interp_p(:)
62 end type tamg_node_t
63
65 type, private :: tamg_lvl_t
66 integer :: lvl = -1
67 integer :: nnodes = 0
68 type(tamg_node_t), allocatable :: nodes(:)
69 integer :: fine_lvl_dofs = 0
70 real(kind=rp), allocatable :: wrk_in(:)
71 type(c_ptr) :: wrk_in_d = c_null_ptr
72 real(kind=rp), allocatable :: wrk_out(:)
73 type(c_ptr) :: wrk_out_d = c_null_ptr
74 integer, allocatable :: map_finest2lvl(:)
75 type(c_ptr) :: map_finest2lvl_d = c_null_ptr
76 !--!
77 integer, allocatable :: nodes_ptr(:)
78 type(c_ptr) :: nodes_ptr_d = c_null_ptr
79 integer, allocatable :: nodes_gid(:)
80 type(c_ptr) :: nodes_gid_d = c_null_ptr
81 integer, allocatable :: nodes_dofs(:)
82 type(c_ptr) :: nodes_dofs_d = c_null_ptr
83 integer, allocatable :: map_f2c(:)
84 type(c_ptr) :: map_f2c_d = c_null_ptr
85 ! could make another array of the same size of nodes_dofs
86 ! that stores the parent node gid information
87 ! (similar to nodes_gid that stores the gid of each node)
88 ! then some loops can be simplified to a single loop
89 ! of len(nodes_dofs) instead of going through each node
90 ! and looping through nodes_ptr(i) to nodes_ptr(i+1)-1
91 end type tamg_lvl_t
92
94 type, public :: tamg_hierarchy_t
96 integer :: nlvls
98 type(tamg_lvl_t), allocatable :: lvl(:)
99
101 class(ax_t), pointer :: ax
102 type(mesh_t), pointer :: msh
103 type(space_t), pointer :: xh
104 type(coef_t), pointer :: coef
105 type(gs_t), pointer :: gs_h
106 type(bc_list_t), pointer :: blst
107
108 contains
109 procedure, pass(this) :: init => tamg_init
110 procedure, pass(this) :: matvec => tamg_matvec
111 procedure, pass(this) :: matvec_impl => tamg_matvec_impl
112 procedure, pass(this) :: interp_f2c => tamg_restriction_operator
113 procedure, pass(this) :: interp_c2f => tamg_prolongation_operator
114 procedure, pass(this) :: interp_f2c_d => tamg_device_restriction_operator
115 procedure, pass(this) :: interp_c2f_d => tamg_device_prolongation_operator
116 procedure, pass(this) :: device_matvec => tamg_device_matvec_flat_impl
117 end type tamg_hierarchy_t
118
120
121contains
122
131 subroutine tamg_init(this, ax, Xh, coef, msh, gs_h, nlvls, blst)
132 class(tamg_hierarchy_t), target, intent(inout) :: this
133 class(ax_t), target, intent(in) :: ax
134 type(space_t),target, intent(in) :: Xh
135 type(coef_t), target, intent(in) :: coef
136 type(mesh_t), target, intent(in) :: msh
137 type(gs_t), target, intent(in) :: gs_h
138 integer, intent(in) :: nlvls
139 type(bc_list_t), target, intent(in) :: blst
140 integer :: i, n
141
142 this%ax => ax
143 this%msh => msh
144 this%Xh => xh
145 this%coef => coef
146 this%gs_h => gs_h
147 this%blst => blst
148
149 if (nlvls .lt. 2) then
150 call neko_error("Need to request at least two multigrid levels.")
151 end if
152
153 this%nlvls = nlvls
154 allocate( this%lvl(this%nlvls) )
155
156 do i = 1, nlvls
157 allocate( this%lvl(i)%map_finest2lvl( 0:coef%dof%size() ))
158 if (neko_bcknd_device .eq. 1) then
159 call device_map(this%lvl(i)%map_finest2lvl, this%lvl(i)%map_finest2lvl_d, coef%dof%size()+1)
160 end if
161 end do
162
163 end subroutine tamg_init
164
170 subroutine tamg_lvl_init(tamg_lvl, lvl, nnodes, ndofs)
171 type(tamg_lvl_t), intent(inout) :: tamg_lvl
172 integer, intent(in) :: lvl
173 integer, intent(in) :: nnodes
174 integer, intent(in) :: ndofs
175
176 tamg_lvl%lvl = lvl
177 tamg_lvl%nnodes = nnodes
178 allocate( tamg_lvl%nodes(tamg_lvl%nnodes) )
179 allocate( tamg_lvl%nodes_ptr(tamg_lvl%nnodes+1) )
180 allocate( tamg_lvl%nodes_gid(tamg_lvl%nnodes) )
181 allocate( tamg_lvl%nodes_dofs(ndofs) )
182 allocate( tamg_lvl%map_f2c(0:ndofs) )
183 if (neko_bcknd_device .eq. 1) then
184 call device_map(tamg_lvl%map_f2c, tamg_lvl%map_f2c_d, ndofs+1)
185 end if
186
187 tamg_lvl%fine_lvl_dofs = ndofs
188 allocate( tamg_lvl%wrk_in( ndofs ) )
189 allocate( tamg_lvl%wrk_out( ndofs ) )
190 if (neko_bcknd_device .eq. 1) then
191 call device_map(tamg_lvl%wrk_in, tamg_lvl%wrk_in_d, ndofs)
192 call device_cfill(tamg_lvl%wrk_in_d, 0.0_rp, ndofs)
193 call device_map(tamg_lvl%wrk_out, tamg_lvl%wrk_out_d, ndofs)
194 call device_cfill(tamg_lvl%wrk_out_d, 0.0_rp, ndofs)
195 end if
196 end subroutine tamg_lvl_init
197
202 subroutine tamg_node_init(node, gid, ndofs)
203 type(tamg_node_t), intent(inout) :: node
204 integer, intent(in) :: gid
205 integer, intent(in) :: ndofs
206
207 node%gid = gid
208 node%ndofs = ndofs
209 allocate( node%dofs( node%ndofs) )
210 node%dofs = -1
211 allocate( node%interp_r( node%ndofs) )
212 allocate( node%interp_p( node%ndofs) )
213 node%interp_r = 1.0_rp
214 node%interp_p = 1.0_rp
215 end subroutine tamg_node_init
216
223 recursive subroutine tamg_matvec(this, vec_out, vec_in, lvl_out)
224 class(tamg_hierarchy_t), intent(inout) :: this
225 real(kind=rp), intent(inout) :: vec_out(:)
226 real(kind=rp), intent(inout) :: vec_in(:)
227 integer, intent(in) :: lvl_out
228 integer :: i, n, e
229 call this%matvec_impl(vec_out, vec_in, this%nlvls, lvl_out)
230 !call tamg_matvec_flat_impl(this, vec_out, vec_in, this%nlvls, lvl_out)
231 end subroutine tamg_matvec
232
240 recursive subroutine tamg_matvec_impl(this, vec_out, vec_in, lvl, lvl_out)
241 class(tamg_hierarchy_t), intent(inout) :: this
242 real(kind=rp), intent(inout) :: vec_out(:)
243 real(kind=rp), intent(inout) :: vec_in(:)
244 integer, intent(in) :: lvl
245 integer, intent(in) :: lvl_out
246 integer :: i, n, e
247
248 if (lvl .eq. 0) then
250 n = size(vec_in)
252 call this%gs_h%op(vec_in, n, gs_op_add)
253 call col2( vec_in, this%coef%mult(1,1,1,1), n)
255 call this%ax%compute(vec_out, vec_in, this%coef, this%msh, this%Xh)
257 call this%gs_h%op(vec_out, n, gs_op_add)
258 call this%blst%apply(vec_out, n)
260 else
261 if (lvl_out .ge. lvl) then
264 associate( wrk_in => this%lvl(lvl)%wrk_in, wrk_out => this%lvl(lvl)%wrk_out)
265 n = this%lvl(lvl)%fine_lvl_dofs
266 call rzero(wrk_in, n)
267 call rzero(wrk_out, n)
268 call rzero(vec_out, this%lvl(lvl)%nnodes)
269 do n = 1, this%lvl(lvl)%nnodes
270 associate(node => this%lvl(lvl)%nodes(n))
271 do i = 1, node%ndofs
272 wrk_in( node%dofs(i) ) = wrk_in( node%dofs(i) ) + vec_in( node%gid ) * node%interp_p( i )
273 end do
274 end associate
275 end do
276
277 call this%matvec_impl(wrk_out, wrk_in, lvl-1, lvl_out)
278
280 do n = 1, this%lvl(lvl)%nnodes
281 associate(node => this%lvl(lvl)%nodes(n))
282 do i = 1, node%ndofs
283 vec_out( node%gid ) = vec_out(node%gid ) + wrk_out( node%dofs(i) ) * node%interp_r( i )
284 end do
285 end associate
286 end do
287 end associate
288 else if (lvl_out .lt. lvl) then
290 call this%matvec_impl(vec_out, vec_in, lvl-1, lvl_out)
291 else
292 call neko_error("TAMG: matvec level numbering problem.")
293 end if
294 end if
295 end subroutine tamg_matvec_impl
296
297
299 recursive subroutine tamg_matvec_flat_impl(this, vec_out, vec_in, lvl_blah, lvl_out)
300 class(tamg_hierarchy_t), intent(inout) :: this
301 real(kind=rp), intent(inout) :: vec_out(:)
302 real(kind=rp), intent(inout) :: vec_in(:)
303 integer, intent(in) :: lvl_blah
304 integer, intent(in) :: lvl_out
305 integer :: i, n, cdof, lvl
306
307 lvl = lvl_out
308 if (lvl .eq. 0) then
310 n = size(vec_in)
312 call this%gs_h%op(vec_in, n, gs_op_add)
313 call col2( vec_in, this%coef%mult(1,1,1,1), n)
315 call this%ax%compute(vec_out, vec_in, this%coef, this%msh, this%Xh)
317 call this%gs_h%op(vec_out, n, gs_op_add)
318 call this%blst%apply(vec_out, n)
320 else
321
322 associate( wrk_in => this%lvl(1)%wrk_in, wrk_out => this%lvl(1)%wrk_out)
323 n = size(wrk_in)
324 call rzero(wrk_out, n)
325 call rzero(vec_out, this%lvl(lvl)%nnodes)
326
328 do i = 1, n
329 cdof = this%lvl(lvl)%map_finest2lvl(i)
330 wrk_in(i) = vec_in( cdof )
331 end do
332
334 call this%gs_h%op(wrk_in, n, gs_op_add)
335 call col2( wrk_in, this%coef%mult(1,1,1,1), n)
337 call this%ax%compute(wrk_out, wrk_in, this%coef, this%msh, this%Xh)
339 call this%gs_h%op(wrk_out, n, gs_op_add)
340 call this%blst%apply(wrk_out, n)
342
344 do i = 1, n
345 cdof = this%lvl(lvl)%map_finest2lvl(i)
346 vec_out(cdof) = vec_out(cdof) + wrk_out( i )
347 end do
348 end associate
349
350 end if
351 end subroutine tamg_matvec_flat_impl
352
353
354
359 subroutine tamg_restriction_operator(this, vec_out, vec_in, lvl)
360 class(tamg_hierarchy_t), intent(inout) :: this
361 real(kind=rp), intent(inout) :: vec_out(:)
362 real(kind=rp), intent(inout) :: vec_in(:)
363 integer, intent(in) :: lvl
364 integer :: i, n, node_start, node_end, node_id
365
366 vec_out = 0d0
367 do n = 1, this%lvl(lvl)%nnodes
368 associate(node => this%lvl(lvl)%nodes(n))
369 do i = 1, node%ndofs
370 vec_out( node%gid ) = vec_out( node%gid ) + vec_in( node%dofs(i) ) * node%interp_r( i )
371 end do
372 end associate
373 end do
374 !do n = 1, this%lvl(lvl)%nnodes
375 ! node_start = this%lvl(lvl)%nodes_ptr(n)
376 ! node_end = this%lvl(lvl)%nodes_ptr(n+1)-1
377 ! node_id = this%lvl(lvl)%nodes_gid(n)
378 ! do i = node_start, node_end
379 ! vec_out( node_id ) = vec_out( node_id ) + &
380 ! vec_in( this%lvl(lvl)%nodes_dofs(i) )
381 ! end do
382 !end do
383 end subroutine tamg_restriction_operator
384
389 subroutine tamg_prolongation_operator(this, vec_out, vec_in, lvl)
390 class(tamg_hierarchy_t), intent(inout) :: this
391 real(kind=rp), intent(inout) :: vec_out(:)
392 real(kind=rp), intent(inout) :: vec_in(:)
393 integer, intent(in) :: lvl
394 integer :: i, n, node_start, node_end, node_id
395
396 vec_out = 0d0
397 do n = 1, this%lvl(lvl)%nnodes
398 associate(node => this%lvl(lvl)%nodes(n))
399 do i = 1, node%ndofs
400 vec_out( node%dofs(i) ) = vec_out( node%dofs(i) ) + vec_in( node%gid ) * node%interp_p( i )
401 end do
402 end associate
403 end do
404 !do n = 1, this%lvl(lvl)%nnodes
405 ! node_start = this%lvl(lvl)%nodes_ptr(n)
406 ! node_end = this%lvl(lvl)%nodes_ptr(n+1)-1
407 ! node_id = this%lvl(lvl)%nodes_gid(n)
408 ! do i = node_start, node_end
409 ! vec_out( this%lvl(lvl)%nodes_dofs(i) ) = vec_out( this%lvl(lvl)%nodes_dofs(i) ) + &
410 ! vec_in(node_id)
411 ! end do
412 !end do
413 end subroutine tamg_prolongation_operator
414
415
416 subroutine tamg_device_matvec_flat_impl(this, vec_out, vec_in, vec_out_d, vec_in_d, lvl_out)
417 class(tamg_hierarchy_t), intent(inout) :: this
418 real(kind=rp), intent(inout) :: vec_out(:)
419 real(kind=rp), intent(inout) :: vec_in(:)
420 type(c_ptr) :: vec_out_d
421 type(c_ptr) :: vec_in_d
422 integer, intent(in) :: lvl_out
423 integer :: i, n, cdof, lvl
424
425 lvl = lvl_out
426 n = this%lvl(1)%fine_lvl_dofs
427 if (lvl .eq. 0) then
428 call this%ax%compute(vec_out, vec_in, this%coef, this%msh, this%Xh)
429 call this%gs_h%op(vec_out, n, gs_op_add, glb_cmd_event)
430 call device_stream_wait_event(glb_cmd_queue, glb_cmd_event, 0)
431 call this%blst%apply(vec_out, n)
432 else
433
434 associate( wrk_in_d => this%lvl(1)%wrk_in_d, wrk_out_d => this%lvl(1)%wrk_out_d)
436 call device_masked_red_copy(wrk_in_d, vec_in_d, this%lvl(lvl)%map_finest2lvl_d, this%lvl(lvl)%nnodes, n)
438 call this%gs_h%op(this%lvl(1)%wrk_in, n, gs_op_add)
439 call device_stream_wait_event(glb_cmd_queue, glb_cmd_event, 0)
440 call device_col2( wrk_in_d, this%coef%mult_d, n)
442 call this%ax%compute(this%lvl(1)%wrk_out, this%lvl(1)%wrk_in, this%coef, this%msh, this%Xh)
443 call this%gs_h%op(this%lvl(1)%wrk_out, n, gs_op_add)
444 call device_stream_wait_event(glb_cmd_queue, glb_cmd_event, 0)
445 call this%blst%apply(this%lvl(1)%wrk_out, n)
447 call device_rzero(vec_out_d, this%lvl(lvl)%nnodes)
448 call device_masked_atomic_reduction(vec_out_d, wrk_out_d, this%lvl(lvl)%map_finest2lvl_d, this%lvl(lvl)%nnodes, n)
449 !TODO: swap n and m
450 end associate
451
452 end if
453 end subroutine tamg_device_matvec_flat_impl
454
455 subroutine tamg_device_restriction_operator(this, vec_out_d, vec_in_d, lvl)
456 class(tamg_hierarchy_t), intent(inout) :: this
457 type(c_ptr) :: vec_out_d
458 type(c_ptr) :: vec_in_d
459 integer, intent(in) :: lvl
460 integer :: i, n, m
461 n = this%lvl(lvl)%nnodes
462 m = this%lvl(lvl)%fine_lvl_dofs
463 call device_rzero(vec_out_d, n)
464 call device_masked_atomic_reduction(vec_out_d, vec_in_d, this%lvl(lvl)%map_f2c_d, n, m)
466
467 subroutine tamg_device_prolongation_operator(this, vec_out_d, vec_in_d, lvl)
468 class(tamg_hierarchy_t), intent(inout) :: this
469 type(c_ptr) :: vec_out_d
470 type(c_ptr) :: vec_in_d
471 integer, intent(in) :: lvl
472 integer :: i, n, m
473 n = this%lvl(lvl)%nnodes
474 m = this%lvl(lvl)%fine_lvl_dofs
475 call device_masked_red_copy(vec_out_d, vec_in_d, this%lvl(lvl)%map_f2c_d, n, m)
477
478end module tree_amg
Map a Fortran array to a device (allocate and associate)
Definition device.F90:71
Defines a Matrix-vector product.
Definition ax.f90:34
Defines a list of bc_t.
Definition bc_list.f90:34
Coefficients.
Definition coef.f90:34
subroutine, public device_col2(a_d, b_d, n)
Vector multiplication .
subroutine, public device_rzero(a_d, n)
Zero a real vector.
subroutine, public device_masked_red_copy(a_d, b_d, mask_d, n, m)
subroutine, public device_masked_atomic_reduction(a_d, b_d, mask_d, n, m)
subroutine, public device_cfill(a_d, c, n)
Set all elements to a constant c .
Device abstraction, common interface for various accelerators.
Definition device.F90:34
subroutine, public device_free(x_d)
Deallocate memory on the device.
Definition device.F90:200
subroutine, public device_stream_wait_event(stream, event, flags)
Synchronize a device stream with an event.
Definition device.F90:1138
type(c_ptr), bind(C), public glb_cmd_queue
Global command queue.
Definition device.F90:50
type(c_ptr), bind(C), public glb_cmd_event
Event for the global command queue.
Definition device.F90:56
Gather-scatter.
Definition math.f90:60
subroutine, public col2(a, b, n)
Vector multiplication .
Definition math.f90:728
subroutine, public rzero(a, n)
Zero a real vector.
Definition math.f90:194
Defines a mesh.
Definition mesh.f90:34
Build configurations.
integer, parameter neko_bcknd_device
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Defines a function space.
Definition space.f90:34
Implements the base type for TreeAMG hierarchy structure.
Definition tree_amg.f90:34
subroutine tamg_device_matvec_flat_impl(this, vec_out, vec_in, vec_out_d, vec_in_d, lvl_out)
Definition tree_amg.f90:417
recursive subroutine tamg_matvec_impl(this, vec_out, vec_in, lvl, lvl_out)
Matrix vector product using the TreeAMG hierarchy structure b=Ax done as vec_out = A * vec_in This is...
Definition tree_amg.f90:241
subroutine, public tamg_node_init(node, gid, ndofs)
Initialization of a TreeAMG tree node.
Definition tree_amg.f90:203
subroutine tamg_restriction_operator(this, vec_out, vec_in, lvl)
Restriction operator for TreeAMG. vec_out = R * vec_in.
Definition tree_amg.f90:360
subroutine tamg_prolongation_operator(this, vec_out, vec_in, lvl)
Prolongation operator for TreeAMG. vec_out = P * vec_in.
Definition tree_amg.f90:390
subroutine, public tamg_lvl_init(tamg_lvl, lvl, nnodes, ndofs)
Initialization of a TreeAMG level.
Definition tree_amg.f90:171
subroutine tamg_device_prolongation_operator(this, vec_out_d, vec_in_d, lvl)
Definition tree_amg.f90:468
recursive subroutine tamg_matvec(this, vec_out, vec_in, lvl_out)
Wrapper for matrix vector product using the TreeAMG hierarchy structure b=Ax done as vec_out = A * ve...
Definition tree_amg.f90:224
subroutine tamg_device_restriction_operator(this, vec_out_d, vec_in_d, lvl)
Definition tree_amg.f90:456
subroutine tamg_init(this, ax, xh, coef, msh, gs_h, nlvls, blst)
Initialization of TreeAMG hierarchy.
Definition tree_amg.f90:132
recursive subroutine tamg_matvec_flat_impl(this, vec_out, vec_in, lvl_blah, lvl_out)
Ignore this. For piecewise constant, can create index map directly to finest level.
Definition tree_amg.f90:300
Utilities.
Definition utils.f90:35
Base type for a matrix-vector product providing .
Definition ax.f90:43
A list of allocatable `bc_t`. Follows the standard interface of lists.
Definition bc_list.f90:47
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition coef.f90:55
The function space for the SEM solution fields.
Definition space.f90:62
Type for a TreeAMG hierarchy.
Definition tree_amg.f90:94
Type for storing TreeAMG level information.
Definition tree_amg.f90:65
Type for storing TreeAMG tree node information.
Definition tree_amg.f90:53