52 void *b1,
void *b2,
void *b3,
55 const dim3 nthrds(1024, 1, 1);
56 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
59 fusedcg_cpld_part1_kernel<real>
60 <<<nblcks, nthrds, 0, stream>>>((
real *) a1, (
real *) a2, (
real *) a3,
67 void *z1,
void *z2,
void *z3,
68 void *po1,
void *po2,
void *po3,
71 const dim3 nthrds(1024, 1, 1);
72 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
75 fusedcg_cpld_update_p_kernel<real>
76 <<<nblcks, nthrds, 0, stream>>>((
real *) p1, (
real *) p2, (
real *) p3,
85 void *p1,
void *p2,
void *p3,
86 void *alpha,
int *p_cur,
int *n) {
88 const dim3 nthrds(1024, 1, 1);
89 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
92 fusedcg_cpld_update_x_kernel<real>
93 <<<nblcks, nthrds, 0, stream>>>((
real *) x1, (
real *) x2, (
real *) x3,
94 (
const real **) p1, (
const real **) p2,
95 (
const real **) p3, (
const real *) alpha,
101 void *c1,
void *c2,
void *c3,
void *alpha_d ,
102 real *alpha,
int *p_cur,
int * n) {
104 const dim3 nthrds(1024, 1, 1);
105 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
106 const int nb = ((*n) + 1024 - 1)/ 1024;
125 real *alpha_d_p_cur = ((
real *) alpha_d) + ((*p_cur - 1));
127 sizeof(
real), cudaMemcpyHostToDevice,
131 fusedcg_cpld_part2_kernel<real>
132 <<<nblcks, nthrds, 0, stream>>>((
real *) a1, (
real *) a2, (
real *) a3,
141 #ifdef HAVE_DEVICE_MPI
142 cudaStreamSynchronize(stream);
147 cudaMemcpyDeviceToHost, stream));
148 cudaStreamSynchronize(stream);
void device_mpi_allreduce(void *buf_d, void *buf, int count, int nbytes, int op)
void cuda_fusedcg_cpld_update_x(void *x1, void *x2, void *x3, void *p1, void *p2, void *p3, void *alpha, int *p_cur, int *n)
void cuda_fusedcg_cpld_part1(void *a1, void *a2, void *a3, void *b1, void *b2, void *b3, void *tmp, int *n)
real cuda_fusedcg_cpld_part2(void *a1, void *a2, void *a3, void *b, void *c1, void *c2, void *c3, void *alpha_d, real *alpha, int *p_cur, int *n)
real * fusedcg_cpld_buf_d
void cuda_fusedcg_cpld_update_p(void *p1, void *p2, void *p3, void *z1, void *z2, void *z3, void *po1, void *po2, void *po3, real *beta, int *n)