Skip to content

Commit 5c56c72

Browse files
gchanansoumith
authored andcommitted
Add broadcasting support for map_, map2_.
1 parent fd2554f commit 5c56c72

File tree

4 files changed

+123
-18
lines changed

4 files changed

+123
-18
lines changed

test/test_torch.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -985,20 +985,22 @@ def _test_broadcast(self, cast):
985985
"dist", "atan2", "pow", "lerp", "add",
986986
"sub", "mul", "div", "fmod", "remainder",
987987
"eq", "ge", "gt", "le", "lt", "max", "min", "ne",
988-
"addcdiv", "addcmul", "masked_copy", "masked_fill"
988+
"addcdiv", "addcmul", "masked_copy", "masked_fill",
989+
"map", "map2"
989990
]
990991
# functions with no torch. equivalent
991-
fns_no_torch = ["sub", "masked_copy", "masked_fill"]
992+
fns_no_torch = ["sub", "masked_copy", "masked_fill", "map", "map2"]
992993
# functions with no inplace equivalent
993994
fns_no_inplace = ["dist", "max", "min"]
994995
# functions with no out-of-place tensor version
995-
fns_no_out_place = ["masked_copy", "masked_fill"]
996+
fns_no_out_place = ["masked_copy", "masked_fill", "map", "map2"]
996997
# functions with fallback to equal nElem behavior
997998
fns_fallback = ["add", "sub", "div", "mul", "pow", "fmod", "remainder",
998999
"eq", "ge", "gt", "le", "lt", "max", "min", "ne",
999-
"addcdiv", "addcmul", "masked_copy", "masked_fill"]
1000+
"addcdiv", "addcmul", "masked_copy", "masked_fill",
1001+
"map", "map2"]
10001002
# functions with three tensor arguments
1001-
fns_3_args = ["addcdiv", "addcmul"]
1003+
fns_3_args = ["addcdiv", "addcmul", "map2"]
10021004

10031005
for fn in fns:
10041006
(dims_small, dims_large, dims_full) = TestTorch._select_broadcastable_dims(self)
@@ -1091,6 +1093,10 @@ def tensorfn_inplace(t0, t1, t2=None):
10911093
return t0_fn(t1 < 0.5, t1.expand_as(t0))
10921094
elif fn == "masked_fill":
10931095
return t0_fn(t1 < 0.5, 1.0)
1096+
elif fn == "map":
1097+
return t0_fn(t1, lambda x, y: x + y)
1098+
elif fn == "map2":
1099+
return t0_fn(t1, t2, lambda x, y,z : x + y + z)
10941100
elif fn in fns_3_args:
10951101
return t0_fn(1.0, t1, t2)
10961102
else:
@@ -1134,13 +1140,14 @@ def _test_broadcast_fallback(self, cast):
11341140
# functions that should fallback to pointwise behavior
11351141
fns_fallback = ["add", "sub", "div", "mul", "pow", "fmod", "remainder",
11361142
"eq", "ge", "gt", "le", "lt", "max", "min", "ne",
1137-
"addcdiv", "addcmul", "masked_copy", "masked_fill"]
1143+
"addcdiv", "addcmul", "masked_copy", "masked_fill",
1144+
"map", "map2"]
11381145
# functions with three tensor arguments
1139-
fns_3_args = ["addcdiv", "addcmul"]
1146+
fns_3_args = ["addcdiv", "addcmul", "map2"]
11401147
# functions with no inplace equivalent
11411148
fns_no_inplace = ["max", "min"]
11421149
# functions with no out-of-place tensor version
1143-
fns_no_out_place = ["masked_copy", "masked_fill"]
1150+
fns_no_out_place = ["masked_copy", "masked_fill", "map", "map2"]
11441151

