Skip to content

Commit 26975c3

Browse files
committed
Merge pull request arrayfire#548 from pavanky/scan_fixes
Bug fixes for accum in CUDA and OpenCL backends
2 parents f539123 + d58ec95 commit 26975c3

File tree

4 files changed

+122
-109
lines changed

4 files changed

+122
-109
lines changed

src/backend/cuda/kernel/scan_dim.hpp

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,18 @@ namespace kernel
3333
uint blocks_dim,
3434
uint lim)
3535
{
36-
const uint tidx = threadIdx.x;
37-
const uint tidy = threadIdx.y;
38-
const uint tid = tidy * THREADS_X + tidx;
36+
const int tidx = threadIdx.x;
37+
const int tidy = threadIdx.y;
38+
const int tid = tidy * THREADS_X + tidx;
3939

40-
const uint zid = blockIdx.x / blocks_x;
41-
const uint wid = blockIdx.y / blocks_y;
42-
const uint blockIdx_x = blockIdx.x - (blocks_x) * zid;
43-
const uint blockIdx_y = blockIdx.y - (blocks_y) * wid;
44-
const uint xid = blockIdx_x * blockDim.x + tidx;
45-
const uint yid = blockIdx_y; // yid of output. updated for input later.
40+
const int zid = blockIdx.x / blocks_x;
41+
const int wid = blockIdx.y / blocks_y;
42+
const int blockIdx_x = blockIdx.x - (blocks_x) * zid;
43+
const int blockIdx_y = blockIdx.y - (blocks_y) * wid;
44+
const int xid = blockIdx_x * blockDim.x + tidx;
45+
const int yid = blockIdx_y; // yid of output. updated for input later.
4646

47-
uint ids[4] = {xid, yid, zid, wid};
47+
int ids[4] = {xid, yid, zid, wid};
4848

4949
const Ti *iptr = in.ptr;
5050
To *optr = out.ptr;
@@ -54,22 +54,22 @@ namespace kernel
5454
// There are blockDim.y elements per block for in
5555
// Hence increment ids[dim] just after offseting out and before offsetting in
5656
tptr += ids[3] * tmp.strides[3] + ids[2] * tmp.strides[2] + ids[1] * tmp.strides[1] + ids[0];
57-
const uint blockIdx_dim = ids[dim];
57+
const int blockIdx_dim = ids[dim];
5858

5959
ids[dim] = ids[dim] * blockDim.y * lim + tidy;
6060
optr += ids[3] * out.strides[3] + ids[2] * out.strides[2] + ids[1] * out.strides[1] + ids[0];
6161
iptr += ids[3] * in.strides[3] + ids[2] * in.strides[2] + ids[1] * in.strides[1] + ids[0];
62-
uint id_dim = ids[dim];
63-
const uint out_dim = out.dims[dim];
62+
int id_dim = ids[dim];
63+
const int out_dim = out.dims[dim];
6464

6565
bool is_valid =
6666
(ids[0] < out.dims[0]) &&
6767
(ids[1] < out.dims[1]) &&
6868
(ids[2] < out.dims[2]) &&
6969
(ids[3] < out.dims[3]);
7070

71-
const uint ostride_dim = out.strides[dim];
72-
const uint istride_dim = in.strides[dim];
71+
const int ostride_dim = out.strides[dim];
72+
const int istride_dim = in.strides[dim];
7373

7474
__shared__ To s_val[THREADS_X * DIMY * 2];
7575
__shared__ To s_tmp[THREADS_X];
@@ -92,7 +92,8 @@ namespace kernel
9292
*sptr = val;
9393
__syncthreads();
9494

95-
uint start = 0;
95+
int start = 0;
96+
#pragma unroll
9697
for (int off = 1; off < DIMY; off *= 2) {
9798

9899
if (tidy >= off) val = binop(val, sptr[(start - off) * THREADS_X]);
@@ -103,6 +104,7 @@ namespace kernel
103104
}
104105

105106
val = binop(val, s_tmp[tidx]);
107+
__syncthreads();
106108
if (cond) *optr = val;
107109

108110
id_dim += blockDim.y;
@@ -127,17 +129,17 @@ namespace kernel
127129
uint blocks_dim,
128130
uint lim)
129131
{
130-
const uint tidx = threadIdx.x;
131-
const uint tidy = threadIdx.y;
132+
const int tidx = threadIdx.x;
133+
const int tidy = threadIdx.y;
132134

133-
const uint zid = blockIdx.x / blocks_x;
134-
const uint wid = blockIdx.y / blocks_y;
135-
const uint blockIdx_x = blockIdx.x - (blocks_x) * zid;
136-
const uint blockIdx_y = blockIdx.y - (blocks_y) * wid;
137-
const uint xid = blockIdx_x * blockDim.x + tidx;
138-
const uint yid = blockIdx_y; // yid of output. updated for input later.
135+
const int zid = blockIdx.x / blocks_x;
136+
const int wid = blockIdx.y / blocks_y;
137+
const int blockIdx_x = blockIdx.x - (blocks_x) * zid;
138+
const int blockIdx_y = blockIdx.y - (blocks_y) * wid;
139+
const int xid = blockIdx_x * blockDim.x + tidx;
140+
const int yid = blockIdx_y; // yid of output. updated for input later.
139141

140-
uint ids[4] = {xid, yid, zid, wid};
142+
int ids[4] = {xid, yid, zid, wid};
141143

142144
const To *tptr = tmp.ptr;
143145
To *optr = out.ptr;
@@ -146,12 +148,12 @@ namespace kernel
146148
// There are blockDim.y elements per block for in
147149
// Hence increment ids[dim] just after offseting out and before offsetting in
148150
tptr += ids[3] * tmp.strides[3] + ids[2] * tmp.strides[2] + ids[1] * tmp.strides[1] + ids[0];
149-
const uint blockIdx_dim = ids[dim];
151+
const int blockIdx_dim = ids[dim];
150152

151153
ids[dim] = ids[dim] * blockDim.y * lim + tidy;
152154
optr += ids[3] * out.strides[3] + ids[2] * out.strides[2] + ids[1] * out.strides[1] + ids[0];
153-
const uint id_dim = ids[dim];
154-
const uint out_dim = out.dims[dim];
155+
const int id_dim = ids[dim];
156+
const int out_dim = out.dims[dim];
155157

156158
bool is_valid =
157159
(ids[0] < out.dims[0]) &&
@@ -165,7 +167,7 @@ namespace kernel
165167
To accum = *(tptr - tmp.strides[dim]);
166168

167169
Binary<To, op> binop;
168-
const uint ostride_dim = out.strides[dim];
170+
const int ostride_dim = out.strides[dim];
169171

170172
for (int k = 0, id = id_dim;
171173
is_valid && k < lim && (id < out_dim);

src/backend/cuda/kernel/scan_first.hpp

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,19 @@ namespace kernel
3131
uint blocks_y,
3232
uint lim)
3333
{
34-
const uint tidx = threadIdx.x;
35-
const uint tidy = threadIdx.y;
34+
const int tidx = threadIdx.x;
35+
const int tidy = threadIdx.y;
3636

37-
const uint zid = blockIdx.x / blocks_x;
38-
const uint wid = blockIdx.y / blocks_y;
39-
const uint blockIdx_x = blockIdx.x - (blocks_x) * zid;
40-
const uint blockIdx_y = blockIdx.y - (blocks_y) * wid;
41-
const uint xid = blockIdx_x * blockDim.x * lim + tidx;
42-
const uint yid = blockIdx_y * blockDim.y + tidy;
37+
const int zid = blockIdx.x / blocks_x;
38+
const int wid = blockIdx.y / blocks_y;
39+
const int blockIdx_x = blockIdx.x - (blocks_x) * zid;
40+
const int blockIdx_y = blockIdx.y - (blocks_y) * wid;
41+
const int xid = blockIdx_x * blockDim.x * lim + tidx;
42+
const int yid = blockIdx_y * blockDim.y + tidy;
43+
44+
bool cond_yzw = (yid < out.dims[1]) && (zid < out.dims[2]) && (wid < out.dims[3]);
45+
46+
if (!cond_yzw) return; // retire warps early
4347

4448
const Ti *iptr = in.ptr;
4549
To *optr = out.ptr;
@@ -49,10 +53,9 @@ namespace kernel
4953
optr += wid * out.strides[3] + zid * out.strides[2] + yid * out.strides[1];
5054
tptr += wid * tmp.strides[3] + zid * tmp.strides[2] + yid * tmp.strides[1];
5155

52-
bool cond_yzw = (yid < out.dims[1]) && (zid < out.dims[2]) && (wid < out.dims[3]);
5356

54-
const uint DIMY = THREADS_PER_BLOCK / DIMX;
55-
const uint SHARED_MEM_SIZE = (2 * DIMX + 1) * (DIMY);
57+
const int DIMY = THREADS_PER_BLOCK / DIMX;
58+
const int SHARED_MEM_SIZE = (2 * DIMX + 1) * (DIMY);
5659

5760
__shared__ To s_val[SHARED_MEM_SIZE];
5861
__shared__ To s_tmp[DIMY];
@@ -63,7 +66,7 @@ namespace kernel
6366
Binary<To, op> binop;
6467

6568
const To init = binop.init();
66-
uint id = xid;
69+
int id = xid;
6770
To val = init;
6871

6972
const bool isLast = (tidx == (DIMX - 1));
@@ -72,13 +75,14 @@ namespace kernel
7275

7376
if (isLast) s_tmp[tidy] = val;
7477

75-
bool cond = (cond_yzw && (id < out.dims[0]));
78+
bool cond = ((id < out.dims[0]));
7679
val = cond ? transform(iptr[id]) : init;
7780
sptr[tidx] = val;
7881
__syncthreads();
7982

8083

81-
uint start = 0;
84+
int start = 0;
85+
#pragma unroll
8286
for (int off = 1; off < DIMX; off *= 2) {
8387

8488
if (tidx >= off) val = binop(val, sptr[(start - off) + tidx]);
@@ -91,9 +95,10 @@ namespace kernel
9195
val = binop(val, s_tmp[tidy]);
9296
if (cond) optr[id] = val;
9397
id += blockDim.x;
98+
__syncthreads();
9499
}
95100

96-
if (!isFinalPass && cond_yzw && isLast) {
101+
if (!isFinalPass && isLast) {
97102
tptr[blockIdx_x] = val;
98103
}
99104
}
@@ -106,27 +111,27 @@ namespace kernel
106111
uint blocks_y,
107112
uint lim)
108113
{
109-
const uint tidx = threadIdx.x;
110-
const uint tidy = threadIdx.y;
114+
const int tidx = threadIdx.x;
115+
const int tidy = threadIdx.y;
116+
117+
const int zid = blockIdx.x / blocks_x;
118+
const int wid = blockIdx.y / blocks_y;
119+
const int blockIdx_x = blockIdx.x - (blocks_x) * zid;
120+
const int blockIdx_y = blockIdx.y - (blocks_y) * wid;
121+
const int xid = blockIdx_x * blockDim.x * lim + tidx;
122+
const int yid = blockIdx_y * blockDim.y + tidy;
111123

112-
const uint zid = blockIdx.x / blocks_x;
113-
const uint wid = blockIdx.y / blocks_y;
114-
const uint blockIdx_x = blockIdx.x - (blocks_x) * zid;
115-
const uint blockIdx_y = blockIdx.y - (blocks_y) * wid;
116-
const uint xid = blockIdx_x * blockDim.x * lim + tidx;
117-
const uint yid = blockIdx_y * blockDim.y + tidy;
124+
if (blockIdx_x == 0) return;
125+
126+
bool cond = (yid < out.dims[1]) && (zid < out.dims[2]) && (wid < out.dims[3]);
127+
if (!cond) return;
118128

119129
To *optr = out.ptr;
120130
const To *tptr = tmp.ptr;
121131

122132
optr += wid * out.strides[3] + zid * out.strides[2] + yid * out.strides[1];
123133
tptr += wid * tmp.strides[3] + zid * tmp.strides[2] + yid * tmp.strides[1];
124134

125-
bool cond = (yid < out.dims[1]) && (zid < out.dims[2]) && (wid < out.dims[3]);
126-
127-
if (!cond) return;
128-
if (blockIdx_x == 0) return;
129-
130135
Binary<To, op> binop;
131136
To accum = tptr[blockIdx_x - 1];
132137

src/backend/opencl/kernel/scan_dim.cl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void scan_dim_kernel(__global To *oData, KParam oInfo,
2727
const int xid = groupId_x * get_local_size(0) + lidx;
2828
const int yid = groupId_y;
2929

30-
uint ids[4] = {xid, yid, zid, wid};
30+
int ids[4] = {xid, yid, zid, wid};
3131

3232
// There is only one element per group for out
3333
// There are DIMY elements per group for in
@@ -69,7 +69,7 @@ void scan_dim_kernel(__global To *oData, KParam oInfo,
6969
l_val[lid] = val;
7070
barrier(CLK_LOCAL_MEM_FENCE);
7171

72-
uint start = 0;
72+
int start = 0;
7373
for (int off = 1; off < DIMY; off *= 2) {
7474

7575
if (lidy >= off) val = binOp(val, l_val[lid - off * THREADS_X]);
@@ -83,6 +83,7 @@ void scan_dim_kernel(__global To *oData, KParam oInfo,
8383

8484
val = binOp(val, l_tmp[lidx]);
8585
if (cond) *oData = val;
86+
barrier(CLK_LOCAL_MEM_FENCE);
8687

8788
id_dim += DIMY;
8889
iData += DIMY * istride_dim;
@@ -116,13 +117,15 @@ void bcast_dim_kernel(__global To *oData, KParam oInfo,
116117
const int xid = groupId_x * get_local_size(0) + lidx;
117118
const int yid = groupId_y;
118119

119-
uint ids[4] = {xid, yid, zid, wid};
120+
int ids[4] = {xid, yid, zid, wid};
121+
const int groupId_dim = ids[dim];
122+
123+
if (groupId_dim == 0) return;
120124

121125
// There is only one element per group for out
122126
// There are DIMY elements per group for in
123127
// Hence increment ids[dim] just after offseting out and before offsetting in
124128
tData += ids[3] * tInfo.strides[3] + ids[2] * tInfo.strides[2] + ids[1] * tInfo.strides[1] + ids[0];
125-
const int groupId_dim = ids[dim];
126129

127130
ids[dim] = ids[dim] * DIMY * lim + lidy;
128131
oData += ids[3] * oInfo.strides[3] + ids[2] * oInfo.strides[2] + ids[1] * oInfo.strides[1] + ids[0];
@@ -137,7 +140,6 @@ void bcast_dim_kernel(__global To *oData, KParam oInfo,
137140
(ids[3] < oInfo.dims[3]);
138141

139142
if (!is_valid) return;
140-
if (groupId_dim == 0) return;
141143

142144
To accum = *(tData - tInfo.strides[dim]);
143145

0 commit comments

Comments
 (0)