@@ -6,41 +6,41 @@ using namespace metal;
6
6
#define MIN (x, y ) ((x) < (y) ? (x) : (y))
7
7
8
8
kernel void kernel_add (
9
- device const float * src0,
10
- device const float * src1,
11
- device float * dst,
9
+ device const float4 * src0,
10
+ device const float4 * src1,
11
+ device float4 * dst,
12
12
uint tpig[[thread_position_in_grid]]) {
13
13
dst[tpig] = src0[tpig] + src1[tpig];
14
14
}
15
15
16
16
// assumption: src1 is a row
17
17
// broadcast src1 into src0
18
18
kernel void kernel_add_row (
19
- device const float * src0,
20
- device const float * src1,
21
- device float * dst,
22
- constant int64_t & ne00 ,
19
+ device const float4 * src0,
20
+ device const float4 * src1,
21
+ device float4 * dst,
22
+ constant int64_t & nb ,
23
23
uint tpig[[thread_position_in_grid]]) {
24
- dst[tpig] = src0[tpig] + src1[tpig % ne00 ];
24
+ dst[tpig] = src0[tpig] + src1[tpig % nb ];
25
25
}
26
26
27
27
kernel void kernel_mul (
28
- device const float * src0,
29
- device const float * src1,
30
- device float * dst,
28
+ device const float4 * src0,
29
+ device const float4 * src1,
30
+ device float4 * dst,
31
31
uint tpig[[thread_position_in_grid]]) {
32
32
dst[tpig] = src0[tpig] * src1[tpig];
33
33
}
34
34
35
35
// assumption: src1 is a row
36
36
// broadcast src1 into src0
37
37
kernel void kernel_mul_row (
38
- device const float * src0,
39
- device const float * src1,
40
- device float * dst,
41
- constant int64_t & ne00 ,
38
+ device const float4 * src0,
39
+ device const float4 * src1,
40
+ device float4 * dst,
41
+ constant int64_t & nb ,
42
42
uint tpig[[thread_position_in_grid]]) {
43
- dst[tpig] = src0[tpig] * src1[tpig % ne00 ];
43
+ dst[tpig] = src0[tpig] * src1[tpig % nb ];
44
44
}
45
45
46
46
kernel void kernel_scale (
0 commit comments