Skip to content

Commit f2b6faf

Browse files
gchanansoumith
authored andcommitted
Proof that broadcasting 3 args (expand3) is equivalent to
breaking up operation.
1 parent e534df8 commit f2b6faf

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tools/cwrap/plugins/Broadcast.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,24 @@
11
from . import CWrapPlugin
22
from string import Template
33

4+
# For out of place:
5+
# Two args: expand the two args together
6+
# Three args (fused kernels): (e.g. addcmul) expand all three args together
7+
# Sketch of proof that this is the same:
8+
# consider addcmul, under expansion we want: a + (b * c) = (a + b * c) [all expanded together]
9+
# Let e(i, j) be the expansion of i with j, e(i, j, k) be the expansion of i with j,k
10+
#
11+
# Then a + (b * c) = e(a, e(b,c) * e(c,b)) + e(e(b,c) * e(c,b), a)
12+
# = e(a, e(b,c)) + e(e(b,c) * e(c,b), a) (only size matters for second param)
13+
# = e(a,b,c) + e(e(b,c) * e(c,b), a) (by associativity of max in expand)
14+
# = e(a,b,c) + e(b,c,a) * e(c,b,a) (see L1)
15+
# which is a + b * c all expanded together
16+
#
17+
# L1: Show e(i * j, a) = e(i,a) * e(j,a) where i,j have same size
18+
# Consider any index _{ s_0, ..., s_n}
19+
# e(i * j, a) = (i*j)_{f(s_0), ...,f(s_n)} where f is the expansion of that dimension with a
20+
# = i_{f(s_0), ..., f(s_n)} * j_{f(s_0), ..., f(s_n)} by definition of pointwise operator
21+
# = e(i,a) * e(j,a)
422

523
class Broadcast(CWrapPlugin):
624
DEPRECATED_WARNING = \

0 commit comments

Comments
 (0)