Neko 0.9.99
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
ax_helm_kernel.h
Go to the documentation of this file.
1#ifndef __MATH_AX_HELM_KERNEL_H__
2#define __MATH_AX_HELM_KERNEL_H__
3/*
4 Copyright (c) 2021-2024, The Neko Authors
5 All rights reserved.
6
7 Redistribution and use in source and binary forms, with or without
8 modification, are permitted provided that the following conditions
9 are met:
10
11 * Redistributions of source code must retain the above copyright
12 notice, this list of conditions and the following disclaimer.
13
14 * Redistributions in binary form must reproduce the above
15 copyright notice, this list of conditions and the following
16 disclaimer in the documentation and/or other materials provided
17 with the distribution.
18
19 * Neither the name of the authors nor the names of its
20 contributors may be used to endorse or promote products derived
21 from this software without specific prior written permission.
22
23 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 POSSIBILITY OF SUCH DAMAGE.
35*/
36
41template< typename T, const int LX, const int CHUNKS >
43 const T * __restrict__ u,
44 const T * __restrict__ dx,
45 const T * __restrict__ dy,
46 const T * __restrict__ dz,
47 const T * __restrict__ dxt,
48 const T * __restrict__ dyt,
49 const T * __restrict__ dzt,
50 const T * __restrict__ h1,
51 const T * __restrict__ g11,
52 const T * __restrict__ g22,
53 const T * __restrict__ g33,
54 const T * __restrict__ g12,
55 const T * __restrict__ g13,
56 const T * __restrict__ g23) {
57
61
65
70
71 const int e = blockIdx.x;
72 const int iii = threadIdx.x;
73 const int nchunks = (LX * LX * LX - 1)/CHUNKS + 1;
74
75 if (iii<LX*LX) {
76 shdx[iii] = dx[iii];
77 shdy[iii] = dy[iii];
78 shdz[iii] = dz[iii];
79 }
80
81 {
82 int i = iii;
83 while (i < LX * LX * LX){
84 shu[i] = u[i+e*LX*LX*LX];
85 i = i + CHUNKS;
86 }
87 }
88
90
91 if (iii<LX*LX){
92 shdxt[iii] = dxt[iii];
93 shdyt[iii] = dyt[iii];
94 shdzt[iii] = dzt[iii];
95 }
96
97 for (int n=0; n<nchunks; n++){
98 const int ijk = iii+n*CHUNKS;
99 const int jk = ijk/LX;
100 const int i = ijk-jk*LX;
101 const int k = jk/LX;
102 const int j = jk-k*LX;
103 if (i<LX && j<LX && k<LX && ijk < LX*LX*LX){
104 const T G00 = g11[ijk+e*LX*LX*LX];
105 const T G11 = g22[ijk+e*LX*LX*LX];
106 const T G22 = g33[ijk+e*LX*LX*LX];
107 const T G01 = g12[ijk+e*LX*LX*LX];
108 const T G02 = g13[ijk+e*LX*LX*LX];
109 const T G12 = g23[ijk+e*LX*LX*LX];
110 const T H1 = h1[ijk+e*LX*LX*LX];
111 T rtmp = 0.0;
112 T stmp = 0.0;
113 T ttmp = 0.0;
114#pragma unroll
115 for (int l = 0; l<LX; l++){
116 rtmp = rtmp + shdx[i+l*LX] * shu[l+j*LX+k*LX*LX];
117 stmp = stmp + shdy[j+l*LX] * shu[i+l*LX+k*LX*LX];
118 ttmp = ttmp + shdz[k+l*LX] * shu[i+j*LX+l*LX*LX];
119 }
120 shur[ijk] = H1 * (G00 * rtmp + G01 * stmp + G02 * ttmp);
121 shus[ijk] = H1 * (G01 * rtmp + G11 * stmp + G12 * ttmp);
122 shut[ijk] = H1 * (G02 * rtmp + G12 * stmp + G22 * ttmp);
123 }
124 }
125
127
128 for (int n=0; n<nchunks; n++){
129 const int ijk = iii+n*CHUNKS;
130 const int jk = ijk/LX;
131 const int k = jk/LX;
132 const int j = jk-k*LX;
133 const int i = ijk-jk*LX;
134 if (i<LX && j<LX && k<LX && ijk <LX*LX*LX){
135 T wijke = 0.0;
136#pragma unroll
137 for (int l = 0; l<LX; l++){
138 wijke = wijke
139 + shdxt[i+l*LX] * shur[l+j*LX+k*LX*LX]
140 + shdyt[j+l*LX] * shus[i+l*LX+k*LX*LX]
141 + shdzt[k+l*LX] * shut[i+j*LX+l*LX*LX];
142 }
143 w[ijk+e*LX*LX*LX] = wijke;
144 }
145 }
146}
147
148template< typename T, const int LX >
150 const T * __restrict__ u,
151 const T * __restrict__ dx,
152 const T * __restrict__ dy,
153 const T * __restrict__ dz,
154 const T * __restrict__ h1,
155 const T * __restrict__ g11,
156 const T * __restrict__ g22,
157 const T * __restrict__ g33,
158 const T * __restrict__ g12,
159 const T * __restrict__ g13,
160 const T * __restrict__ g23) {
161
162 __shared__ T shdx[LX * LX];
163 __shared__ T shdy[LX * LX];
164 __shared__ T shdz[LX * LX];
165
166 __shared__ T shu[LX * LX];
167 __shared__ T shur[LX * LX];
168 __shared__ T shus[LX * LX];
169
170 T ru[LX];
171 T rw[LX];
172 T rut;
173
174 const int e = blockIdx.x;
175 const int j = threadIdx.y;
176 const int i = threadIdx.x;
177 const int ij = i + j*LX;
178 const int ele = e*LX*LX*LX;
179
180 shdx[ij] = dx[ij];
181 shdy[ij] = dy[ij];
182 shdz[ij] = dz[ij];
183
184#pragma unroll
185 for(int k = 0; k < LX; ++k){
186 ru[k] = u[ij + k*LX*LX + ele];
187 rw[k] = 0.0;
188 }
189
190
192#pragma unroll
193 for (int k = 0; k < LX; ++k){
194 const int ijk = ij + k*LX*LX;
195 const T G00 = g11[ijk+ele];
196 const T G11 = g22[ijk+ele];
197 const T G22 = g33[ijk+ele];
198 const T G01 = g12[ijk+ele];
199 const T G02 = g13[ijk+ele];
200 const T G12 = g23[ijk+ele];
201 const T H1 = h1[ijk+ele];
202 T ttmp = 0.0;
203 shu[ij] = ru[k];
204 for (int l = 0; l < LX; l++){
205 ttmp += shdz[k+l*LX] * ru[l];
206 }
208
209 T rtmp = 0.0;
210 T stmp = 0.0;
211#pragma unroll
212 for (int l = 0; l < LX; l++){
213 rtmp += shdx[i+l*LX] * shu[l+j*LX];
214 stmp += shdy[j+l*LX] * shu[i+l*LX];
215 }
216 shur[ij] = H1
217 * (G00 * rtmp
218 + G01 * stmp
219 + G02 * ttmp);
220 shus[ij] = H1
221 * (G01 * rtmp
222 + G11 * stmp
223 + G12 * ttmp);
224 rut = H1
225 * (G02 * rtmp
226 + G12 * stmp
227 + G22 * ttmp);
228
230
231 T wijke = 0.0;
232#pragma unroll
233 for (int l = 0; l < LX; l++){
234 wijke += shur[l+j*LX] * shdx[l+i*LX];
235 rw[l] += rut * shdz[k+l*LX];
236 wijke += shus[i+l*LX] * shdy[l + j*LX];
237 }
238 rw[k] += wijke;
239 }
240#pragma unroll
241 for (int k = 0; k < LX; ++k){
242 w[ij + k*LX*LX + ele] = rw[k];
243 }
244}
245
251template< typename T, const int LX >
253 const T * __restrict__ u,
254 const T * __restrict__ dx,
255 const T * __restrict__ dy,
256 const T * __restrict__ dz,
257 const T * __restrict__ h1,
258 const T * __restrict__ g11,
259 const T * __restrict__ g22,
260 const T * __restrict__ g33,
261 const T * __restrict__ g12,
262 const T * __restrict__ g13,
263 const T * __restrict__ g23) {
264
265 __shared__ T shdx[LX * (LX+1)];
266 __shared__ T shdy[LX * (LX+1)];
267 __shared__ T shdz[LX * (LX+1)];
268
269 __shared__ T shu[LX * (LX+1)];
270 __shared__ T shur[LX * LX]; // only accessed using fastest dimension
271 __shared__ T shus[LX * (LX+1)];
272
273 T ru[LX];
274 T rw[LX];
275 T rut;
276
277 const int e = blockIdx.x;
278 const int j = threadIdx.y;
279 const int i = threadIdx.x;
280 const int ij = i + j*LX;
281 const int ij_p = i + j*(LX+1);
282 const int ele = e*LX*LX*LX;
283
284 shdx[ij_p] = dx[ij];
285 shdy[ij_p] = dy[ij];
286 shdz[ij_p] = dz[ij];
287
288#pragma unroll
289 for(int k = 0; k < LX; ++k){
290 ru[k] = u[ij + k*LX*LX + ele];
291 rw[k] = 0.0;
292 }
293
294
296#pragma unroll
297 for (int k = 0; k < LX; ++k){
298 const int ijk = ij + k*LX*LX;
299 const T G00 = g11[ijk+ele];
300 const T G11 = g22[ijk+ele];
301 const T G22 = g33[ijk+ele];
302 const T G01 = g12[ijk+ele];
303 const T G02 = g13[ijk+ele];
304 const T G12 = g23[ijk+ele];
305 const T H1 = h1[ijk+ele];
306 T ttmp = 0.0;
307 shu[ij_p] = ru[k];
308 for (int l = 0; l < LX; l++){
309 ttmp += shdz[k+l*(LX+1)] * ru[l];
310 }
312
313 T rtmp = 0.0;
314 T stmp = 0.0;
315#pragma unroll
316 for (int l = 0; l < LX; l++){
317 rtmp += shdx[i+l*(LX+1)] * shu[l+j*(LX+1)];
318 stmp += shdy[j+l*(LX+1)] * shu[i+l*(LX+1)];
319 }
320 shur[ij] = H1
321 * (G00 * rtmp
322 + G01 * stmp
323 + G02 * ttmp);
324 shus[ij_p] = H1
325 * (G01 * rtmp
326 + G11 * stmp
327 + G12 * ttmp);
328 rut = H1
329 * (G02 * rtmp
330 + G12 * stmp
331 + G22 * ttmp);
332
334
335 T wijke = 0.0;
336#pragma unroll
337 for (int l = 0; l < LX; l++){
338 wijke += shur[l+j*LX] * shdx[l+i*(LX+1)];
339 rw[l] += rut * shdz[k+l*(LX+1)];
340 wijke += shus[i+l*(LX+1)] * shdy[l + j*(LX+1)];
341 }
342 rw[k] += wijke;
343 }
344#pragma unroll
345 for (int k = 0; k < LX; ++k){
346 w[ij + k*LX*LX + ele] = rw[k];
347 }
348}
349
350/*
351 * Vector versions
352 */
353
354template< typename T, const int LX >
356 T * __restrict__ av,
357 T * __restrict__ aw,
358 const T * __restrict__ u,
359 const T * __restrict__ v,
360 const T * __restrict__ w,
361 const T * __restrict__ dx,
362 const T * __restrict__ dy,
363 const T * __restrict__ dz,
364 const T * __restrict__ h1,
365 const T * __restrict__ g11,
366 const T * __restrict__ g22,
367 const T * __restrict__ g33,
368 const T * __restrict__ g12,
369 const T * __restrict__ g13,
370 const T * __restrict__ g23) {
371
372 __shared__ T shdx[LX * LX];
373 __shared__ T shdy[LX * LX];
374 __shared__ T shdz[LX * LX];
375
376 __shared__ T shu[LX * LX];
377 __shared__ T shur[LX * LX];
378 __shared__ T shus[LX * LX];
379
380 __shared__ T shv[LX * LX];
381 __shared__ T shvr[LX * LX];
382 __shared__ T shvs[LX * LX];
383
384 __shared__ T shw[LX * LX];
385 __shared__ T shwr[LX * LX];
386 __shared__ T shws[LX * LX];
387
388 T ru[LX];
389 T rv[LX];
390 T rw[LX];
391
392 T ruw[LX];
393 T rvw[LX];
394 T rww[LX];
395
396 T rut;
397 T rvt;
398 T rwt;
399
400 const int e = blockIdx.x;
401 const int j = threadIdx.y;
402 const int i = threadIdx.x;
403 const int ij = i + j*LX;
404 const int ele = e*LX*LX*LX;
405
406 shdx[ij] = dx[ij];
407 shdy[ij] = dy[ij];
408 shdz[ij] = dz[ij];
409
410#pragma unroll
411 for(int k = 0; k < LX; ++k){
412 ru[k] = u[ij + k*LX*LX + ele];
413 ruw[k] = 0.0;
414
415 rv[k] = v[ij + k*LX*LX + ele];
416 rvw[k] = 0.0;
417
418 rw[k] = w[ij + k*LX*LX + ele];
419 rww[k] = 0.0;
420 }
421
422
424#pragma unroll
425 for (int k = 0; k < LX; ++k){
426 const int ijk = ij + k*LX*LX;
427 const T G00 = g11[ijk+ele];
428 const T G11 = g22[ijk+ele];
429 const T G22 = g33[ijk+ele];
430 const T G01 = g12[ijk+ele];
431 const T G02 = g13[ijk+ele];
432 const T G12 = g23[ijk+ele];
433 const T H1 = h1[ijk+ele];
434 T uttmp = 0.0;
435 T vttmp = 0.0;
436 T wttmp = 0.0;
437 shu[ij] = ru[k];
438 shv[ij] = rv[k];
439 shw[ij] = rw[k];
440 for (int l = 0; l < LX; l++){
441 uttmp += shdz[k+l*LX] * ru[l];
442 vttmp += shdz[k+l*LX] * rv[l];
443 wttmp += shdz[k+l*LX] * rw[l];
444 }
446
447 T urtmp = 0.0;
448 T ustmp = 0.0;
449
450 T vrtmp = 0.0;
451 T vstmp = 0.0;
452
453 T wrtmp = 0.0;
454 T wstmp = 0.0;
455#pragma unroll
456 for (int l = 0; l < LX; l++){
457 urtmp += shdx[i+l*LX] * shu[l+j*LX];
458 ustmp += shdy[j+l*LX] * shu[i+l*LX];
459
460 vrtmp += shdx[i+l*LX] * shv[l+j*LX];
461 vstmp += shdy[j+l*LX] * shv[i+l*LX];
462
463 wrtmp += shdx[i+l*LX] * shw[l+j*LX];
464 wstmp += shdy[j+l*LX] * shw[i+l*LX];
465 }
466
467 shur[ij] = H1
468 * (G00 * urtmp
469 + G01 * ustmp
470 + G02 * uttmp);
471 shus[ij] = H1
472 * (G01 * urtmp
473 + G11 * ustmp
474 + G12 * uttmp);
475 rut = H1
476 * (G02 * urtmp
477 + G12 * ustmp
478 + G22 * uttmp);
479
480 shvr[ij] = H1
481 * (G00 * vrtmp
482 + G01 * vstmp
483 + G02 * vttmp);
484 shvs[ij] = H1
485 * (G01 * vrtmp
486 + G11 * vstmp
487 + G12 * vttmp);
488 rvt = H1
489 * (G02 * vrtmp
490 + G12 * vstmp
491 + G22 * vttmp);
492
493 shwr[ij] = H1
494 * (G00 * wrtmp
495 + G01 * wstmp
496 + G02 * wttmp);
497 shws[ij] = H1
498 * (G01 * wrtmp
499 + G11 * wstmp
500 + G12 * wttmp);
501 rwt = H1
502 * (G02 * wrtmp
503 + G12 * wstmp
504 + G22 * wttmp);
505
507
508 T uwijke = 0.0;
509 T vwijke = 0.0;
510 T wwijke = 0.0;
511#pragma unroll
512 for (int l = 0; l < LX; l++){
513 uwijke += shur[l+j*LX] * shdx[l+i*LX];
514 ruw[l] += rut * shdz[k+l*LX];
515 uwijke += shus[i+l*LX] * shdy[l + j*LX];
516
517 vwijke += shvr[l+j*LX] * shdx[l+i*LX];
518 rvw[l] += rvt * shdz[k+l*LX];
519 vwijke += shvs[i+l*LX] * shdy[l + j*LX];
520
521 wwijke += shwr[l+j*LX] * shdx[l+i*LX];
522 rww[l] += rwt * shdz[k+l*LX];
523 wwijke += shws[i+l*LX] * shdy[l + j*LX];
524 }
525 ruw[k] += uwijke;
526 rvw[k] += vwijke;
527 rww[k] += wwijke;
528 }
529#pragma unroll
530 for (int k = 0; k < LX; ++k){
531 au[ij + k*LX*LX + ele] = ruw[k];
532 av[ij + k*LX*LX + ele] = rvw[k];
533 aw[ij + k*LX*LX + ele] = rww[k];
534 }
535}
536
537template< typename T, const int LX >
539 T * __restrict__ av,
540 T * __restrict__ aw,
541 const T * __restrict__ u,
542 const T * __restrict__ v,
543 const T * __restrict__ w,
544 const T * __restrict__ dx,
545 const T * __restrict__ dy,
546 const T * __restrict__ dz,
547 const T * __restrict__ h1,
548 const T * __restrict__ g11,
549 const T * __restrict__ g22,
550 const T * __restrict__ g33,
551 const T * __restrict__ g12,
552 const T * __restrict__ g13,
553 const T * __restrict__ g23) {
554
555 __shared__ T shdx[LX * (LX+1)];
556 __shared__ T shdy[LX * (LX+1)];
557 __shared__ T shdz[LX * (LX+1)];
558
559 __shared__ T shu[LX * (LX+1)];
560 __shared__ T shur[LX * LX];
561 __shared__ T shus[LX * (LX+1)];
562
563 __shared__ T shv[LX * (LX+1)];
564 __shared__ T shvr[LX * LX];
565 __shared__ T shvs[LX * (LX+1)];
566
567 __shared__ T shw[LX * (LX+1)];
568 __shared__ T shwr[LX * LX];
569 __shared__ T shws[LX * (LX+1)];
570
571 T ru[LX];
572 T rv[LX];
573 T rw[LX];
574
575 T ruw[LX];
576 T rvw[LX];
577 T rww[LX];
578
579 T rut;
580 T rvt;
581 T rwt;
582
583 const int e = blockIdx.x;
584 const int j = threadIdx.y;
585 const int i = threadIdx.x;
586 const int ij = i + j*LX;
587 const int ij_p = i + j*(LX+1);
588 const int ele = e*LX*LX*LX;
589
590 shdx[ij_p] = dx[ij];
591 shdy[ij_p] = dy[ij];
592 shdz[ij_p] = dz[ij];
593
594#pragma unroll
595 for(int k = 0; k < LX; ++k){
596 ru[k] = u[ij + k*LX*LX + ele];
597 ruw[k] = 0.0;
598
599 rv[k] = v[ij + k*LX*LX + ele];
600 rvw[k] = 0.0;
601
602 rw[k] = w[ij + k*LX*LX + ele];
603 rww[k] = 0.0;
604 }
605
606
608#pragma unroll
609 for (int k = 0; k < LX; ++k){
610 const int ijk = ij + k*LX*LX;
611 const T G00 = g11[ijk+ele];
612 const T G11 = g22[ijk+ele];
613 const T G22 = g33[ijk+ele];
614 const T G01 = g12[ijk+ele];
615 const T G02 = g13[ijk+ele];
616 const T G12 = g23[ijk+ele];
617 const T H1 = h1[ijk+ele];
618 T uttmp = 0.0;
619 T vttmp = 0.0;
620 T wttmp = 0.0;
621 shu[ij_p] = ru[k];
622 shv[ij_p] = rv[k];
623 shw[ij_p] = rw[k];
624 for (int l = 0; l < LX; l++){
625 uttmp += shdz[k+l*(LX+1)] * ru[l];
626 vttmp += shdz[k+l*(LX+1)] * rv[l];
627 wttmp += shdz[k+l*(LX+1)] * rw[l];
628 }
630
631 T urtmp = 0.0;
632 T ustmp = 0.0;
633
634 T vrtmp = 0.0;
635 T vstmp = 0.0;
636
637 T wrtmp = 0.0;
638 T wstmp = 0.0;
639#pragma unroll
640 for (int l = 0; l < LX; l++){
641 urtmp += shdx[i+l*(LX+1)] * shu[l+j*(LX+1)];
642 ustmp += shdy[j+l*(LX+1)] * shu[i+l*(LX+1)];
643
644 vrtmp += shdx[i+l*(LX+1)] * shv[l+j*(LX+1)];
645 vstmp += shdy[j+l*(LX+1)] * shv[i+l*(LX+1)];
646
647 wrtmp += shdx[i+l*(LX+1)] * shw[l+j*(LX+1)];
648 wstmp += shdy[j+l*(LX+1)] * shw[i+l*(LX+1)];
649 }
650
651 shur[ij] = H1
652 * (G00 * urtmp
653 + G01 * ustmp
654 + G02 * uttmp);
655 shus[ij_p] = H1
656 * (G01 * urtmp
657 + G11 * ustmp
658 + G12 * uttmp);
659 rut = H1
660 * (G02 * urtmp
661 + G12 * ustmp
662 + G22 * uttmp);
663
664 shvr[ij] = H1
665 * (G00 * vrtmp
666 + G01 * vstmp
667 + G02 * vttmp);
668 shvs[ij_p] = H1
669 * (G01 * vrtmp
670 + G11 * vstmp
671 + G12 * vttmp);
672 rvt = H1
673 * (G02 * vrtmp
674 + G12 * vstmp
675 + G22 * vttmp);
676
677 shwr[ij] = H1
678 * (G00 * wrtmp
679 + G01 * wstmp
680 + G02 * wttmp);
681 shws[ij_p] = H1
682 * (G01 * wrtmp
683 + G11 * wstmp
684 + G12 * wttmp);
685 rwt = H1
686 * (G02 * wrtmp
687 + G12 * wstmp
688 + G22 * wttmp);
689
691
692 T uwijke = 0.0;
693 T vwijke = 0.0;
694 T wwijke = 0.0;
695#pragma unroll
696 for (int l = 0; l < LX; l++){
697 uwijke += shur[l+j*LX] * shdx[l+i*(LX+1)];
698 ruw[l] += rut * shdz[k+l*(LX+1)];
699 uwijke += shus[i+l*(LX+1)] * shdy[l + j*(LX+1)];
700
701 vwijke += shvr[l+j*LX] * shdx[l+i*(LX+1)];
702 rvw[l] += rvt * shdz[k+l*(LX+1)];
703 vwijke += shvs[i+l*(LX+1)] * shdy[l + j*(LX+1)];
704
705 wwijke += shwr[l+j*LX] * shdx[l+i*(LX+1)];
706 rww[l] += rwt * shdz[k+l*(LX+1)];
707 wwijke += shws[i+l*(LX+1)] * shdy[l + j*(LX+1)];
708 }
709 ruw[k] += uwijke;
710 rvw[k] += vwijke;
711 rww[k] += wwijke;
712 }
713#pragma unroll
714 for (int k = 0; k < LX; ++k){
715 au[ij + k*LX*LX + ele] = ruw[k];
716 av[ij + k*LX*LX + ele] = rvw[k];
717 aw[ij + k*LX*LX + ele] = rww[k];
718 }
719}
720
721template< typename T >
723 T * __restrict__ av,
724 T * __restrict__ aw,
725 const T * __restrict__ u,
726 const T * __restrict__ v,
727 const T * __restrict__ w,
728 const T * __restrict__ h2,
729 const T * __restrict__ B,
730 const int n) {
731
732 const int idx = blockIdx.x * blockDim.x + threadIdx.x;
733 const int str = blockDim.x * gridDim.x;
734
735 for (int i = idx; i < n; i += str) {
736 au[i] = au[i] + h2[i] * B[i] * u[i];
737 av[i] = av[i] + h2[i] * B[i] * v[i];
738 aw[i] = aw[i] + h2[i] * B[i] * w[i];
739 }
740
741}
742#endif // __MATH_AX_HELM_KERNEL_H__
T rv[LX]
__shared__ T shu[LX *LX]
const int ij_p
__shared__ T shdy[LX *LX]
__shared__ T shw[LX *LX]
T rvw[LX]
__global__ void T *__restrict__ T *__restrict__ aw
__shared__ T shdz[LX *LX]
T rww[LX]
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ w
__shared__ T shvs[LX *LX]
__shared__ T shws[LX *LX]
const int i
const int ij
shdx[ij]
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ u
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ dx
const int e
__global__ void T *__restrict__ av
T ruw[LX]
T ru[LX]
__shared__ T shwr[LX *LX]
const int ele
T rw[LX]
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ v
const int j
__shared__ T shus[LX *LX]
__syncthreads()
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ dz
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ h1
__shared__ T shv[LX *LX]
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ dy
__shared__ T shvr[LX *LX]
__shared__ T shur[LX *LX]
__global__ void ax_helm_kernel_vector_kstep(T *__restrict__ au, T *__restrict__ av, T *__restrict__ aw, const T *__restrict__ u, const T *__restrict__ v, const T *__restrict__ w, const T *__restrict__ dx, const T *__restrict__ dy, const T *__restrict__ dz, const T *__restrict__ h1, const T *__restrict__ g11, const T *__restrict__ g22, const T *__restrict__ g33, const T *__restrict__ g12, const T *__restrict__ g13, const T *__restrict__ g23)
__global__ void ax_helm_kernel_1d(T *__restrict__ w, const T *__restrict__ u, const T *__restrict__ dx, const T *__restrict__ dy, const T *__restrict__ dz, const T *__restrict__ dxt, const T *__restrict__ dyt, const T *__restrict__ dzt, const T *__restrict__ h1, const T *__restrict__ g11, const T *__restrict__ g22, const T *__restrict__ g33, const T *__restrict__ g12, const T *__restrict__ g13, const T *__restrict__ g23)
__global__ void ax_helm_kernel_kstep(T *__restrict__ w, const T *__restrict__ u, const T *__restrict__ dx, const T *__restrict__ dy, const T *__restrict__ dz, const T *__restrict__ h1, const T *__restrict__ g11, const T *__restrict__ g22, const T *__restrict__ g33, const T *__restrict__ g12, const T *__restrict__ g13, const T *__restrict__ g23)
__global__ void ax_helm_kernel_kstep_padded(T *__restrict__ w, const T *__restrict__ u, const T *__restrict__ dx, const T *__restrict__ dy, const T *__restrict__ dz, const T *__restrict__ h1, const T *__restrict__ g11, const T *__restrict__ g22, const T *__restrict__ g33, const T *__restrict__ g12, const T *__restrict__ g13, const T *__restrict__ g23)
__global__ void ax_helm_kernel_vector_kstep_padded(T *__restrict__ au, T *__restrict__ av, T *__restrict__ aw, const T *__restrict__ u, const T *__restrict__ v, const T *__restrict__ w, const T *__restrict__ dx, const T *__restrict__ dy, const T *__restrict__ dz, const T *__restrict__ h1, const T *__restrict__ g11, const T *__restrict__ g22, const T *__restrict__ g33, const T *__restrict__ g12, const T *__restrict__ g13, const T *__restrict__ g23)
__global__ void ax_helm_kernel_vector_part2(T *__restrict__ au, T *__restrict__ av, T *__restrict__ aw, const T *__restrict__ u, const T *__restrict__ v, const T *__restrict__ w, const T *__restrict__ h2, const T *__restrict__ B, const int n)
__shared__ T shdyt[LX *LX]
__shared__ T shdzt[LX *LX]
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ dzt
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ dyt
shdxt[ij]
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ dxt
__global__ void dirichlet_apply_scalar_kernel(const int *__restrict__ msk, T *__restrict__ x, const T g, const int m)
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ g23
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ g22
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ g13
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ g12
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ g33
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ g11