35 #include <hip/hip_runtime.h>
55 void *xbar,
int *
j,
int *n){
61 const int nt = 1024/pow2;
62 const dim3 glsc3_nthrds(pow2, nt, 1);
63 const dim3 glsc3_nblcks(((*n)+nt - 1)/nt, 1, 1);
64 const int glsc3_nb = ((*n) + nt - 1)/nt;
74 hipLaunchKernelGGL(HIP_KERNEL_NAME( glsc3_many_kernel<real> ),
75 glsc3_nblcks, glsc3_nthrds,
77 (
const real *) b, (
const real **) xx,
80 hipLaunchKernelGGL(HIP_KERNEL_NAME( glsc3_reduce_kernel<real> ),
85 hipMemcpyDeviceToDevice,
92 const dim3 vec_nthrds(1024, 1, 1);
93 const dim3 vec_nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
96 hipLaunchKernelGGL(HIP_KERNEL_NAME( project_on_vec_kernel<real> ),
97 vec_nblcks, vec_nthrds,
100 (
const real *) alpha, *
j, *n);
102 hipLaunchKernelGGL(HIP_KERNEL_NAME( glsc3_many_kernel<real> ),
103 glsc3_nblcks, glsc3_nthrds,
105 (
const real *) b, (
const real **) xx,
108 hipLaunchKernelGGL(HIP_KERNEL_NAME( glsc3_reduce_kernel<real> ),
113 hipMemcpyDeviceToDevice,
120 hipLaunchKernelGGL(HIP_KERNEL_NAME(project_on_vec_kernel<real> ),
121 vec_nblcks, vec_nthrds,
124 (
const real *) alpha, *
j, *n);
128 void *
w,
void *xm,
int *
j,
int *n,
real *nrm){
134 const int nt = 1024/pow2;
135 const dim3 glsc3_nthrds(pow2, nt, 1);
136 const dim3 glsc3_nblcks(((*n)+nt - 1)/nt, 1, 1);
137 const int glsc3_nb = ((*n) + nt - 1)/nt;
147 hipLaunchKernelGGL(HIP_KERNEL_NAME( glsc3_many_kernel<real> ),
148 glsc3_nblcks, glsc3_nthrds,
150 (
const real *) b, (
const real **) xx,
153 hipLaunchKernelGGL(HIP_KERNEL_NAME( glsc3_reduce_kernel<real> ),
158 hipMemcpyDeviceToDevice,
165 sizeof(
real), hipMemcpyDeviceToHost,
170 const dim3 vec_nthrds(1024, 1, 1);
171 const dim3 vec_nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
174 hipLaunchKernelGGL( HIP_KERNEL_NAME( project_ortho_vec_kernel<real> ),
178 (
const real *) alpha, *
j, *n);
181 hipLaunchKernelGGL(HIP_KERNEL_NAME( glsc3_many_kernel<real> ),
183 (
const real *) b, (
const real **) xx,
186 hipLaunchKernelGGL(HIP_KERNEL_NAME( glsc3_reduce_kernel<real> ),
191 hipMemcpyDeviceToDevice,
198 hipLaunchKernelGGL( HIP_KERNEL_NAME( project_ortho_vec_kernel<real> ),
202 (
const real *) alpha, *
j, *n);
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ const T *__restrict__ w
void device_mpi_allreduce_inplace(void *buf_d, int count, int nbytes, int op)
void hip_project_ortho(void *alpha, void *b, void *xx, void *bb, void *w, void *xm, int *j, int *n, real *nrm)
void hip_project_on(void *alpha, void *b, void *xx, void *bb, void *mult, void *xbar, int *j, int *n)