52 void *b1,
void *b2,
void *b3,
55 const dim3 nthrds(1024, 1, 1);
56 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
59 hipLaunchKernelGGL(HIP_KERNEL_NAME(fusedcg_cpld_part1_kernel<real>),
60 nblcks, nthrds, 0, stream,
68 void *z1,
void *z2,
void *z3,
69 void *po1,
void *po2,
void *po3,
72 const dim3 nthrds(1024, 1, 1);
73 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
77 hipLaunchKernelGGL(HIP_KERNEL_NAME(fusedcg_cpld_update_p_kernel<real>),
78 nblcks, nthrds, 0, stream,
88 void *p1,
void *p2,
void *p3,
89 void *alpha,
int *p_cur,
int *n) {
91 const dim3 nthrds(1024, 1, 1);
92 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
96 hipLaunchKernelGGL(HIP_KERNEL_NAME(fusedcg_cpld_update_x_kernel<real>),
97 nblcks, nthrds, 0, stream,
99 (
const real **) p1, (
const real **) p2,
100 (
const real **) p3, (
const real *) alpha,
106 void *c1,
void *c2,
void *c3,
void *alpha_d ,
107 real *alpha,
int *p_cur,
int * n) {
109 const dim3 nthrds(1024, 1, 1);
110 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
111 const int nb = ((*n) + 1024 - 1)/ 1024;
130 real *alpha_d_p_cur = ((
real *) alpha_d) + ((*p_cur - 1));
132 sizeof(
real), hipMemcpyHostToDevice,
135 hipLaunchKernelGGL(HIP_KERNEL_NAME(fusedcg_cpld_part2_kernel<real>),
136 nblcks, nthrds, 0, stream,
143 hipLaunchKernelGGL(HIP_KERNEL_NAME(reduce_kernel<real>),
147 #ifdef HAVE_DEVICE_MPI
148 hipStreamSynchronize(stream);
153 hipMemcpyDeviceToHost, stream));
154 hipStreamSynchronize(stream);
void device_mpi_allreduce(void *buf_d, void *buf, int count, int nbytes, int op)
void hip_fusedcg_cpld_update_x(void *x1, void *x2, void *x3, void *p1, void *p2, void *p3, void *alpha, int *p_cur, int *n)
void hip_fusedcg_cpld_update_p(void *p1, void *p2, void *p3, void *z1, void *z2, void *z3, void *po1, void *po2, void *po3, real *beta, int *n)
real hip_fusedcg_cpld_part2(void *a1, void *a2, void *a3, void *b, void *c1, void *c2, void *c3, void *alpha_d, real *alpha, int *p_cur, int *n)
void hip_fusedcg_cpld_part1(void *a1, void *a2, void *a3, void *b1, void *b2, void *b3, void *tmp, int *n)
real * fusedcg_cpld_buf_d