Neko 1.99.3
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
ax_helm_device.F90
Go to the documentation of this file.
1! Copyright (c) 2021-2025, The Neko Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32!
34 use ax_helm, only : ax_helm_t
35 use num_types, only : rp
36 use coefs, only : coef_t
37 use space, only : space_t
38 use mesh, only : mesh_t
39 use device_math, only : device_addcol4
40 use device, only : device_get_ptr
41 use num_types, only : rp
42 use, intrinsic :: iso_c_binding, only : c_ptr, c_int
43 implicit none
44 private
45
46 type, public, extends(ax_helm_t) :: ax_helm_device_t
47 contains
48 procedure, nopass :: compute => ax_helm_device_compute
49 procedure, pass(this) :: compute_vector => ax_helm_device_compute_vector
50 end type ax_helm_device_t
51
52#ifdef HAVE_HIP
53 interface
54 subroutine hip_ax_helm(w_d, u_d, &
55 dx_d, dy_d, dz_d, dxt_d, dyt_d, dzt_d, &
56 h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d, nelv, lx) &
57 bind(c, name='hip_ax_helm')
58 use, intrinsic :: iso_c_binding
59 type(c_ptr), value :: w_d, u_d
60 type(c_ptr), value :: dx_d, dy_d, dz_d
61 type(c_ptr), value :: dxt_d, dyt_d, dzt_d
62 type(c_ptr), value :: h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d
63 integer(c_int) :: nelv, lx
64 end subroutine hip_ax_helm
65 end interface
66
67 interface
68 subroutine hip_ax_helm_vector(au_d, av_d, aw_d, u_d, v_d, w_d, &
69 dx_d, dy_d, dz_d, dxt_d, dyt_d, dzt_d,&
70 h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d, nelv, lx) &
71 bind(c, name='hip_ax_helm_vector')
72 use, intrinsic :: iso_c_binding
73 type(c_ptr), value :: au_d, av_d, aw_d
74 type(c_ptr), value :: u_d, v_d, w_d
75 type(c_ptr), value :: dx_d, dy_d, dz_d
76 type(c_ptr), value :: dxt_d, dyt_d, dzt_d
77 type(c_ptr), value :: h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d
78 integer(c_int) :: nelv, lx
79 end subroutine hip_ax_helm_vector
80 end interface
81
82 interface
83 subroutine hip_ax_helm_vector_part2(au_d, av_d, aw_d, u_d, v_d, w_d, &
84 h2_d, B_d, n) bind(c, name='hip_ax_helm_vector_part2')
85 use, intrinsic :: iso_c_binding
86 type(c_ptr), value :: au_d, av_d, aw_d
87 type(c_ptr), value :: u_d, v_d, w_d
88 type(c_ptr), value :: h2_d, B_d
89 integer(c_int) :: n
90 end subroutine hip_ax_helm_vector_part2
91 end interface
92#elif HAVE_CUDA
93 interface
94 subroutine cuda_ax_helm(w_d, u_d, &
95 dx_d, dy_d, dz_d, dxt_d, dyt_d, dzt_d,&
96 h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d, nelv, lx) &
97 bind(c, name='cuda_ax_helm')
98 use, intrinsic :: iso_c_binding
99 type(c_ptr), value :: w_d, u_d
100 type(c_ptr), value :: dx_d, dy_d, dz_d
101 type(c_ptr), value :: dxt_d, dyt_d, dzt_d
102 type(c_ptr), value :: h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d
103 integer(c_int) :: nelv, lx
104 end subroutine cuda_ax_helm
105 end interface
106
107 interface
108 subroutine cuda_ax_helm_vector(au_d, av_d, aw_d, u_d, v_d, w_d, &
109 dx_d, dy_d, dz_d, dxt_d, dyt_d, dzt_d,&
110 h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d, nelv, lx) &
111 bind(c, name='cuda_ax_helm_vector')
112 use, intrinsic :: iso_c_binding
113 type(c_ptr), value :: au_d, av_d, aw_d
114 type(c_ptr), value :: u_d, v_d, w_d
115 type(c_ptr), value :: dx_d, dy_d, dz_d
116 type(c_ptr), value :: dxt_d, dyt_d, dzt_d
117 type(c_ptr), value :: h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d
118 integer(c_int) :: nelv, lx
119 end subroutine cuda_ax_helm_vector
120 end interface
121
122 interface
123 subroutine cuda_ax_helm_vector_part2(au_d, av_d, aw_d, u_d, v_d, w_d, &
124 h2_d, B_d, n) bind(c, name='cuda_ax_helm_vector_part2')
125 use, intrinsic :: iso_c_binding
126 type(c_ptr), value :: au_d, av_d, aw_d
127 type(c_ptr), value :: u_d, v_d, w_d
128 type(c_ptr), value :: h2_d, B_d
129 integer(c_int) :: n
130 end subroutine cuda_ax_helm_vector_part2
131 end interface
132#elif HAVE_OPENCL
133 interface
134 subroutine opencl_ax_helm(w_d, u_d, &
135 dx_d, dy_d, dz_d, dxt_d, dyt_d, dzt_d, &
136 h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d, nelv, lx) &
137 bind(c, name='opencl_ax_helm')
138 use, intrinsic :: iso_c_binding
139 type(c_ptr), value :: w_d, u_d
140 type(c_ptr), value :: dx_d, dy_d, dz_d
141 type(c_ptr), value :: dxt_d, dyt_d, dzt_d
142 type(c_ptr), value :: h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d
143 integer(c_int) :: nelv, lx
144 end subroutine opencl_ax_helm
145 end interface
146
147 interface
148 subroutine opencl_ax_helm_vector(au_d, av_d, aw_d, u_d, v_d, w_d, &
149 dx_d, dy_d, dz_d, dxt_d, dyt_d, dzt_d,&
150 h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d, nelv, lx) &
151 bind(c, name='opencl_ax_helm_vector')
152 use, intrinsic :: iso_c_binding
153 type(c_ptr), value :: au_d, av_d, aw_d
154 type(c_ptr), value :: u_d, v_d, w_d
155 type(c_ptr), value :: dx_d, dy_d, dz_d
156 type(c_ptr), value :: dxt_d, dyt_d, dzt_d
157 type(c_ptr), value :: h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d
158 integer(c_int) :: nelv, lx
159 end subroutine opencl_ax_helm_vector
160 end interface
161
162 interface
163 subroutine opencl_ax_helm_vector_part2(au_d, av_d, aw_d, u_d, v_d, w_d, &
164 h2_d, B_d, n) bind(c, name='opencl_ax_helm_vector_part2')
165 use, intrinsic :: iso_c_binding
166 type(c_ptr), value :: au_d, av_d, aw_d
167 type(c_ptr), value :: u_d, v_d, w_d
168 type(c_ptr), value :: h2_d, B_d
169 integer(c_int) :: n
170 end subroutine opencl_ax_helm_vector_part2
171 end interface
172#elif HAVE_METAL
173 interface
174 subroutine metal_ax_helm(w_d, u_d, &
175 dx_d, dy_d, dz_d, dxt_d, dyt_d, dzt_d, &
176 h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d, nelv, lx) &
177 bind(c, name='metal_ax_helm')
178 use, intrinsic :: iso_c_binding
179 type(c_ptr), value :: w_d, u_d
180 type(c_ptr), value :: dx_d, dy_d, dz_d
181 type(c_ptr), value :: dxt_d, dyt_d, dzt_d
182 type(c_ptr), value :: h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d
183 integer(c_int) :: nelv, lx
184 end subroutine metal_ax_helm
185 end interface
186
187 interface
188 subroutine metal_ax_helm_vector(au_d, av_d, aw_d, u_d, v_d, w_d, &
189 dx_d, dy_d, dz_d, dxt_d, dyt_d, dzt_d,&
190 h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d, nelv, lx) &
191 bind(c, name='metal_ax_helm_vector')
192 use, intrinsic :: iso_c_binding
193 type(c_ptr), value :: au_d, av_d, aw_d
194 type(c_ptr), value :: u_d, v_d, w_d
195 type(c_ptr), value :: dx_d, dy_d, dz_d
196 type(c_ptr), value :: dxt_d, dyt_d, dzt_d
197 type(c_ptr), value :: h1_d, g11_d, g22_d, g33_d, g12_d, g13_d, g23_d
198 integer(c_int) :: nelv, lx
199 end subroutine metal_ax_helm_vector
200 end interface
201
202 interface
203 subroutine metal_ax_helm_vector_part2(au_d, av_d, aw_d, u_d, v_d, w_d, &
204 h2_d, B_d, n) bind(c, name='metal_ax_helm_vector_part2')
205 use, intrinsic :: iso_c_binding
206 type(c_ptr), value :: au_d, av_d, aw_d
207 type(c_ptr), value :: u_d, v_d, w_d
208 type(c_ptr), value :: h2_d, B_d
209 integer(c_int) :: n
210 end subroutine metal_ax_helm_vector_part2
211 end interface
212#endif
213
214contains
215
216 subroutine ax_helm_device_compute(w, u, coef, msh, Xh)
217 type(mesh_t), intent(in) :: msh
218 type(space_t), intent(in) :: Xh
219 type(coef_t), intent(in) :: coef
220 real(kind=rp), intent(inout) :: w(xh%lx, xh%ly, xh%lz, msh%nelv)
221 real(kind=rp), intent(in) :: u(xh%lx, xh%ly, xh%lz, msh%nelv)
222 type(c_ptr) :: u_d, w_d
223
224 u_d = device_get_ptr(u)
225 w_d = device_get_ptr(w)
226
227#ifdef HAVE_HIP
228 call hip_ax_helm(w_d, u_d, xh%dx_d, xh%dy_d, xh%dz_d, &
229 xh%dxt_d, xh%dyt_d, xh%dzt_d, coef%h1_d, &
230 coef%G11_d, coef%G22_d, coef%G33_d, &
231 coef%G12_d, coef%G13_d, coef%G23_d, &
232 msh%nelv, xh%lx)
233#elif HAVE_CUDA
234 call cuda_ax_helm(w_d, u_d, xh%dx_d, xh%dy_d, xh%dz_d, &
235 xh%dxt_d, xh%dyt_d, xh%dzt_d, coef%h1_d, &
236 coef%G11_d, coef%G22_d, coef%G33_d, &
237 coef%G12_d, coef%G13_d, coef%G23_d, &
238 msh%nelv, xh%lx)
239#elif HAVE_OPENCL
240 call opencl_ax_helm(w_d, u_d, xh%dx_d, xh%dy_d, xh%dz_d, &
241 xh%dxt_d, xh%dyt_d, xh%dzt_d, coef%h1_d, &
242 coef%G11_d, coef%G22_d, coef%G33_d, &
243 coef%G12_d, coef%G13_d, coef%G23_d, &
244 msh%nelv, xh%lx)
245#elif HAVE_METAL
246 call metal_ax_helm(w_d, u_d, xh%dx_d, xh%dy_d, xh%dz_d, &
247 xh%dxt_d, xh%dyt_d, xh%dzt_d, coef%h1_d, &
248 coef%G11_d, coef%G22_d, coef%G33_d, &
249 coef%G12_d, coef%G13_d, coef%G23_d, &
250 msh%nelv, xh%lx)
251#endif
252
253 if (coef%ifh2) then
254 call device_addcol4(w_d ,coef%h2_d, coef%B_d, u_d, coef%dof%size())
255 end if
256
257 end subroutine ax_helm_device_compute
258
259 subroutine ax_helm_device_compute_vector(this, au, av, aw, &
260 u, v, w, coef, msh, Xh)
261 class(ax_helm_device_t), intent(in) :: this
262 type(space_t), intent(in) :: Xh
263 type(mesh_t), intent(in) :: msh
264 type(coef_t), intent(in) :: coef
265 real(kind=rp), intent(inout) :: au(xh%lx, xh%ly, xh%lz, msh%nelv)
266 real(kind=rp), intent(inout) :: av(xh%lx, xh%ly, xh%lz, msh%nelv)
267 real(kind=rp), intent(inout) :: aw(xh%lx, xh%ly, xh%lz, msh%nelv)
268 real(kind=rp), intent(in) :: u(xh%lx, xh%ly, xh%lz, msh%nelv)
269 real(kind=rp), intent(in) :: v(xh%lx, xh%ly, xh%lz, msh%nelv)
270 real(kind=rp), intent(in) :: w(xh%lx, xh%ly, xh%lz, msh%nelv)
271 type(c_ptr) :: u_d, v_d, w_d
272 type(c_ptr) :: au_d, av_d, aw_d
273
274 u_d = device_get_ptr(u)
275 v_d = device_get_ptr(v)
276 w_d = device_get_ptr(w)
277
278 au_d = device_get_ptr(au)
279 av_d = device_get_ptr(av)
280 aw_d = device_get_ptr(aw)
281
282#ifdef HAVE_HIP
283 call hip_ax_helm_vector(au_d, av_d, aw_d, u_d, v_d, w_d, &
284 xh%dx_d, xh%dy_d, xh%dz_d, xh%dxt_d, xh%dyt_d, xh%dzt_d, coef%h1_d, &
285 coef%G11_d, coef%G22_d, coef%G33_d, &
286 coef%G12_d, coef%G13_d, coef%G23_d, &
287 msh%nelv, xh%lx)
288#elif HAVE_CUDA
289 call cuda_ax_helm_vector(au_d, av_d, aw_d, u_d, v_d, w_d, &
290 xh%dx_d, xh%dy_d, xh%dz_d, xh%dxt_d, xh%dyt_d, xh%dzt_d, coef%h1_d, &
291 coef%G11_d, coef%G22_d, coef%G33_d, &
292 coef%G12_d, coef%G13_d, coef%G23_d, &
293 msh%nelv, xh%lx)
294#elif HAVE_OPENCL
295 call opencl_ax_helm_vector(au_d, av_d, aw_d, u_d, v_d, w_d, &
296 xh%dx_d, xh%dy_d, xh%dz_d, xh%dxt_d, xh%dyt_d, xh%dzt_d, coef%h1_d, &
297 coef%G11_d, coef%G22_d, coef%G33_d, &
298 coef%G12_d, coef%G13_d, coef%G23_d, &
299 msh%nelv, xh%lx)
300#elif HAVE_METAL
301 call metal_ax_helm_vector(au_d, av_d, aw_d, u_d, v_d, w_d, &
302 xh%dx_d, xh%dy_d, xh%dz_d, xh%dxt_d, xh%dyt_d, xh%dzt_d, coef%h1_d, &
303 coef%G11_d, coef%G22_d, coef%G33_d, &
304 coef%G12_d, coef%G13_d, coef%G23_d, &
305 msh%nelv, xh%lx)
306#endif
307
308 if (coef%ifh2) then
309#ifdef HAVE_HIP
310 call hip_ax_helm_vector_part2(au_d, av_d, aw_d, u_d, v_d, w_d, &
311 coef%h2_d, coef%B_d, coef%dof%size())
312#elif HAVE_CUDA
313 call cuda_ax_helm_vector_part2(au_d, av_d, aw_d, u_d, v_d, w_d, &
314 coef%h2_d, coef%B_d, coef%dof%size())
315#elif HAVE_METAL
316 call metal_ax_helm_vector_part2(au_d, av_d, aw_d, u_d, v_d, w_d, &
317 coef%h2_d, coef%B_d, coef%dof%size())
318#else
319 call device_addcol4(au_d ,coef%h2_d, coef%B_d, u_d, coef%dof%size())
320 call device_addcol4(av_d ,coef%h2_d, coef%B_d, v_d, coef%dof%size())
321 call device_addcol4(aw_d ,coef%h2_d, coef%B_d, w_d, coef%dof%size())
322#endif
323 end if
324
325 end subroutine ax_helm_device_compute_vector
326
327end module ax_helm_device
void opencl_ax_helm_vector(void *au, void *av, void *aw, void *u, void *v, void *w, void *dx, void *dy, void *dz, void *dxt, void *dyt, void *dzt, void *h1, void *g11, void *g22, void *g33, void *g12, void *g13, void *g23, int *nelv, int *lx)
Definition ax_helm.c:258
void opencl_ax_helm(void *w, void *u, void *dx, void *dy, void *dz, void *dxt, void *dyt, void *dzt, void *h1, void *g11, void *g22, void *g33, void *g12, void *g13, void *g23, int *nelv, int *lx)
Definition ax_helm.c:57
void cuda_ax_helm_vector_part2(void *au, void *av, void *aw, void *u, void *v, void *w, void *h2, void *B, int *n)
Definition ax_helm.cu:254
void cuda_ax_helm(void *w, void *u, void *dx, void *dy, void *dz, void *dxt, void *dyt, void *dzt, void *h1, void *g11, void *g22, void *g33, void *g12, void *g13, void *g23, int *nelv, int *lx)
Definition ax_helm.cu:63
void cuda_ax_helm_vector(void *au, void *av, void *aw, void *u, void *v, void *w, void *dx, void *dy, void *dz, void *dxt, void *dyt, void *dzt, void *h1, void *g11, void *g22, void *g33, void *g12, void *g13, void *g23, int *nelv, int *lx)
Definition ax_helm.cu:184
Return the device pointer for an associated Fortran array.
Definition device.F90:108
subroutine ax_helm_device_compute(w, u, coef, msh, xh)
subroutine ax_helm_device_compute_vector(this, au, av, aw, u, v, w, coef, msh, xh)
Coefficients.
Definition coef.f90:34
subroutine, public device_addcol4(a_d, b_d, c_d, d_d, n, strm)
Returns .
Device abstraction, common interface for various accelerators.
Definition device.F90:34
Defines a mesh.
Definition mesh.f90:34
integer, parameter, public rp
Global precision used in computations.
Definition num_types.f90:12
Defines a function space.
Definition space.f90:34
Matrix-vector product for a Helmholtz problem.
Definition ax_helm.f90:44
Coefficients defined on a given (mesh, ) tuple. Arrays use indices (i,j,k,e): element e,...
Definition coef.f90:63
The function space for the SEM solution fields.
Definition space.f90:63