Skip to content

Commit 04882e9

Browse files
gchanansoumith
authored andcommitted
Support "fused" ops: addcmul/addcdiv.
1 parent 6e927a6 commit 04882e9

File tree

8 files changed

+374
-189
lines changed

8 files changed

+374
-189
lines changed

test/test_torch.py

Lines changed: 140 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -946,12 +946,15 @@ def test_arange(self):
946946

947947
@staticmethod
948948
def _test_broadcast(self, cast):
949-
def select_broadcastable_dims():
949+
def select_broadcastable_dims(dims_full=None):
950950
# select full dimensionality
951-
ndims = random.randint(1, 4)
952-
dims_full = []
953-
for _ in range(ndims):
954-
dims_full = dims_full + [random.randint(1, 8)]
951+
if dims_full is None:
952+
dims_full=[]
953+
ndims = random.randint(1, 4)
954+
for _ in range(ndims):
955+
dims_full = dims_full + [random.randint(1, 8)]
956+
else:
957+
ndims = len(dims_full)
955958

956959
# select actual dimensions for ops:
957960
# larger: full ndims, individual sizes may be reduced
@@ -979,15 +982,21 @@ def select_broadcastable_dims():
979982
fns = [
980983
"dist", "atan2", "pow", "lerp", "add",
981984
"sub", "mul", "div", "fmod", "remainder",
982-
"eq", "ge", "gt", "le", "lt", "max", "min", "ne"
985+
"eq", "ge", "gt", "le", "lt", "max", "min", "ne",
986+
"addcdiv", "addcmul"
983987
]
984988
# functions with no torch. equivalent
985989
fns_no_torch = ["sub"]
986990
# functions with no inplace equivalent
987991
fns_no_inplace = ["dist", "max", "min"]
992+
# functions with no out-of-place tensor version
993+
fns_no_out_place = []
988994
# functions with fallback to equal nElem behavior
989995
fns_fallback = ["add", "sub", "div", "mul", "pow", "fmod", "remainder",
990-
"eq", "ge", "gt", "le", "lt", "max", "min", "ne"]
996+
"eq", "ge", "gt", "le", "lt", "max", "min", "ne",
997+
"addcdiv", "addcmul"]
998+
# functions with three tensor arguments
999+
fns_3_args = ["addcdiv", "addcmul"]
9911000

9921001
for fn in fns:
9931002
(dims_small, dims_large, dims_full) = select_broadcastable_dims()
@@ -997,108 +1006,182 @@ def select_broadcastable_dims():
9971006
large = cast(large)
9981007
smallExpanded = small.expand(*dims_full)
9991008
largeExpanded = large.expand(*dims_full)
1000-
# run through tensor versions of functions
1001-
# and verify fully expanded inputs give same results
1002-
fntensor_large_expanded = getattr(largeExpanded, fn)
1003-
fntensor_large_non_expanded = getattr(large, fn)
1004-
1005-
def tensorfn(myfn, t):
1006-
return myfn(t) if fn != "lerp" else myfn(t, 0.5)
1007-
r1 = tensorfn(fntensor_large_expanded, smallExpanded)
1008-
r2 = tensorfn(fntensor_large_non_expanded, small)
1009-
self.assertEqual(r1, r2)
1010-
# other order
1011-
fntensor_small_expanded = getattr(smallExpanded, fn)
1012-
fntensor_small_non_expanded = getattr(small, fn)
1013-
r1 = tensorfn(fntensor_small_expanded, largeExpanded)
1014-
r2 = tensorfn(fntensor_small_non_expanded, large)
1015-
self.assertEqual(r1, r2)
1009+
small2 = None
1010+
small2Expanded = None
1011+
if fn in fns_3_args:
1012+
# create another smaller tensor
1013+
(dims_small2, _, _) = select_broadcastable_dims(dims_full)
1014+
small2 = torch.randn(*dims_small2).float()
1015+
small2 = cast(small2)
1016+
small2Expanded = small2.expand(*dims_full)
1017+
1018+
if fn not in fns_no_out_place:
1019+
# run through tensor versions of functions
1020+
# and verify fully expanded inputs give same results
1021+
fntensor_large_expanded = getattr(largeExpanded, fn)
1022+
fntensor_large_non_expanded = getattr(large, fn)
1023+
1024+
def tensorfn(myfn, t1, t2):
1025+
if fn == "lerp":
1026+
return myfn(t1, 0.5)
1027+
elif fn in fns_3_args:
1028+
return myfn(1, t1, t2)
1029+
else:
1030+
return myfn(t1)
1031+
r1 = tensorfn(fntensor_large_expanded, smallExpanded, small2Expanded)
1032+
r2 = tensorfn(fntensor_large_non_expanded, small, small2)
1033+
self.assertEqual(r1, r2)
1034+
# other order
1035+
fntensor_small_expanded = getattr(smallExpanded, fn)
1036+
fntensor_small_non_expanded = getattr(small, fn)
1037+
r1 = tensorfn(fntensor_small_expanded, largeExpanded, small2Expanded)
1038+
r2 = tensorfn(fntensor_small_non_expanded, large, small2)
1039+
self.assertEqual(r1, r2)
1040+
if fn in fns_3_args:
1041+
fntensor_small2_expanded = getattr(small2Expanded, fn)
1042+
fntensor_small2_non_expanded = getattr(small2, fn)
1043+
r1 = tensorfn(fntensor_small2_expanded, smallExpanded, largeExpanded)
1044+
r2 = tensorfn(fntensor_small2_non_expanded, small, large)
1045+
self.assertEqual(r1, r2)
1046+
r1 = tensorfn(fntensor_small2_expanded, largeExpanded, smallExpanded)
1047+
r2 = tensorfn(fntensor_small2_non_expanded, large, small)
1048+
self.assertEqual(r1, r2)
10161049

