|
1 | 1 | from . import CWrapPlugin |
2 | 2 | from string import Template |
3 | 3 |
|
| 4 | +# For out of place: |
| 5 | +# Two args: expand the two args together |
| 6 | +# Three args (fused kernels): (e.g. addcmul) expand all three args together |
| 7 | +# Sketch of proof that this is the same: |
| 8 | +# consider addcmul, under expansion we want: a + (b * c) = (a + b * c) [all expanded together] |
| 9 | +# Let e(i, j) be the expansion of i with j, e(i, j, k) be the expansion of i with j,k |
| 10 | +# |
| 11 | +# Then a + (b * c) = e(a, e(b,c) * e(c,b)) + e(e(b,c) * e(c,b), a) |
| 12 | +# = e(a, e(b,c)) + e(e(b,c) * e(c,b), a) (only size matters for second param) |
| 13 | +# = e(a,b,c) + e(e(b,c) * e(c,b), a) (by associativity of max in expand) |
| 14 | +# = e(a,b,c) + e(b,c,a) * e(c,b,a) (see L1) |
| 15 | +# which is a + b * c all expanded together |
| 16 | +# |
| 17 | +# L1: Show e(i * j, a) = e(i,a) * e(j,a) where i,j have same size |
| 18 | +# Consider any index _{ s_0, ..., s_n} |
| 19 | +# e(i * j, a) = (i*j)_{f(s_0), ...,f(s_n)} where f is the expansion of that dimension with a |
| 20 | +# = i_{f(s_0), ..., f(s_n)} * j_{f(s_0), ..., f(s_n)} by definition of pointwise operator |
| 21 | +# = e(i,a) * e(j,a) |
4 | 22 |
|
5 | 23 | class Broadcast(CWrapPlugin): |
6 | 24 | DEPRECATED_WARNING = \ |
|
0 commit comments