@@ -946,12 +946,15 @@ def test_arange(self):
946946
947947 @staticmethod
948948 def _test_broadcast (self , cast ):
949- def select_broadcastable_dims ():
949+ def select_broadcastable_dims (dims_full = None ):
950950 # select full dimensionality
951- ndims = random .randint (1 , 4 )
952- dims_full = []
953- for _ in range (ndims ):
954- dims_full = dims_full + [random .randint (1 , 8 )]
951+ if dims_full is None :
952+ dims_full = []
953+ ndims = random .randint (1 , 4 )
954+ for _ in range (ndims ):
955+ dims_full = dims_full + [random .randint (1 , 8 )]
956+ else :
957+ ndims = len (dims_full )
955958
956959 # select actual dimensions for ops:
957960 # larger: full ndims, individual sizes may be reduced
@@ -979,15 +982,21 @@ def select_broadcastable_dims():
979982 fns = [
980983 "dist" , "atan2" , "pow" , "lerp" , "add" ,
981984 "sub" , "mul" , "div" , "fmod" , "remainder" ,
982- "eq" , "ge" , "gt" , "le" , "lt" , "max" , "min" , "ne"
985+ "eq" , "ge" , "gt" , "le" , "lt" , "max" , "min" , "ne" ,
986+ "addcdiv" , "addcmul"
983987 ]
984988 # functions with no torch. equivalent
985989 fns_no_torch = ["sub" ]
986990 # functions with no inplace equivalent
987991 fns_no_inplace = ["dist" , "max" , "min" ]
992+ # functions with no out-of-place tensor version
993+ fns_no_out_place = []
988994 # functions with fallback to equal nElem behavior
989995 fns_fallback = ["add" , "sub" , "div" , "mul" , "pow" , "fmod" , "remainder" ,
990- "eq" , "ge" , "gt" , "le" , "lt" , "max" , "min" , "ne" ]
996+ "eq" , "ge" , "gt" , "le" , "lt" , "max" , "min" , "ne" ,
997+ "addcdiv" , "addcmul" ]
998+ # functions with three tensor arguments
999+ fns_3_args = ["addcdiv" , "addcmul" ]
9911000
9921001 for fn in fns :
9931002 (dims_small , dims_large , dims_full ) = select_broadcastable_dims ()
@@ -997,108 +1006,182 @@ def select_broadcastable_dims():
9971006 large = cast (large )
9981007 smallExpanded = small .expand (* dims_full )
9991008 largeExpanded = large .expand (* dims_full )
1000- # run through tensor versions of functions
1001- # and verify fully expanded inputs give same results
1002- fntensor_large_expanded = getattr (largeExpanded , fn )
1003- fntensor_large_non_expanded = getattr (large , fn )
1004-
1005- def tensorfn (myfn , t ):
1006- return myfn (t ) if fn != "lerp" else myfn (t , 0.5 )
1007- r1 = tensorfn (fntensor_large_expanded , smallExpanded )
1008- r2 = tensorfn (fntensor_large_non_expanded , small )
1009- self .assertEqual (r1 , r2 )
1010- # other order
1011- fntensor_small_expanded = getattr (smallExpanded , fn )
1012- fntensor_small_non_expanded = getattr (small , fn )
1013- r1 = tensorfn (fntensor_small_expanded , largeExpanded )
1014- r2 = tensorfn (fntensor_small_non_expanded , large )
1015- self .assertEqual (r1 , r2 )
1009+ small2 = None
1010+ small2Expanded = None
1011+ if fn in fns_3_args :
1012+ # create another smaller tensor
1013+ (dims_small2 , _ , _ ) = select_broadcastable_dims (dims_full )
1014+ small2 = torch .randn (* dims_small2 ).float ()
1015+ small2 = cast (small2 )
1016+ small2Expanded = small2 .expand (* dims_full )
1017+
1018+ if fn not in fns_no_out_place :
1019+ # run through tensor versions of functions
1020+ # and verify fully expanded inputs give same results
1021+ fntensor_large_expanded = getattr (largeExpanded , fn )
1022+ fntensor_large_non_expanded = getattr (large , fn )
1023+
1024+ def tensorfn (myfn , t1 , t2 ):
1025+ if fn == "lerp" :
1026+ return myfn (t1 , 0.5 )
1027+ elif fn in fns_3_args :
1028+ return myfn (1 , t1 , t2 )
1029+ else :
1030+ return myfn (t1 )
1031+ r1 = tensorfn (fntensor_large_expanded , smallExpanded , small2Expanded )
1032+ r2 = tensorfn (fntensor_large_non_expanded , small , small2 )
1033+ self .assertEqual (r1 , r2 )
1034+ # other order
1035+ fntensor_small_expanded = getattr (smallExpanded , fn )
1036+ fntensor_small_non_expanded = getattr (small , fn )
1037+ r1 = tensorfn (fntensor_small_expanded , largeExpanded , small2Expanded )
1038+ r2 = tensorfn (fntensor_small_non_expanded , large , small2 )
1039+ self .assertEqual (r1 , r2 )
1040+ if fn in fns_3_args :
1041+ fntensor_small2_expanded = getattr (small2Expanded , fn )
1042+ fntensor_small2_non_expanded = getattr (small2 , fn )
1043+ r1 = tensorfn (fntensor_small2_expanded , smallExpanded , largeExpanded )
1044+ r2 = tensorfn (fntensor_small2_non_expanded , small , large )
1045+ self .assertEqual (r1 , r2 )
1046+ r1 = tensorfn (fntensor_small2_expanded , largeExpanded , smallExpanded )
1047+ r2 = tensorfn (fntensor_small2_non_expanded , large , small )
1048+ self .assertEqual (r1 , r2 )
10161049
10171050 # now for torch. versions of functions
10181051 if fn not in fns_no_torch :
10191052 fntorch = getattr (torch , fn )
10201053
1021- def torchfn (t1 , t2 ):
1022- return (fntorch (t1 , t2 ) if fn != "lerp"
1023- else fntorch (t1 , t2 , 0.5 ))
1024- r1 = torchfn (large , small )
1025- r2 = torchfn (largeExpanded , smallExpanded )
1054+ def torchfn (t1 , t2 , t3 ):
1055+ if fn == "lerp" :
1056+ return fntorch (t1 , t2 , 0.5 )
1057+ elif fn in fns_3_args :
1058+ return fntorch (t1 , 1.0 , t2 , t3 )
1059+ else :
1060+ return fntorch (t1 , t2 )
1061+ r1 = torchfn (large , small , small2 )
1062+ r2 = torchfn (largeExpanded , smallExpanded , small2Expanded )
10261063 self .assertEqual (r1 , r2 )
10271064 # other order
1028- r1 = torchfn (small , large )
1029- r2 = torchfn (smallExpanded , largeExpanded )
1065+ r1 = torchfn (small , large , small2 )
1066+ r2 = torchfn (smallExpanded , largeExpanded , small2Expanded )
10301067 self .assertEqual (r1 , r2 )
1068+ if fn in fns_3_args :
1069+ r1 = torchfn (small2 , small , large )
1070+ r2 = torchfn (small2Expanded , smallExpanded , largeExpanded )
1071+ self .assertEqual (r1 , r2 )
1072+ r1 = torchfn (small2 , large , small )
1073+ r2 = torchfn (small2Expanded , largeExpanded , smallExpanded )
1074+ self .assertEqual (r1 , r2 )
10311075
10321076 # now for in place functions
10331077 if fn not in fns_no_inplace :
10341078 # in-place tensor is not broadcastable; test only guaranteed
1035- # to work by broadcasting other argument
1079+ # to work by broadcasting other argument(s)
10361080
1037- # need to clone largeExpanded so we can reuse
1081+ # need to clone largeExpanded so we can reuse, since functions are in-place
10381082 largeExpandedClone = largeExpanded .clone ()
10391083
1040- def tensorfn_inplace (t0 , t1 ):
1084+ def tensorfn_inplace (t0 , t1 , t2 = None ):
10411085 t0_fn = getattr (t0 , fn + "_" )
1042- return t0_fn (t1 ) if fn != "lerp" else t0_fn (t1 , 0.5 )
1043- r1 = tensorfn_inplace (largeExpanded , smallExpanded )
1044- r2 = tensorfn_inplace (largeExpandedClone , small )
1045- # in-place pointwise operations don't actually work on 0-strided tensors
1046- # (numpy has the same issue)
1047- if (0 not in largeExpanded .stride () and 0 not in smallExpanded .stride ()
1048- and 0 not in largeExpandedClone .stride () and 0 not in small .stride ()):
1086+ if fn == "lerp" :
1087+ return t0_fn (t1 , 0.5 )
1088+ elif fn in fns_3_args :
1089+ return t0_fn (1.0 , t1 , t2 )
1090+ else :
1091+ return t0_fn (t1 )
1092+ r1 = tensorfn_inplace (largeExpanded , smallExpanded , small2Expanded )
1093+ r2 = tensorfn_inplace (largeExpandedClone , small , small2 )
1094+ # in-place pointwise operations don't actually work if the in-place
1095+ # tensor is 0-strided (numpy has the same issue)
1096+ if (0 not in largeExpanded .stride () and 0 not in largeExpandedClone .stride ()):
10491097 self .assertEqual (r1 , r2 )
10501098
1051- broadcastable = (dims_small == dims_full )
1052- if not broadcastable :
1053- if (fn not in fns_fallback ) or (fn in fns_fallback and small .numel () != largeExpanded .numel ()):
1054- self .assertRaises (RuntimeError , lambda : tensorfn_inplace (small , largeExpanded ))
1055- if (fn not in fns_fallback ) or (fn in fns_fallback and small .numel () != large .numel ()):
1056- self .assertRaises (RuntimeError , lambda : tensorfn_inplace (small , large ))
1099+ def broadcastable (t0 , t1 , t2 = None ):
1100+ try :
1101+ t1 .expand_as (t0 )
1102+ if t2 is not None :
1103+ t2 .expand_as (t0 )
1104+ except RuntimeError :
1105+ return False
1106+ return True
1107+
1108+ def _test_in_place_broadcastable (t0 , t1 , t2 = None ):
1109+ if not broadcastable (t0 , t1 , t2 ):
1110+ same_size = t0 .numel () == t1 .numel () and (t0 .numel () == t2 .numel () if t2 is not None else True )
1111+ if (fn not in fns_fallback ) or (fn in fns_fallback and not same_size ):
1112+ self .assertRaises (RuntimeError , lambda : tensorfn_inplace (t0 , t1 , t2 ))
1113+ else :
1114+ tensorfn_inplace (t0 , t1 , t2 )
1115+
1116+ if fn not in fns_3_args :
1117+ _test_in_place_broadcastable (small , largeExpanded )
1118+ _test_in_place_broadcastable (small , large )
1119+ else :
1120+ _test_in_place_broadcastable (small2 , smallExpanded , largeExpanded )
1121+ _test_in_place_broadcastable (small2 , small , large )
10571122
10581123 def test_broadcast (self ):
1059- for i in range (100 ):
1060- self ._test_broadcast (self , lambda t : t )
1124+ self ._test_broadcast (self , lambda t : t )
10611125
10621126 @staticmethod
10631127 def _test_broadcast_fallback (self , cast ):
1128+ # functions that should fallback to pointwise behavior
10641129 fns_fallback = ["add" , "sub" , "div" , "mul" , "pow" , "fmod" , "remainder" ,
1065- "eq" , "ge" , "gt" , "le" , "lt" , "max" , "min" , "ne" ]
1130+ "eq" , "ge" , "gt" , "le" , "lt" , "max" , "min" , "ne" ,
1131+ "addcdiv" , "addcmul" ]
1132+ # functions with three tensor arguments
1133+ fns_3_args = ["addcdiv" , "addcmul" ]
1134+
10661135 for fn in fns_fallback :
10671136 # case 1: both broadcastable and nElems equal -- verify that we broadcast
10681137 t0 = torch .randn (1 , 4 ).float ()
10691138 t0 = cast (t0 )
10701139 t1 = torch .randn (4 , 1 ).float ()
10711140 t1 = cast (t1 )
1141+ t2 = torch .randn (4 ).float ()
1142+ t2 = cast (t2 )
10721143 broadcastSize = torch .Size ([4 , 4 ])
10731144 t0_fn = getattr (t0 , fn )
10741145 t1_fn = getattr (t1 , fn )
10751146
1076- def tensorfn (myfn , t ):
1077- return myfn (t ) if fn != "lerp" else myfn (t , 0.5 )
1078- r0 = tensorfn (t0_fn , t1 )
1079- r1 = tensorfn (t1_fn , t0 )
1147+ def tensorfn (myfn , t1 , t2 ):
1148+ if fn == "lerp" :
1149+ return myfn (t1 , 0.5 )
1150+ elif fn in fns_3_args :
1151+ return myfn (1.0 , t1 , t2 )
1152+ else :
1153+ return myfn (t1 )
1154+ r0 = tensorfn (t0_fn , t1 , t2 )
1155+ r1 = tensorfn (t1_fn , t0 , t2 )
10801156 self .assertEqual (broadcastSize , r0 .size ())
10811157 self .assertEqual (broadcastSize , r1 .size ())
10821158
10831159 # case 2: broadcastable and not nElemes equal -- tested by test_fallback
1160+
10841161 # case 3: not broadcastable nElems equal -- verify we fallback
1085- t0 = torch .randn (1 , 4 ).float ()
1086- t1 = torch .randn (2 , 2 ).float ()
1162+ t0 = torch .randn (1 , 6 ).float ()
1163+ t1 = torch .randn (2 , 3 ).float ()
1164+ t2 = torch .randn (3 , 2 ).float ()
10871165 t0_fn = getattr (t0 , fn )
10881166 t1_fn = getattr (t1 , fn )
1167+ t2_fn = getattr (t2 , fn )
10891168
10901169 def verifyFallbackWarnings (w ):
10911170 self .assertEqual (len (w ), 1 )
10921171 self .assertTrue (issubclass (w [0 ].category , UserWarning ))
10931172 self .assertTrue ("Falling back" in str (w [0 ].message ))
10941173 with warnings .catch_warnings (record = True ) as w :
1095- r0 = tensorfn (t0_fn , t1 )
1174+ r0 = tensorfn (t0_fn , t1 , t2 )
1175+ verifyFallbackWarnings (w )
1176+ with warnings .catch_warnings (record = True ) as w :
1177+ r1 = tensorfn (t1_fn , t0 , t2 )
10961178 verifyFallbackWarnings (w )
10971179 with warnings .catch_warnings (record = True ) as w :
1098- r1 = tensorfn (t1_fn , t0 )
1180+ r2 = tensorfn (t2_fn , t0 , t1 )
10991181 verifyFallbackWarnings (w )
11001182 self .assertEqual (t0 .size (), r0 .size ())
11011183 self .assertEqual (t1 .size (), r1 .size ())
1184+ self .assertEqual (t2 .size (), r2 .size ())
11021185
11031186 # case 4: not broadcastable and not nEleme equal -- tested by test_fallback
11041187
0 commit comments