36 #include <OpenCL/cl.h>
47 #include "tensor_kernel.cl.h"
50 void *A,
void *Bt,
void *Ct,
int *nel) {
56 cl_kernel kernel = clCreateKernel(
tensor_program,
"tnsr3d_kernel", &err);
59 CL_CHECK(clSetKernelArg(kernel, 0,
sizeof(cl_mem), (
void *) &
v));
60 CL_CHECK(clSetKernelArg(kernel, 1,
sizeof(
int), nv));
61 CL_CHECK(clSetKernelArg(kernel, 2,
sizeof(cl_mem), (
void *) &
u));
62 CL_CHECK(clSetKernelArg(kernel, 3,
sizeof(
int), nu));
63 CL_CHECK(clSetKernelArg(kernel, 4,
sizeof(cl_mem), (
void *) &A));
64 CL_CHECK(clSetKernelArg(kernel, 5,
sizeof(cl_mem), (
void *) &Bt));
65 CL_CHECK(clSetKernelArg(kernel, 6,
sizeof(cl_mem), (
void *) &Ct));
67 const size_t global_item_size = 256 * (*nel);
68 const size_t local_item_size = 256;
71 NULL, &global_item_size, &local_item_size,
76 void *A,
void *Bt,
void *Ct,
int *elements,
83 cl_kernel kernel = clCreateKernel(
tensor_program,
"tnsr3d_el_kernel", &err);
86 CL_CHECK(clSetKernelArg(kernel, 0,
sizeof(cl_mem), (
void *) &
v));
87 CL_CHECK(clSetKernelArg(kernel, 1,
sizeof(
int), nv));
88 CL_CHECK(clSetKernelArg(kernel, 2,
sizeof(cl_mem), (
void *) &
u));
89 CL_CHECK(clSetKernelArg(kernel, 3,
sizeof(
int), nu));
90 CL_CHECK(clSetKernelArg(kernel, 4,
sizeof(cl_mem), (
void *) &A));
91 CL_CHECK(clSetKernelArg(kernel, 5,
sizeof(cl_mem), (
void *) &Bt));
92 CL_CHECK(clSetKernelArg(kernel, 6,
sizeof(cl_mem), (
void *) &Ct));
93 CL_CHECK(clSetKernelArg(kernel, 7,
sizeof(cl_mem), (
void *) &elements));
94 CL_CHECK(clSetKernelArg(kernel, 8,
sizeof(
int), n_points));
96 const size_t global_item_size = 256 * (*n_points);
97 const size_t local_item_size = 256;
100 NULL, &global_item_size, &local_item_size,
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ u
__global__ void T *__restrict__ T *__restrict__ const T *__restrict__ const T *__restrict__ v
void opencl_kernel_jit(const char *kernel, cl_program *program)
void opencl_tnsr3d(void *v, int *nv, void *u, int *nu, void *A, void *Bt, void *Ct, int *nel)
void opencl_tnsr3d_el_list(void *v, int *nv, void *u, int *nu, void *A, void *Bt, void *Ct, int *elements, int *n_points)