Neko 0.9.99
A portable framework for high-order spectral element flow simulations
Loading...
Searching...
No Matches
ax_helm_kernel.h
Go to the documentation of this file.
1#ifndef __MATH_AX_HELM_KERNEL_H__
2#define __MATH_AX_HELM_KERNEL_H__
3/*
4 Copyright (c) 2021-2024, The Neko Authors
5 All rights reserved.
6
7 Redistribution and use in source and binary forms, with or without
8 modification, are permitted provided that the following conditions
9 are met:
10
11 * Redistributions of source code must retain the above copyright
12 notice, this list of conditions and the following disclaimer.
13
14 * Redistributions in binary form must reproduce the above
15 copyright notice, this list of conditions and the following
16 disclaimer in the documentation and/or other materials provided
17 with the distribution.
18
19 * Neither the name of the authors nor the names of its
20 contributors may be used to endorse or promote products derived
21 from this software without specific prior written permission.
22
23 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 POSSIBILITY OF SUCH DAMAGE.
35*/
36
41template< typename T, const int LX, const int CHUNKS >
43 const T * __restrict__ u,
44 const T * __restrict__ dx,
45 const T * __restrict__ dy,
46 const T * __restrict__ dz,
47 const T * __restrict__ dxt,
48 const T * __restrict__ dyt,
49 const T * __restrict__ dzt,
50 const T * __restrict__ h1,
51 const T * __restrict__ g11,
52 const T * __restrict__ g22,
53 const T * __restrict__ g33,
54 const T * __restrict__ g12,
55 const T * __restrict__ g13,
56 const T * __restrict__ g23) {
57
61
65
70
71 const int e = blockIdx.x;
72 const int iii = threadIdx.x;
73 const int nchunks = (LX * LX * LX - 1)/CHUNKS + 1;
74
75 if (iii<LX*LX) {
76 shdx[iii] = dx[iii];
77 shdy[iii] = dy[iii];
78 shdz[iii] = dz[iii];
79 }
80
81 {
82 int i = iii;
83 while (i < LX * LX * LX){
84 shu[i] = u[i+e*LX*LX*LX];
85 i = i + CHUNKS;
86 }
87 }
88
90
91 if (iii<LX*LX){
92 shdxt[iii] = dxt[iii];
93 shdyt[iii] = dyt[iii];
94 shdzt[iii] = dzt[iii];
95 }
96
97 for (int n=0; n<nchunks; n++){
98 const int ijk = iii+n*CHUNKS;
99 const int jk = ijk/LX;
100 const int i = ijk-jk*LX;
101 const int k = jk/LX;
102 const int j = jk-k*LX;
103 if (i<LX && j<LX && k<LX && ijk < LX*LX*LX){
104 const T G00 = g11[ijk+e*LX*LX*LX];
105 const T G11 = g22[ijk+e*LX*LX*LX];
106 const T G22 = g33[ijk+e*LX*LX*LX];
107 const T G01 = g12[ijk+e*LX*LX*LX];
108 const T G02 = g13[ijk+e*LX*LX*LX];
109 const T G12 = g23[ijk+e*LX*LX*LX];
110 const T H1 = h1[ijk+e*LX*LX*LX];
111 T rtmp = 0.0;
112 T stmp = 0.0;
113 T ttmp = 0.0;
114#pragma unroll
115 for (int l = 0; l<LX; l++){
116 rtmp = rtmp + shdx[i+l*LX] * shu[l+j*LX+k*LX*LX];
117 stmp = stmp + shdy[j+l*LX] * shu[i+l*LX+k*LX*LX];
118 ttmp = ttmp + shdz[k+l*LX] * shu[i+j*LX+l*LX*LX];
119 }
120 shur[ijk] = H1 * (G00 * rtmp + G01 * stmp + G02 * ttmp);
121 shus[ijk] = H1 * (G01 * rtmp + G11 * stmp + G12 * ttmp);
122 shut[ijk] = H1 * (G02 * rtmp + G12 * stmp + G22 * ttmp);
123 }
124 }
125
127
128 for (int n=0; n<nchunks; n++){
129 const int ijk = iii+n*CHUNKS;
130 const int jk = ijk/LX;
131 const int k = jk/LX;
132 const int j = jk-k*LX;
133 const int i = ijk-jk*LX;
134 if (i<LX && j<LX && k<LX && ijk <LX*LX*LX){
135 T wijke = 0.0;
136#pragma unroll
137 for (int l = 0; l<LX; l++){
138 wijke = wijke
139 + shdxt[i+l*LX] * shur[l+j*LX+k*LX*LX]
140 + shdyt[j+l*LX] * shus[i+l*LX+k*LX*LX]
141 + shdzt[k+l*LX] * shut[i+j*LX+l*LX*LX];
142 }
143 w[ijk+e*LX*LX*LX] = wijke;
144 }
145 }
146}
147
148template< typename T, const int LX >
162
163 __shared__ T shdx[LX * LX];
166
170
174
175 const int e = blockIdx.x;
176 const int j = threadIdx.y;
177 const int i = threadIdx.x;
178 const int ij = i + j*LX;
179 const int ele = e*LX*LX*LX;
180
182 shdy[ij] = dy[ij];
183 shdz[ij] = dz[ij];
184
185#pragma unroll
186 for(int k = 0; k < LX; ++k){
187 ru[k] = u[ij + k*LX*LX + ele];
188 rw[k] = 0.0;
189 }
190
191
193#pragma unroll
194 for (int k = 0; k < LX; ++k){
195 const int ijk = ij + k*LX*LX;
196 const T G00 = g11[ijk+ele];
197 const T G11 = g22[ijk+ele];
198 const T G22 = g33[ijk+ele];
199 const T G01 = g12[ijk+ele];
200 const T G02 = g13[ijk+ele];
201 const T G12 = g23[ijk+ele];
202 const T H1 = h1[ijk+ele];
203 T ttmp = 0.0;
204 shu[ij] = ru[k];
205 for (int l = 0; l < LX; l++){
206 ttmp += shdz[k+l*LX] * ru[l];
207 }
209
210 T rtmp = 0.0;
211 T stmp = 0.0;
212#pragma unroll
213 for (int l = 0; l < LX; l++){
214 rtmp += shdx[i+l*LX] * shu[l+j*LX];
215 stmp += shdy[j+l*LX] * shu[i+l*LX];
216 }
217 shur[ij] = H1
218 * (G00 * rtmp
219 + G01 * stmp
220 + G02 * ttmp);
221 shus[ij] = H1
222 * (G01 * rtmp
223 + G11 * stmp
224 + G12 * ttmp);
225 rut = H1
226 * (G02 * rtmp
227 + G12 * stmp
228 + G22 * ttmp);
229
231
232 T wijke = 0.0;
233#pragma unroll
234 for (int l = 0; l < LX; l++){
235 wijke += shur[l+j*LX] * shdx[l+i*LX];
236 rw[l] += rut * shdz[k+l*LX];
237 wijke += shus[i+l*LX] * shdy[l + j*LX];
238 }
239 rw[k] += wijke;
240 }
241#pragma unroll
242 for (int k = 0; k < LX; ++k){
243 w[ij + k*LX*LX + ele] = rw[k];
244 }
245}
246
252template< typename T, const int LX >
266
267 __shared__ T shdx[LX * (LX+1)];
268 __shared__ T shdy[LX * (LX+1)];
269 __shared__ T shdz[LX * (LX+1)];
270
271 __shared__ T shu[LX * (LX+1)];
272 __shared__ T shur[LX * LX]; // only accessed using fastest dimension
273 __shared__ T shus[LX * (LX+1)];
274
275 T ru[LX];
276 T rw[LX];
277 T rut;
278
279 const int e = blockIdx.x;
280 const int j = threadIdx.y;
281 const int i = threadIdx.x;
282 const int ij = i + j*LX;
283 const int ij_p = i + j*(LX+1);
284 const int ele = e*LX*LX*LX;
285
286 shdx[ij_p] = dx[ij];
287 shdy[ij_p] = dy[ij];
288 shdz[ij_p] = dz[ij];
289
290#pragma unroll
291 for(int k = 0; k < LX; ++k){
292 ru[k] = u[ij + k*LX*LX + ele];
293 rw[k] = 0.0;
294 }
295
296
298#pragma unroll
299 for (int k = 0; k < LX; ++k){
300 const int ijk = ij + k*LX*LX;
301 const T G00 = g11[ijk+ele];
302 const T G11 = g22[ijk+ele];
303 const T G22 = g33[ijk+ele];
304 const T G01 = g12[ijk+ele];
305 const T G02 = g13[ijk+ele];
306 const T G12 = g23[ijk+ele];
307 const T H1 = h1[ijk+ele];
308 T ttmp = 0.0;
309 shu[ij_p] = ru[k];
310 for (int l = 0; l < LX; l++){
311 ttmp += shdz[k+l*(LX+1)] * ru[l];
312 }
314
315 T rtmp = 0.0;
316 T stmp = 0.0;
317#pragma unroll
318 for (int l = 0; l < LX; l++){
319 rtmp += shdx[i+l*(LX+1)] * shu[l+j*(LX+1)];
320 stmp += shdy[j+l*(LX+1)] * shu[i+l*(LX+1)];
321 }
322 shur[ij] = H1
323 * (G00 * rtmp
324 + G01 * stmp
325 + G02 * ttmp);
326 shus[ij_p] = H1
327 * (G01 * rtmp
328 + G11 * stmp
329 + G12 * ttmp);
330 rut = H1
331 * (G02 * rtmp
332 + G12 * stmp
333 + G22 * ttmp);
334
336
337 T wijke = 0.0;
338#pragma unroll
339 for (int l = 0; l < LX; l++){
340 wijke += shur[l+j*LX] * shdx[l+i*(LX+1)];
341 rw[l] += rut * shdz[k+l*(LX+1)];
342 wijke += shus[i+l*(LX+1)] * shdy[l + j*(LX+1)];
343 }
344 rw[k] += wijke;
345 }
346#pragma unroll
347 for (int k = 0; k < LX; ++k){
348 w[ij + k*LX*LX + ele] = rw[k];
349 }
350}
351
352/*
353 * Vector versions
354 */
355
356template< typename T, const int LX >
374
375 __shared__ T shdx[LX * LX];
376 __shared__ T shdy[LX * LX];
377 __shared__ T shdz[LX * LX];
378
379 __shared__ T shu[LX * LX];
380 __shared__ T shur[LX * LX];
381 __shared__ T shus[LX * LX];
382
386
390
391 T ru[LX];
393 T rw[LX];
394
398
399 T rut;
402
403 const int e = blockIdx.x;
404 const int j = threadIdx.y;
405 const int i = threadIdx.x;
406 const int ij = i + j*LX;
407 const int ele = e*LX*LX*LX;
408
409 shdx[ij] = dx[ij];
410 shdy[ij] = dy[ij];
411 shdz[ij] = dz[ij];
412
413#pragma unroll
414 for(int k = 0; k < LX; ++k){
415 ru[k] = u[ij + k*LX*LX + ele];
416 ruw[k] = 0.0;
417
418 rv[k] = v[ij + k*LX*LX + ele];
419 rvw[k] = 0.0;
420
421 rw[k] = w[ij + k*LX*LX + ele];
422 rww[k] = 0.0;
423 }
424
425
427#pragma unroll
428 for (int k = 0; k < LX; ++k){
429 const int ijk = ij + k*LX*LX;
430 const T G00 = g11[ijk+ele];
431 const T G11 = g22[ijk+ele];
432 const T G22 = g33[ijk+ele];
433 const T G01 = g12[ijk+ele];
434 const T G02 = g13[ijk+ele];
435 const T G12 = g23[ijk+ele];
436 const T H1 = h1[ijk+ele];
437 T uttmp = 0.0;
438 T vttmp = 0.0;
439 T wttmp = 0.0;
440 shu[ij] = ru[k];
441 shv[ij] = rv[k];
442 shw[ij] = rw[k];
443 for (int l = 0; l < LX; l++){
444 uttmp += shdz[k+l*LX] * ru[l];
445 vttmp += shdz[k+l*LX] * rv[l];
446 wttmp += shdz[k+l*LX] * rw[l];
447 }
449
450 T urtmp = 0.0;
451 T ustmp = 0.0;
452
453 T vrtmp = 0.0;
454 T vstmp = 0.0;
455
456 T wrtmp = 0.0;
457 T wstmp = 0.0;
458#pragma unroll
459 for (int l = 0; l < LX; l++){
460 urtmp += shdx[i+l*LX] * shu[l+j*LX];
461 ustmp += shdy[j+l*LX] * shu[i+l*LX];
462
463 vrtmp += shdx[i+l*LX] * shv[l+j*LX];
464 vstmp += shdy[j+l*LX] * shv[i+l*LX];
465
466 wrtmp += shdx[i+l*LX] * shw[l+j*LX];
467 wstmp += shdy[j+l*LX] * shw[i+l*LX];
468 }
469
470 shur[ij] = H1
471 * (G00 * urtmp
472 + G01 * ustmp
473 + G02 * uttmp);
474 shus[ij] = H1
475 * (G01 * urtmp
476 + G11 * ustmp
477 + G12 * uttmp);
478 rut = H1
479 * (G02 * urtmp
480 + G12 * ustmp
481 + G22 * uttmp);
482
483 shvr[ij] = H1
484 * (G00 * vrtmp
485 + G01 * vstmp
486 + G02 * vttmp);
487 shvs[ij] = H1
488 * (G01 * vrtmp
489 + G11 * vstmp
490 + G12 * vttmp);
491 rvt = H1
492 * (G02 * vrtmp
493 + G12 * vstmp
494 + G22 * vttmp);
495
496 shwr[ij] = H1
497 * (G00 * wrtmp
498 + G01 * wstmp
499 + G02 * wttmp);
500 shws[ij] = H1
501 * (G01 * wrtmp
502 + G11 * wstmp
503 + G12 * wttmp);
504 rwt = H1
505 * (G02 * wrtmp
506 + G12 * wstmp
507 + G22 * wttmp);
508
510
511 T uwijke = 0.0;
512 T vwijke = 0.0;
513 T wwijke = 0.0;
514#pragma unroll
515 for (int l = 0; l < LX; l++){
516 uwijke += shur[l+j*LX] * shdx[l+i*LX];
517 ruw[l] += rut * shdz[k+l*LX];
518 uwijke += shus[i+l*LX] * shdy[l + j*LX];
519
520 vwijke += shvr[l+j*LX] * shdx[l+i*LX];
521 rvw[l] += rvt * shdz[k+l*LX];
522 vwijke += shvs[i+l*LX] * shdy[l + j*LX];
523
524 wwijke += shwr[l+j*LX] * shdx[l+i*LX];
525 rww[l] += rwt * shdz[k+l*LX];
526 wwijke += shws[i+l*LX] * shdy[l + j*LX];
527 }
528 ruw[k] += uwijke;
529 rvw[k] += vwijke;
530 rww[k] += wwijke;
531 }
532#pragma unroll
533 for (int k = 0; k < LX; ++k){
534 au[ij + k*LX*LX + ele] = ruw[k];
535 av[ij + k*LX*LX + ele] = rvw[k];
536 aw[ij + k*LX*LX + ele] = rww[k];
537 }
538}
539
540template< typename T, const int LX >
543 T * __restrict__ av,
544 T * __restrict__ aw,
558
559 __shared__ T shdx[LX * (LX+1)];
560 __shared__ T shdy[LX * (LX+1)];
561 __shared__ T shdz[LX * (LX+1)];
562
563 __shared__ T shu[LX * (LX+1)];
564 __shared__ T shur[LX * LX];
565 __shared__ T shus[LX * (LX+1)];
566
567 __shared__ T shv[LX * (LX+1)];
568 __shared__ T shvr[LX * LX];
569 __shared__ T shvs[LX * (LX+1)];
570
571 __shared__ T shw[LX * (LX+1)];
572 __shared__ T shwr[LX * LX];
573 __shared__ T shws[LX * (LX+1)];
574
575 T ru[LX];
576 T rv[LX];
577 T rw[LX];
578
579 T ruw[LX];
580 T rvw[LX];
581 T rww[LX];
582
583 T rut;
584 T rvt;
585 T rwt;
586
587 const int e = blockIdx.x;
588 const int j = threadIdx.y;
589 const int i = threadIdx.x;
590 const int ij = i + j*LX;
591 const int ij_p = i + j*(LX+1);
592 const int ele = e*LX*LX*LX;
593
594 shdx[ij_p] = dx[ij];
595 shdy[ij_p] = dy[ij];
596 shdz[ij_p] = dz[ij];
597
598#pragma unroll
599 for(int k = 0; k < LX; ++k){
600 ru[k] = u[ij + k*LX*LX + ele];
601 ruw[k] = 0.0;
602
603 rv[k] = v[ij + k*LX*LX + ele];
604 rvw[k] = 0.0;
605
606 rw[k] = w[ij + k*LX*LX + ele];
607 rww[k] = 0.0;
608 }
609
610
612#pragma unroll
613 for (int k = 0; k < LX; ++k){
614 const int ijk = ij + k*LX*LX;
615 const T G00 = g11[ijk+ele];
616 const T G11 = g22[ijk+ele];
617 const T G22 = g33[ijk+ele];
618 const T G01 = g12[ijk+ele];
619 const T G02 = g13[ijk+ele];
620 const T G12 = g23[ijk+ele];
621 const T H1 = h1[ijk+ele];
622 T uttmp = 0.0;
623 T vttmp = 0.0;
624 T wttmp = 0.0;
625 shu[ij_p] = ru[k];
626 shv[ij_p] = rv[k];
627 shw[ij_p] = rw[k];
628 for (int l = 0; l < LX; l++){
629 uttmp += shdz[k+l*(LX+1)] * ru[l];
630 vttmp += shdz[k+l*(LX+1)] * rv[l];
631 wttmp += shdz[k+l*(LX+1)] * rw[l];
632 }
634
635 T urtmp = 0.0;
636 T ustmp = 0.0;
637
638 T vrtmp = 0.0;
639 T vstmp = 0.0;
640
641 T wrtmp = 0.0;
642 T wstmp = 0.0;
643#pragma unroll
644 for (int l = 0; l < LX; l++){
645 urtmp += shdx[i+l*(LX+1)] * shu[l+j*(LX+1)];
646 ustmp += shdy[j+l*(LX+1)] * shu[i+l*(LX+1)];
647
648 vrtmp += shdx[i+l*(LX+1)] * shv[l+j*(LX+1)];
649 vstmp += shdy[j+l*(LX+1)] * shv[i+l*(LX+1)];
650
651 wrtmp += shdx[i+l*(LX+1)] * shw[l+j*(LX+1)];
652 wstmp += shdy[j+l*(LX+1)] * shw[i+l*(LX+1)];
653 }
654
655 shur[ij] = H1
656 * (G00 * urtmp
657 + G01 * ustmp
658 + G02 * uttmp);
659 shus[ij_p] = H1
660 * (G01 * urtmp
661 + G11 * ustmp
662 + G12 * uttmp);
663 rut = H1
664 * (G02 * urtmp
665 + G12 * ustmp
666 + G22 * uttmp);
667
668 shvr[ij] = H1
669 * (G00 * vrtmp
670 + G01 * vstmp
671 + G02 * vttmp);
672 shvs[ij_p] = H1
673 * (G01 * vrtmp
674 + G11 * vstmp
675 + G12 * vttmp);
676 rvt = H1
677 * (G02 * vrtmp
678 + G12 * vstmp
679 + G22 * vttmp);
680
681 shwr[ij] = H1
682 * (G00 * wrtmp
683 + G01 * wstmp
684 + G02 * wttmp);
685 shws[ij_p] = H1
686 * (G01 * wrtmp
687 + G11 * wstmp
688 + G12 * wttmp);
689 rwt = H1
690 * (G02 * wrtmp
691 + G12 * wstmp
692 + G22 * wttmp);
693
695
696 T uwijke = 0.0;
697 T vwijke = 0.0;
698 T wwijke = 0.0;
699#pragma unroll
700 for (int l = 0; l < LX; l++){
701 uwijke += shur[l+j*LX] * shdx[l+i*(LX+1)];
702 ruw[l] += rut * shdz[k+l*(LX+1)];
703 uwijke += shus[i+l*(LX+1)] * shdy[l + j*(LX+1)];
704
705 vwijke += shvr[l+j*LX] * shdx[l+i*(LX+1)];
706 rvw[l] += rvt * shdz[k+l*(LX+1)];
707 vwijke += shvs[i+l*(LX+1)] * shdy[l + j*(LX+1)];
708
709 wwijke += shwr[l+j*LX] * shdx[l+i*(LX+1)];
710 rww[l] += rwt * shdz[k+l*(LX+1)];
711 wwijke += shws[i+l*(LX+1)] * shdy[l + j*(LX+1)];
712 }
713 ruw[k] += uwijke;
714 rvw[k] += vwijke;
715 rww[k] += wwijke;
716 }
717#pragma unroll
718 for (int k = 0; k < LX; ++k){
719 au[ij + k*LX*LX + ele] = ruw[k];
720 av[ij + k*LX*LX + ele] = rvw[k];
721 aw[ij + k*LX*LX + ele] = rww[k];
722 }
723}
724
725template< typename T >
727 T * __restrict__ av,
728 T * __restrict__ aw,
729 const T * __restrict__ u,
730 const T * __restrict__ v,
731 const T * __restrict__ w,
732 const T * __restrict__ h2,
733 const T * __restrict__ B,
734 const int n) {
735
736 const int idx = blockIdx.x * blockDim.x + threadIdx.x;
737 const int str = blockDim.x * gridDim.x;
738
739 for (int i = idx; i < n; i += str) {
740 au[i] = au[i] + h2[i] * B[i] * u[i];
741 av[i] = av[i] + h2[i] * B[i] * v[i];
742 aw[i] = aw[i] + h2[i] * B[i] * w[i];
743 }
744
745}
746#endif // __MATH_AX_HELM_KERNEL_H__
T rv[LX]
__shared__ T shu[LX *LX]
const int ij_p
__shared__ T shdy[LX *LX]
__shared__ T shw[LX *LX]
T rvw[LX]
__global__ void T *__restrict__ T *__restrict__ aw
__shared__ T shdz[LX *LX]
T rww[LX]
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ w
__shared__ T shvs[LX *LX]
__shared__ T shws[LX *LX]
const int i
const int ij
shdx[ij]
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ u
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ dx
const int e
__global__ void T *__restrict__ av
T ruw[LX]
T ru[LX]
__shared__ T shwr[LX *LX]
const int ele
T rw[LX]
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ v
const int j
__shared__ T shus[LX *LX]
__syncthreads()
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ dz
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ h1
__shared__ T shv[LX *LX]
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ dy
__shared__ T shvr[LX *LX]
__shared__ T shur[LX *LX]
__global__ void ax_helm_kernel_vector_kstep(T *__restrict__ au, T *__restrict__ av, T *__restrict__ aw, const T *__restrict__ u, const T *__restrict__ v, const T *__restrict__ w, const T *__restrict__ dx, const T *__restrict__ dy, const T *__restrict__ dz, const T *__restrict__ h1, const T *__restrict__ g11, const T *__restrict__ g22, const T *__restrict__ g33, const T *__restrict__ g12, const T *__restrict__ g13, const T *__restrict__ g23)
__global__ void ax_helm_kernel_1d(T *__restrict__ w, const T *__restrict__ u, const T *__restrict__ dx, const T *__restrict__ dy, const T *__restrict__ dz, const T *__restrict__ dxt, const T *__restrict__ dyt, const T *__restrict__ dzt, const T *__restrict__ h1, const T *__restrict__ g11, const T *__restrict__ g22, const T *__restrict__ g33, const T *__restrict__ g12, const T *__restrict__ g13, const T *__restrict__ g23)
__global__ void ax_helm_kernel_kstep(T *__restrict__ w, const T *__restrict__ u, const T *__restrict__ dx, const T *__restrict__ dy, const T *__restrict__ dz, const T *__restrict__ h1, const T *__restrict__ g11, const T *__restrict__ g22, const T *__restrict__ g33, const T *__restrict__ g12, const T *__restrict__ g13, const T *__restrict__ g23)
__global__ void ax_helm_kernel_kstep_padded(T *__restrict__ w, const T *__restrict__ u, const T *__restrict__ dx, const T *__restrict__ dy, const T *__restrict__ dz, const T *__restrict__ h1, const T *__restrict__ g11, const T *__restrict__ g22, const T *__restrict__ g33, const T *__restrict__ g12, const T *__restrict__ g13, const T *__restrict__ g23)
__global__ void ax_helm_kernel_vector_kstep_padded(T *__restrict__ au, T *__restrict__ av, T *__restrict__ aw, const T *__restrict__ u, const T *__restrict__ v, const T *__restrict__ w, const T *__restrict__ dx, const T *__restrict__ dy, const T *__restrict__ dz, const T *__restrict__ h1, const T *__restrict__ g11, const T *__restrict__ g22, const T *__restrict__ g33, const T *__restrict__ g12, const T *__restrict__ g13, const T *__restrict__ g23)
__global__ void ax_helm_kernel_vector_part2(T *__restrict__ au, T *__restrict__ av, T *__restrict__ aw, const T *__restrict__ u, const T *__restrict__ v, const T *__restrict__ w, const T *__restrict__ h2, const T *__restrict__ B, const int n)
__shared__ T shdyt[LX *LX]
__shared__ T shdzt[LX *LX]
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ dzt
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ dyt
shdxt[ij]
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ dxt
__global__ void dirichlet_apply_scalar_kernel(const int *__restrict__ msk, T *__restrict__ x, const T g, const int m)
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ g23
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ g22
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ g13
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ g12
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ g33
__global__ void __launch_bounds__(LX *LX, 3) ax_helm_kernel_kstep(T *__restrict__ w
__global__ void const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ g11