@@ -33,18 +33,18 @@ namespace kernel
3333 uint blocks_dim,
3434 uint lim)
3535 {
36- const uint tidx = threadIdx.x ;
37- const uint tidy = threadIdx.y ;
38- const uint tid = tidy * THREADS_X + tidx;
36+ const int tidx = threadIdx.x ;
37+ const int tidy = threadIdx.y ;
38+ const int tid = tidy * THREADS_X + tidx;
3939
40- const uint zid = blockIdx.x / blocks_x;
41- const uint wid = blockIdx.y / blocks_y;
42- const uint blockIdx_x = blockIdx.x - (blocks_x) * zid;
43- const uint blockIdx_y = blockIdx.y - (blocks_y) * wid;
44- const uint xid = blockIdx_x * blockDim.x + tidx;
45- const uint yid = blockIdx_y; // yid of output. updated for input later.
40+ const int zid = blockIdx.x / blocks_x;
41+ const int wid = blockIdx.y / blocks_y;
42+ const int blockIdx_x = blockIdx.x - (blocks_x) * zid;
43+ const int blockIdx_y = blockIdx.y - (blocks_y) * wid;
44+ const int xid = blockIdx_x * blockDim.x + tidx;
45+ const int yid = blockIdx_y; // yid of output. updated for input later.
4646
47- uint ids[4 ] = {xid, yid, zid, wid};
47+ int ids[4 ] = {xid, yid, zid, wid};
4848
4949 const Ti *iptr = in.ptr ;
5050 To *optr = out.ptr ;
@@ -54,22 +54,22 @@ namespace kernel
5454 // There are blockDim.y elements per block for in
5555 // Hence increment ids[dim] just after offseting out and before offsetting in
5656 tptr += ids[3 ] * tmp.strides [3 ] + ids[2 ] * tmp.strides [2 ] + ids[1 ] * tmp.strides [1 ] + ids[0 ];
57- const uint blockIdx_dim = ids[dim];
57+ const int blockIdx_dim = ids[dim];
5858
5959 ids[dim] = ids[dim] * blockDim.y * lim + tidy;
6060 optr += ids[3 ] * out.strides [3 ] + ids[2 ] * out.strides [2 ] + ids[1 ] * out.strides [1 ] + ids[0 ];
6161 iptr += ids[3 ] * in.strides [3 ] + ids[2 ] * in.strides [2 ] + ids[1 ] * in.strides [1 ] + ids[0 ];
62- uint id_dim = ids[dim];
63- const uint out_dim = out.dims [dim];
62+ int id_dim = ids[dim];
63+ const int out_dim = out.dims [dim];
6464
6565 bool is_valid =
6666 (ids[0 ] < out.dims [0 ]) &&
6767 (ids[1 ] < out.dims [1 ]) &&
6868 (ids[2 ] < out.dims [2 ]) &&
6969 (ids[3 ] < out.dims [3 ]);
7070
71- const uint ostride_dim = out.strides [dim];
72- const uint istride_dim = in.strides [dim];
71+ const int ostride_dim = out.strides [dim];
72+ const int istride_dim = in.strides [dim];
7373
7474 __shared__ To s_val[THREADS_X * DIMY * 2 ];
7575 __shared__ To s_tmp[THREADS_X];
@@ -92,7 +92,8 @@ namespace kernel
9292 *sptr = val;
9393 __syncthreads ();
9494
95- uint start = 0 ;
95+ int start = 0 ;
96+ #pragma unroll
9697 for (int off = 1 ; off < DIMY; off *= 2 ) {
9798
9899 if (tidy >= off) val = binop (val, sptr[(start - off) * THREADS_X]);
@@ -103,6 +104,7 @@ namespace kernel
103104 }
104105
105106 val = binop (val, s_tmp[tidx]);
107+ __syncthreads ();
106108 if (cond) *optr = val;
107109
108110 id_dim += blockDim.y ;
@@ -127,17 +129,17 @@ namespace kernel
127129 uint blocks_dim,
128130 uint lim)
129131 {
130- const uint tidx = threadIdx.x ;
131- const uint tidy = threadIdx.y ;
132+ const int tidx = threadIdx.x ;
133+ const int tidy = threadIdx.y ;
132134
133- const uint zid = blockIdx.x / blocks_x;
134- const uint wid = blockIdx.y / blocks_y;
135- const uint blockIdx_x = blockIdx.x - (blocks_x) * zid;
136- const uint blockIdx_y = blockIdx.y - (blocks_y) * wid;
137- const uint xid = blockIdx_x * blockDim.x + tidx;
138- const uint yid = blockIdx_y; // yid of output. updated for input later.
135+ const int zid = blockIdx.x / blocks_x;
136+ const int wid = blockIdx.y / blocks_y;
137+ const int blockIdx_x = blockIdx.x - (blocks_x) * zid;
138+ const int blockIdx_y = blockIdx.y - (blocks_y) * wid;
139+ const int xid = blockIdx_x * blockDim.x + tidx;
140+ const int yid = blockIdx_y; // yid of output. updated for input later.
139141
140- uint ids[4 ] = {xid, yid, zid, wid};
142+ int ids[4 ] = {xid, yid, zid, wid};
141143
142144 const To *tptr = tmp.ptr ;
143145 To *optr = out.ptr ;
@@ -146,12 +148,12 @@ namespace kernel
146148 // There are blockDim.y elements per block for in
147149 // Hence increment ids[dim] just after offseting out and before offsetting in
148150 tptr += ids[3 ] * tmp.strides [3 ] + ids[2 ] * tmp.strides [2 ] + ids[1 ] * tmp.strides [1 ] + ids[0 ];
149- const uint blockIdx_dim = ids[dim];
151+ const int blockIdx_dim = ids[dim];
150152
151153 ids[dim] = ids[dim] * blockDim.y * lim + tidy;
152154 optr += ids[3 ] * out.strides [3 ] + ids[2 ] * out.strides [2 ] + ids[1 ] * out.strides [1 ] + ids[0 ];
153- const uint id_dim = ids[dim];
154- const uint out_dim = out.dims [dim];
155+ const int id_dim = ids[dim];
156+ const int out_dim = out.dims [dim];
155157
156158 bool is_valid =
157159 (ids[0 ] < out.dims [0 ]) &&
@@ -165,7 +167,7 @@ namespace kernel
165167 To accum = *(tptr - tmp.strides [dim]);
166168
167169 Binary<To, op> binop;
168- const uint ostride_dim = out.strides [dim];
170+ const int ostride_dim = out.strides [dim];
169171
170172 for (int k = 0 , id = id_dim;
171173 is_valid && k < lim && (id < out_dim);
0 commit comments