10171050
# now for torch. versions of functions
10181051
if fn not in fns_no_torch:
10191052
fntorch = getattr(torch, fn)
10201053

1021-
def torchfn(t1, t2):
1022-
return (fntorch(t1, t2) if fn != "lerp"
1023-
else fntorch(t1, t2, 0.5))
1024-
r1 = torchfn(large, small)
1025-
r2 = torchfn(largeExpanded, smallExpanded)
1054+
def torchfn(t1, t2, t3):
1055+
if fn == "lerp":
1056+
return fntorch(t1, t2, 0.5)
1057+
elif fn in fns_3_args:
1058+
return fntorch(t1, 1.0, t2, t3)
1059+
else:
1060+
return fntorch(t1, t2)
1061+
r1 = torchfn(large, small, small2)
1062+
r2 = torchfn(largeExpanded, smallExpanded, small2Expanded)
10261063
self.assertEqual(r1, r2)
10271064
# other order
1028-
r1 = torchfn(small, large)
1029-
r2 = torchfn(smallExpanded, largeExpanded)
1065+
r1 = torchfn(small, large, small2)
1066+
r2 = torchfn(smallExpanded, largeExpanded, small2Expanded)
10301067
self.assertEqual(r1, r2)
1068+
if fn in fns_3_args:
1069+
r1 = torchfn(small2, small, large)
1070+
r2 = torchfn(small2Expanded, smallExpanded, largeExpanded)
1071+
self.assertEqual(r1, r2)
1072+
r1 = torchfn(small2, large, small)
1073+
r2 = torchfn(small2Expanded, largeExpanded, smallExpanded)
1074+
self.assertEqual(r1, r2)
10311075

10321076
# now for in place functions
10331077
if fn not in fns_no_inplace:
10341078
# in-place tensor is not broadcastable; test only guaranteed
1035-
# to work by broadcasting other argument
1079+
# to work by broadcasting other argument(s)
10361080

1037-
# need to clone largeExpanded so we can reuse
1081+
# need to clone largeExpanded so we can reuse, since functions are in-place
10381082
largeExpandedClone = largeExpanded.clone()
10391083

1040-
def tensorfn_inplace(t0, t1):
1084+
def tensorfn_inplace(t0, t1, t2=None):
10411085
t0_fn = getattr(t0, fn + "_")
1042-
return t0_fn(t1) if fn != "lerp" else t0_fn(t1, 0.5)
1043-
r1 = tensorfn_inplace(largeExpanded, smallExpanded)
1044-
r2 = tensorfn_inplace(largeExpandedClone, small)
1045-
# in-place pointwise operations don't actually work on 0-strided tensors
1046-
# (numpy has the same issue)
1047-
if (0 not in largeExpanded.stride() and 0 not in smallExpanded.stride()
1048-
and 0 not in largeExpandedClone.stride() and 0 not in small.stride()):
1086+
if fn == "lerp":
1087+
return t0_fn(t1, 0.5)
1088+
elif fn in fns_3_args:
1089+
return t0_fn(1.0, t1, t2)
1090+
else:
1091+
return t0_fn(t1)
1092+
r1 = tensorfn_inplace(largeExpanded, smallExpanded, small2Expanded)
1093+
r2 = tensorfn_inplace(largeExpandedClone, small, small2)
1094+
# in-place pointwise operations don't actually work if the in-place
1095+
# tensor is 0-strided (numpy has the same issue)
1096+
if (0 not in largeExpanded.stride() and 0 not in largeExpandedClone.stride()):
10491097
self.assertEqual(r1, r2)
10501098