11451152
for fn in fns_fallback:
11461153
# case 1: both broadcastable and nElems equal -- verify that we broadcast
@@ -1162,6 +1169,10 @@ def tensorfn(myfn, t1, t2):
11621169
return myfn(t1 < 0.5, torch.randn(4*4).float())
11631170
elif fn == "masked_fill":
11641171
return myfn(t1 < 0.5, 1.0)
1172+
elif fn == "map":
1173+
return myfn(t1, lambda x, y: x + y)
1174+
elif fn == "map2":
1175+
return myfn(t1, t2, lambda x, y, z: x + y + z)
11651176
elif fn in fns_3_args:
11661177
return myfn(1.0, t1, t2)
11671178
else:

tools/cwrap/plugins/Broadcast.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def getPreArgStringTemplate(self, includeElementCount=True, type=None):
5555
if type == None:
5656
ret = """THTensor *${arg_op_other}_save = ${arg_op_other};
5757
THTensorPtr ${arg_op_other}_guard = THTensor_(new)(LIBRARY_STATE_NOARGS);
58-
${arg_op_other}=${arg_op_other}_guard.get();"""
58+
${arg_op_other}=${arg_op_other}_guard.get();
59+
"""
5960
if includeElementCount:
6061
ret += "ptrdiff_t ${arg_op_other}_nElem = THTensor_(nElement)(LIBRARY_STATE ${arg_op_other}_save);"
6162
else:
@@ -165,9 +166,9 @@ def getInPlacePreExpand1Template(self, type=None):
165166
ret = """bool ${arg_op_other}_raise = ${raise_errors} || (${arg_op_a}_nElem != ${arg_op_other}_nElem);
166167
int ${arg_op_other}_err ="""
167168
if type == None:
168-
ret += """!skip_expand && THTensor_(expand)(LIBRARY_STATE"""
169+
ret += """!skip_expand && THTensor_(expand)(LIBRARY_STATE\n"""
169170
else:
170-
ret += """!skip_expand && THBroadcastTensor_(expand)(LIBRARY_STATE"""
171+
ret += """!skip_expand && THBroadcastTensor_(expand)(LIBRARY_STATE\n"""
171172

172173
ret += """ ${arg_op_other},
173174
${arg_op_other}_save,
@@ -187,9 +188,9 @@ def getInPlacePreExpand2Template(self, type1=None, type2=None):
187188
bool ${arg_op_other2}_raise = ${raise_errors} || (${arg_op_a}_nElem != ${arg_op_other2}_nElem);
188189
int ${arg_op_other1}_err ="""
189190
if type1 is None:
190-
ret += """!skip_expand && THTensor_(expand)(LIBRARY_STATE"""
191+
ret += """!skip_expand && THTensor_(expand)(LIBRARY_STATE\n"""
191192
else:
192-
ret += """!skip_expand && THBroadcastTensor_(expand)(LIBRARY_STATE"""
193+
ret += """!skip_expand && THBroadcastTensor_(expand)(LIBRARY_STATE\n"""
193194

194195
ret += """ ${arg_op_other1},
195196
${arg_op_other1}_save,
@@ -202,9 +203,9 @@ def getInPlacePreExpand2Template(self, type1=None, type2=None):
202203
ret += """}
203204
int ${arg_op_other2}_err ="""
204205
if type2 == None:
205-
ret += """!skip_expand && THTensor_(expand)(LIBRARY_STATE"""
206+
ret += """!skip_expand && THTensor_(expand)(LIBRARY_STATE\n"""
206207
else:
207-
ret += """!skip_expand && THBroadcastTensor_(expand)(LIBRARY_STATE"""
208+
ret += """!skip_expand && THBroadcastTensor_(expand)(LIBRARY_STATE\n"""
208209
ret += """ ${arg_op_other2},
209210
${arg_op_other2}_save,
210211
${arg_op_a}_size.get(),
@@ -317,7 +318,7 @@ def process_option_code_template(self, template, option):
317318
arg_op_other2=arg_op_c,
318319
post_code=post_code)
319320
expand_code += self.IN_PLACE_BACK_COMPAT_WARN_TEMPLATE.substitute(op_b_mapping)
320-
expand_code += self.IN_PLACE_BACK_COMPAT_WARN_TEMPLATE.substitute(op_b_mapping)
321+
expand_code += self.IN_PLACE_BACK_COMPAT_WARN_TEMPLATE.substitute(op_c_mapping)
321322
else:
322323
expand_code = self.getInPlacePreExpand1Template(type=type_op_b).substitute(op_b_mapping, post_code=post_code)
323324
expand_code += self.IN_PLACE_BACK_COMPAT_WARN_TEMPLATE.substitute(op_b_mapping)

torch/_tensor_docs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,8 +792,10 @@
792792
map_(tensor, callable)
793793
794794
Applies :attr:`callable` for each element in this tensor and the given tensor
795-
and stores the results in this tensor. The :attr:`callable` should have the
796-
signature::
795+
and stores the results in this tensor. This tensor and the given tensor must be
796+
:ref:`broadcastable <broadcasting-semantics>`.
797+
798+
The :attr:`callable` should have the signature::
797799
798800
def callable(a, b) -> number
799801
""")

torch/csrc/generic/methods/TensorApply.cwrap

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,36 @@ static PyObject * THPTensor_(map)(THPTensor *self, PyObject *args)
5757

5858
THTensor *tensor = self->cdata;
5959
THTensor *src = src_object->cdata;
60+
61+
THLongStoragePtr tensor_size = THTensor_(newSizeOf)(LIBRARY_STATE tensor);
62+
ptrdiff_t tensor_nElem = THTensor_(nElement)(LIBRARY_STATE tensor);
63+
bool skip_expand = false;
64+
THTensor *src_save = src;
65+
THTensorPtr src_guard = THTensor_(new)(LIBRARY_STATE_NOARGS);
66+
src=src_guard.get();
67+
ptrdiff_t src_nElem = THTensor_(nElement)(LIBRARY_STATE src_save);
68+
69+
bool src_raise = false || (tensor_nElem != src_nElem);
70+
int src_err =!skip_expand && THTensor_(expand)(LIBRARY_STATE
71+
src,
72+
src_save,
73+
tensor_size.get(),
74+
src_raise);
75+
if (src_err != 0 && !src_raise) {
76+
skip_expand = true; // don't do further expansions
77+
src = src_save;
78+
PyErr_WarnEx(PyExc_UserWarning, "src is not broadcastable to tensor, but they have the same number of "
79+
"elements. Falling back to deprecated pointwise behavior.", 1);
80+
}
81+
if (getBackCompatBroadcastWarn()) {
82+
bool same_shape = THSize_isSameSizeAs(tensor->size, tensor->nDimension,
83+
src_save->size, src_save->nDimension);
84+
if (!same_shape && src_err == 0 && (tensor_nElem == src_nElem) && !false) {
85+
PyErr_WarnEx(PyExc_UserWarning, "tensor and src do not have the same shape, but are broadcastable, and have the same number of "
86+
"elements. Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.", 1);
87+
}
88+
}
89+
6090
TH_TENSOR_APPLY2(real, tensor, real, src,
6191
PyObject *ret =
6292
PyObject_CallFunction(fn, (char*)(BUILD_REAL_FMT BUILD_REAL_FMT),
@@ -71,6 +101,8 @@ static PyObject * THPTensor_(map)(THPTensor *self, PyObject *args)
71101
Py_DECREF(ret);
72102
);
73103

104+
src = src_save;
105+
74106
Py_INCREF(self);
75107
return (PyObject*)self;
76108
END_HANDLE_TH_ERRORS
@@ -95,6 +127,62 @@ static PyObject * THPTensor_(map2)(THPTensor *self, PyObject *args)
95127
THTensor *tensor = self->cdata;
96128
THTensor *src1 = src1_object->cdata;
97129
THTensor *src2 = src2_object->cdata;
130+
131+
132+
THLongStoragePtr tensor_size = THTensor_(newSizeOf)(LIBRARY_STATE tensor);
133+
ptrdiff_t tensor_nElem = THTensor_(nElement)(LIBRARY_STATE tensor);
134+
bool skip_expand = false;
135+
THTensor *src1_save = src1;
136+
THTensorPtr src1_guard = THTensor_(new)(LIBRARY_STATE_NOARGS);
137+
src1=src1_guard.get();
138+
ptrdiff_t src1_nElem = THTensor_(nElement)(LIBRARY_STATE src1_save);
139+
THTensor *src2_save = src2;
140+
THTensorPtr src2_guard = THTensor_(new)(LIBRARY_STATE_NOARGS);
141+
src2=src2_guard.get();
142+
ptrdiff_t src2_nElem = THTensor_(nElement)(LIBRARY_STATE src2_save);
143+
bool src1_raise = false || (tensor_nElem != src1_nElem);
144+
bool src2_raise = false || (tensor_nElem != src2_nElem);
145+
int src1_err =!skip_expand && THTensor_(expand)(LIBRARY_STATE
146+
src1,
147+
src1_save,
148+
tensor_size.get(),
149+
src1_raise || src2_raise);
150+
if (src1_err != 0 && !(src1_raise || src2_raise)) {
151+
skip_expand = true; // don't do further expansions
152+
src1 = src1_save;
153+
src2 = src2_save;
154+
PyErr_WarnEx(PyExc_UserWarning, "src1, src2 are not broadcastable to self, but they all have the same number of "
155+
"elements. Falling back to deprecated pointwise behavior.", 1);
156+
}
157+
int src2_err =!skip_expand && THTensor_(expand)(LIBRARY_STATE
158+
src2,
159+
src2_save,
160+
tensor_size.get(),
161+
src1_raise || src2_raise);
162+
if (src2_err != 0 && !(src1_raise || src2_raise)) {
163+
skip_expand = true; // don't do further expansions
164+
src1 = src1_save;
165+
src2 = src2_save;
166+
PyErr_WarnEx(PyExc_UserWarning, "src1, src2 are not broadcastable to self, but they all have the same number of "
167+
"elements. Falling back to deprecated pointwise behavior.", 1);
168+
}
169+
if (getBackCompatBroadcastWarn()) {
170+
bool same_shape = THSize_isSameSizeAs(tensor->size, tensor->nDimension,
171+
src1_save->size, src1_save->nDimension);
172+
if (!same_shape && src1_err == 0 && (tensor_nElem == src1_nElem) && !false) {
173+
PyErr_WarnEx(PyExc_UserWarning, "tensor and src1 do not have the same shape, but are broadcastable, and have the same number of "
174+
"elements. Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.", 1);
175+
}
176+
}
177+
if (getBackCompatBroadcastWarn()) {
178+
bool same_shape = THSize_isSameSizeAs(tensor->size, tensor->nDimension,
179+
src1_save->size, src1_save->nDimension);
180+
if (!same_shape && src1_err == 0 && (tensor_nElem == src1_nElem) && !false) {
181+
PyErr_WarnEx(PyExc_UserWarning, "tensor and src2 do not have the same shape, but are broadcastable, and have the same number of "
182+
"elements. Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.", 1);
183+
}
184+
}
185+
98186
TH_TENSOR_APPLY3(real, tensor, real, src1, real, src2,
99187
PyObject *ret =
100188
PyObject_CallFunction(fn, (char*)(BUILD_REAL_FMT BUILD_REAL_FMT BUILD_REAL_FMT),
@@ -109,6 +197,9 @@ static PyObject * THPTensor_(map2)(THPTensor *self, PyObject *args)
109197
Py_DECREF(ret);
110198
);
111199

200+
src1 = src1_save;
201+
src2 = src2_save;
202+
112203
Py_INCREF(self);
113204
return (PyObject*)self;
114205
END_HANDLE_TH_ERRORS

0 commit comments

Comments
 (0)