1051-
broadcastable = (dims_small == dims_full)
1052-
if not broadcastable:
1053-
if (fn not in fns_fallback) or (fn in fns_fallback and small.numel() != largeExpanded.numel()):
1054-
self.assertRaises(RuntimeError, lambda: tensorfn_inplace(small, largeExpanded))
1055-
if (fn not in fns_fallback) or (fn in fns_fallback and small.numel() != large.numel()):
1056-
self.assertRaises(RuntimeError, lambda: tensorfn_inplace(small, large))
1099+
def broadcastable(t0, t1, t2=None):
1100+
try:
1101+
t1.expand_as(t0)
1102+
if t2 is not None:
1103+
t2.expand_as(t0)
1104+
except RuntimeError:
1105+
return False
1106+
return True
1107+
1108+
def _test_in_place_broadcastable(t0, t1, t2=None):
1109+
if not broadcastable(t0, t1, t2):
1110+
same_size = t0.numel() == t1.numel() and (t0.numel() == t2.numel() if t2 is not None else True)
1111+
if (fn not in fns_fallback) or (fn in fns_fallback and not same_size):
1112+
self.assertRaises(RuntimeError, lambda: tensorfn_inplace(t0, t1, t2))
1113+
else:
1114+
tensorfn_inplace(t0, t1, t2)
1115+
1116+
if fn not in fns_3_args:
1117+
_test_in_place_broadcastable(small, largeExpanded)
1118+
_test_in_place_broadcastable(small, large)
1119+
else:
1120+
_test_in_place_broadcastable(small2, smallExpanded, largeExpanded)
1121+
_test_in_place_broadcastable(small2, small, large)
10571122

10581123
def test_broadcast(self):
1059-
for i in range(100):
1060-
self._test_broadcast(self, lambda t: t)
1124+
self._test_broadcast(self, lambda t: t)
10611125

10621126
@staticmethod
10631127
def _test_broadcast_fallback(self, cast):
1128+
# functions that should fallback to pointwise behavior
10641129
fns_fallback = ["add", "sub", "div", "mul", "pow", "fmod", "remainder",
1065-
"eq", "ge", "gt", "le", "lt", "max", "min", "ne"]
1130+
"eq", "ge", "gt", "le", "lt", "max", "min", "ne",
1131+
"addcdiv", "addcmul"]
1132+
# functions with three tensor arguments
1133+
fns_3_args = ["addcdiv", "addcmul"]
1134+
10661135
for fn in fns_fallback:
10671136
# case 1: both broadcastable and nElems equal -- verify that we broadcast
10681137
t0 = torch.randn(1, 4).float()
10691138
t0 = cast(t0)
10701139
t1 = torch.randn(4, 1).float()
10711140
t1 = cast(t1)
1141+
t2 = torch.randn(4).float()
1142+
t2 = cast(t2)
10721143
broadcastSize = torch.Size([4, 4])
10731144
t0_fn = getattr(t0, fn)
10741145
t1_fn = getattr(t1, fn)
10751146

1076-
def tensorfn(myfn, t):
1077-
return myfn(t) if fn != "lerp" else myfn(t, 0.5)
1078-
r0 = tensorfn(t0_fn, t1)
1079-
r1 = tensorfn(t1_fn, t0)
1147+
def tensorfn(myfn, t1, t2):
1148+
if fn == "lerp":
1149+
return myfn(t1, 0.5)
1150+
elif fn in fns_3_args:
1151+
return myfn(1.0, t1, t2)
1152+
else:
1153+
return myfn(t1)
1154+
r0 = tensorfn(t0_fn, t1, t2)
1155+
r1 = tensorfn(t1_fn, t0, t2)
10801156
self.assertEqual(broadcastSize, r0.size())
10811157
self.assertEqual(broadcastSize, r1.size())
10821158

10831159
# case 2: broadcastable and not nElemes equal -- tested by test_fallback
1160+
10841161
# case 3: not broadcastable nElems equal -- verify we fallback
1085-
t0 = torch.randn(1, 4).float()
1086-
t1 = torch.randn(2, 2).float()
1162+
t0 = torch.randn(1, 6).float()
1163+
t1 = torch.randn(2, 3).float()
1164+
t2 = torch.randn(3, 2).float()
10871165
t0_fn = getattr(t0, fn)
10881166
t1_fn = getattr(t1, fn)
1167+
t2_fn = getattr(t2, fn)
10891168

10901169
def verifyFallbackWarnings(w):
10911170
self.assertEqual(len(w), 1)
10921171
self.assertTrue(issubclass(w[0].category, UserWarning))
10931172
self.assertTrue("Falling back" in str(w[0].message))
10941173
with warnings.catch_warnings(record=True) as w:
1095-
r0 = tensorfn(t0_fn, t1)
1174+
r0 = tensorfn(t0_fn, t1, t2)
1175+
verifyFallbackWarnings(w)
1176+
with warnings.catch_warnings(record=True) as w:
1177+
r1 = tensorfn(t1_fn, t0, t2)
10961178
verifyFallbackWarnings(w)
10971179
with warnings.catch_warnings(record=True) as w:
1098-
r1 = tensorfn(t1_fn, t0)
1180+
r2 = tensorfn(t2_fn, t0, t1)
10991181
verifyFallbackWarnings(w)
11001182
self.assertEqual(t0.size(), r0.size())
11011183
self.assertEqual(t1.size(), r1.size())
1184+
self.assertEqual(t2.size(), r2.size())
11021185

11031186
# case 4: not broadcastable and not nEleme equal -- tested by test_fallback
11041187

0 commit comments

Comments
 (0)