Skip to content

Commit e4c3270

Browse files
gchanansoumith
authored andcommitted
Remove raiseErrors from THTensor functions, have THStorage functions
take an error_buffer to return a proper error message while being able to handle memory management correctly from calling function.
1 parent c0a8304 commit e4c3270

File tree

12 files changed

+247
-230
lines changed

12 files changed

+247
-230
lines changed

tools/cwrap/plugins/Broadcast.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,21 +65,31 @@ def getPreArgStringTemplate(self, type=None):
6565
return Template(ret)
6666

6767
OUT_PLACE_PRE_EXPAND2_TEMPLATE = Template(
68-
"""if (!expand_outplace2(LIBRARY_STATE ${arg_op_a}_guard.get(), ${arg_op_other}_guard.get(),
69-
${arg_op_a}, ${arg_op_other},
70-
\"${op_a}\", \"${op_other}\", !${raise_errors})) {
68+
"""bool expand_success = false;
69+
try {
70+
expand_outplace2(LIBRARY_STATE ${arg_op_a}_guard.get(), ${arg_op_other}_guard.get(),
71+
${arg_op_a}, ${arg_op_other},
72+
\"${op_a}\", \"${op_other}\", !${raise_errors});
73+
expand_success = true;
74+
} catch (std::exception &e) {}
75+
if(expand_success) {
7176
${arg_op_a} = ${arg_op_a}_guard.get();
7277
${arg_op_other} = ${arg_op_other}_guard.get();
7378
}""")
7479

7580
OUT_PLACE_PRE_EXPAND3_TEMPLATE = Template(
76-
"""if (!expand_outplace3(LIBRARY_STATE
77-
${arg_op_a}_guard.get(), ${arg_op_other1}_guard.get(), ${arg_op_other2}_guard.get(),
78-
${arg_op_a}, ${arg_op_other1}, ${arg_op_other2},
79-
\"${op_a}\", \"${op_other1}\", \"${op_other2}\", !${raise_errors})) {
80-
${arg_op_a} = ${arg_op_a}_guard.get();
81-
${arg_op_other1} = ${arg_op_other1}_guard.get();
82-
${arg_op_other2} = ${arg_op_other2}_guard.get();
81+
"""bool expand_success = false;
82+
try {
83+
expand_outplace3(LIBRARY_STATE
84+
${arg_op_a}_guard.get(), ${arg_op_other1}_guard.get(), ${arg_op_other2}_guard.get(),
85+
${arg_op_a}, ${arg_op_other1}, ${arg_op_other2},
86+
\"${op_a}\", \"${op_other1}\", \"${op_other2}\", !${raise_errors});
87+
expand_success = true;
88+
} catch (std::exception &e) {}
89+
if (expand_success) {
90+
${arg_op_a} = ${arg_op_a}_guard.get();
91+
${arg_op_other1} = ${arg_op_other1}_guard.get();
92+
${arg_op_other2} = ${arg_op_other2}_guard.get();
8393
}""")
8494

8595
OUT_PLACE_PRE_EXPAND_PRE_DIM_TEMPLATE = Template(
@@ -101,7 +111,12 @@ def getPreArgStringTemplate(self, type=None):
101111
THLongStorage_newWithSize3(${arg_op_a}_dim0_size, ${arg_op_a}_dim1_size, ${arg_op_a}_dim2_size));\n""")
102112

103113
OUT_PLACE_PRE_EXPAND_POST_DIM_TEMPLATE = Template(
104-
"""if (!THTensor_(expand)(LIBRARY_STATE ${arg_op_a}_guard.get(), ${arg_op_a}, ${arg_op_a}_storage, ${raise_errors})) {
114+
"""bool expand_success = false;
115+
try {
116+
expand(LIBRARY_STATE ${arg_op_a}_guard.get(), ${arg_op_a}, ${arg_op_a}_storage);
117+
expand_success = true;
118+
} catch (std::exception &e) {}
119+
if (expand_success) {
105120
${arg_op_a} = ${arg_op_a}_guard.get();
106121
}""")
107122

@@ -110,15 +125,25 @@ def getPreArgStringTemplate(self, type=None):
110125
${expand_code}""")
111126

112127
IN_PLACE_PRE_EXPAND1_TEMPLATE = Template(
113-
"""if (!expand_inplace1(LIBRARY_STATE ${arg_op_other}_guard.get(), ${arg_op_other}, ${arg_op_a},
114-
\"${op_other}\", \"${op_a}\", !${raise_errors})) {
115-
${arg_op_other} = ${arg_op_other}_guard.get();
116-
}""")
128+
"""bool expand_success = false;
129+
try {
130+
expand_inplace1(LIBRARY_STATE ${arg_op_other}_guard.get(), ${arg_op_other}, ${arg_op_a},
131+
\"${op_other}\", \"${op_a}\", !${raise_errors});
132+
expand_success = true;
133+
} catch (std::exception &e) {}
134+
if (expand_success) {
135+
${arg_op_other} = ${arg_op_other}_guard.get();
136+
}""")
117137

118138
IN_PLACE_PRE_EXPAND2_TEMPLATE = Template(
119-
"""if (!expand_inplace2(LIBRARY_STATE ${arg_op_other1}_guard.get(), ${arg_op_other2}_guard.get(),
120-
${arg_op_other1}, ${arg_op_other2}, ${arg_op_a},
121-
\"${op_other1}\", \"${op_other2}\", \"${op_a}\", !${raise_errors})) {
139+
"""bool expand_success = false;
140+
try {
141+
expand_inplace2(LIBRARY_STATE ${arg_op_other1}_guard.get(), ${arg_op_other2}_guard.get(),
142+
${arg_op_other1}, ${arg_op_other2}, ${arg_op_a},
143+
\"${op_other1}\", \"${op_other2}\", \"${op_a}\", !${raise_errors});
144+
expand_success = true;
145+
} catch (std::exception &e) {}
146+
if (expand_success) {
122147
${arg_op_other1} = ${arg_op_other1}_guard.get();
123148
${arg_op_other2} = ${arg_op_other2}_guard.get();
124149
}""")
@@ -158,6 +183,7 @@ def process_option_code_template(self, template, option):
158183
dims_kvs = []
159184
for p in params:
160185
if p.startswith("dims:"):
186+
assert(raise_errors == "true")
161187
if len(dims_kvs) != 0:
162188
raise ValueError("multiple specifications of dims")
163189
dims = p[len("dims:"):].split(",")

torch/csrc/Module.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,11 @@ PyObject *THPModule_inferSize(PyObject *_unused, PyObject *args)
485485
THLongStoragePtr sizes_guard(THLongStorage_new());
486486
THLongStorage *sizes = sizes_guard.get();
487487

488-
THLongStorage_inferSize2(sizes, size1->data, size1->size, size2->data, size2->size, 1);
488+
char error_buffer[1024];
489+
int ret = THLongStorage_inferSize2(sizes, size1->data, size1->size, size2->data, size2->size, error_buffer, 1024);
490+
if (ret != 0) {
491+
THError(error_buffer);
492+
}
489493
return THPSize_New(sizes->size, sizes->data);
490494
END_HANDLE_TH_ERRORS
491495
}

torch/csrc/copy_utils.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,12 @@ void THPInsertTensorCopyFunction(
118118

119119
// support for "broadcast" parameter to copy_.
120120
if (broadcast) {
121-
int ret = expand_inplace1<TensorSrc, TensorDst>(LIBRARY_STATE src_guard.get(), src, dst, "src", "dst", true);
122-
if (ret == 0) {
121+
bool expand_success = false;
122+
try {
123+
expand_inplace1<TensorSrc, TensorDst>(LIBRARY_STATE src_guard.get(), src, dst, "src", "dst", true);
124+
expand_success = true;
125+
} catch (std::exception &e) {}
126+
if (expand_success) {
123127
src = src_guard.get();
124128
}
125129
}

torch/csrc/expand_utils.h

Lines changed: 71 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,24 @@ template <typename ExpandType>
88
ExpandType *newForExpand(LIBRARY_STATE_TYPE_NOARGS);
99

1010
template <typename TensorType>
11-
int expand(LIBRARY_STATE_TYPE TensorType *r, TensorType *tensor, THLongStorage *sizes, int raiseErrors);
11+
void expand(LIBRARY_STATE_TYPE TensorType *r, TensorType *tensor, THLongStorage *sizes);
1212

1313
template <typename TensorType>
14-
int expand2(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
15-
TensorType *e1, TensorType *e2, int raiseErrors);
14+
void expand2(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
15+
TensorType *e1, TensorType *e2);
1616

1717
template <typename TensorType>
18-
int expand3(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2, TensorType *r3,
19-
TensorType *e1, TensorType *e2, TensorType *e3, int raiseErrors);
18+
void expand3(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2, TensorType *r3,
19+
TensorType *e1, TensorType *e2, TensorType *e3);
2020

2121
template <typename ExpandType, typename TensorType>
2222
void check_backincompat_expand_warn(ExpandType *to_expand, TensorType *tensor,
2323
char *to_expand_name, char *tensor_name, bool fallback,
24-
ptrdiff_t to_expand_nElem, ptrdiff_t tensor_nElem, int to_expand_err) {
24+
ptrdiff_t to_expand_nElem, ptrdiff_t tensor_nElem) {
2525
if (fallback && getBackCompatBroadcastWarn()) {
2626
bool same_shape = THSize_isSameSizeAs(tensor->size, tensor->nDimension,
2727
to_expand->size, to_expand->nDimension);
28-
if (!same_shape && to_expand_err == 0 && (tensor_nElem == to_expand_nElem)) {
28+
if (!same_shape && (tensor_nElem == to_expand_nElem)) {
2929
std::ostringstream warn;
3030
warn << tensor_name << " and " << to_expand_name << " do not have the same shape, but are "
3131
<< "broadcastable, and have the same number of elements. Changing behavior in a backwards incompatible "
@@ -36,120 +36,109 @@ void check_backincompat_expand_warn(ExpandType *to_expand, TensorType *tensor,
3636
}
3737

3838
template <typename ExpandType, typename TensorType>
39-
int expand_inplace(LIBRARY_STATE_TYPE ExpandType *r, ExpandType *to_expand, TensorType *tensor,
40-
char *to_expand_name, char *tensor_name, bool fallback,
41-
THLongStorage *tensor_size, ptrdiff_t to_expand_nElem, ptrdiff_t tensor_nElem, bool raise) {
42-
int ret = 0;
43-
44-
int to_expand_err = expand<ExpandType>(LIBRARY_STATE r, to_expand, tensor_size, raise);
45-
if (to_expand_err != 0 && !raise) {
46-
ret = to_expand_err;
47-
std::ostringstream warn;
48-
warn << to_expand_name << " is not broadcastable to " << tensor_name
49-
<< ", but they have the same number of elements. Falling back to deprecated pointwise behavior.";
50-
PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1);
39+
void expand_inplace(LIBRARY_STATE_TYPE ExpandType *r, ExpandType *to_expand, TensorType *tensor,
40+
char *to_expand_name, char *tensor_name, bool fallback,
41+
THLongStorage *tensor_size, ptrdiff_t to_expand_nElem, ptrdiff_t tensor_nElem, bool raise) {
42+
try {
43+
expand<ExpandType>(LIBRARY_STATE r, to_expand, tensor_size);
44+
} catch (std::exception &e) {
45+
if (!raise) {
46+
std::ostringstream warn;
47+
warn << to_expand_name << " is not broadcastable to " << tensor_name
48+
<< ", but they have the same number of elements. Falling back to deprecated pointwise behavior.";
49+
PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1);
50+
}
51+
throw;
5152
}
52-
53-
return ret;
5453
}
5554

5655
template <typename ExpandType, typename TensorType>
57-
int expand_inplace1(LIBRARY_STATE_TYPE ExpandType *r, ExpandType *to_expand, TensorType *tensor,
58-
char *to_expand_name, char *tensor_name, bool fallback) {
56+
void expand_inplace1(LIBRARY_STATE_TYPE ExpandType *r, ExpandType *to_expand, TensorType *tensor,
57+
char *to_expand_name, char *tensor_name, bool fallback) {
5958
ptrdiff_t to_expand_nElem = THSize_nElement(to_expand->nDimension, to_expand->size);
6059
ptrdiff_t tensor_nElem = THSize_nElement(tensor->nDimension, tensor->size);
61-
bool to_expand_raise = !fallback || (to_expand_nElem != tensor_nElem);
60+
bool to_expand_raise = !fallback || (to_expand_nElem != tensor_nElem) || to_expand_nElem == 0;
6261
THLongStoragePtr tensor_size(THLongStorage_newWithSize(tensor->nDimension));
6362
THLongStorage_rawCopy(tensor_size.get(), tensor->size);
6463

65-
int ret = expand_inplace(LIBRARY_STATE r, to_expand, tensor, to_expand_name, tensor_name, fallback,
66-
tensor_size, to_expand_nElem, tensor_nElem, to_expand_raise);
67-
64+
expand_inplace(LIBRARY_STATE r, to_expand, tensor, to_expand_name, tensor_name, fallback,
65+
tensor_size, to_expand_nElem, tensor_nElem, to_expand_raise);
6866
check_backincompat_expand_warn<ExpandType, TensorType>(to_expand, tensor, to_expand_name, tensor_name, fallback,
69-
to_expand_nElem, tensor_nElem, ret);
70-
71-
return ret;
67+
to_expand_nElem, tensor_nElem);
7268
}
7369

7470
template <typename TensorType>
75-
int expand_inplace2(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
76-
TensorType *to_expand1, TensorType *to_expand2, TensorType *tensor,
77-
char *to_expand1_name, char *to_expand2_name, char *tensor_name, bool fallback) {
71+
void expand_inplace2(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
72+
TensorType *to_expand1, TensorType *to_expand2, TensorType *tensor,
73+
char *to_expand1_name, char *to_expand2_name, char *tensor_name, bool fallback) {
7874
ptrdiff_t tensor_nElem = THSize_nElement(tensor->nDimension, tensor->size);
7975
ptrdiff_t to_expand1_nElem = THSize_nElement(to_expand1->nDimension, to_expand1->size);
8076
ptrdiff_t to_expand2_nElem = THSize_nElement(to_expand2->nDimension, to_expand2->size);
81-
bool to_expand1_raise = !fallback || (tensor_nElem != to_expand1_nElem);
82-
bool to_expand2_raise = !fallback || (tensor_nElem != to_expand2_nElem);
77+
bool to_expand1_raise = !fallback || (tensor_nElem != to_expand1_nElem) || tensor_nElem == 0;
78+
bool to_expand2_raise = !fallback || (tensor_nElem != to_expand2_nElem) || tensor_nElem == 0;
8379
THLongStoragePtr tensor_size(THLongStorage_newWithSize(tensor->nDimension));
8480
THLongStorage_rawCopy(tensor_size.get(), tensor->size);
8581

86-
int ret = expand_inplace(LIBRARY_STATE r1, to_expand1, tensor, to_expand1_name, tensor_name, fallback,
87-
tensor_size, to_expand1_nElem, tensor_nElem, to_expand1_raise || to_expand2_raise);
88-
89-
int ret2 = 0;
90-
if (ret == 0) {
91-
ret2 = expand_inplace(LIBRARY_STATE r2, to_expand2, tensor, to_expand2_name, tensor_name, fallback,
92-
tensor_size, to_expand2_nElem, tensor_nElem, to_expand1_raise || to_expand2_raise);
93-
}
82+
expand_inplace(LIBRARY_STATE r1, to_expand1, tensor, to_expand1_name, tensor_name, fallback,
83+
tensor_size, to_expand1_nElem, tensor_nElem, to_expand1_raise || to_expand2_raise);
84+
expand_inplace(LIBRARY_STATE r2, to_expand2, tensor, to_expand2_name, tensor_name, fallback,
85+
tensor_size, to_expand2_nElem, tensor_nElem, to_expand1_raise || to_expand2_raise);
9486

9587
check_backincompat_expand_warn<TensorType, TensorType>(to_expand1, tensor, to_expand1_name, tensor_name, fallback,
96-
to_expand1_nElem, tensor_nElem, ret);
88+
to_expand1_nElem, tensor_nElem);
9789
check_backincompat_expand_warn<TensorType, TensorType>(to_expand2, tensor, to_expand2_name, tensor_name, fallback,
98-
to_expand2_nElem, tensor_nElem, ret2);
99-
100-
return ret == 0 && ret2 == 0 ? 0 : -1;
90+
to_expand2_nElem, tensor_nElem);
10191
}
10292

10393
template <typename TensorType>
104-
int expand_outplace2(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
105-
TensorType *to_expand1, TensorType *to_expand2,
106-
char *to_expand1_name, char *to_expand2_name, bool fallback) {
94+
void expand_outplace2(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
95+
TensorType *to_expand1, TensorType *to_expand2,
96+
char *to_expand1_name, char *to_expand2_name, bool fallback) {
10797
ptrdiff_t to_expand1_nElem = THSize_nElement(to_expand1->nDimension, to_expand1->size);
10898
ptrdiff_t to_expand2_nElem = THSize_nElement(to_expand2->nDimension, to_expand2->size);
109-
bool raise = !fallback || (to_expand1_nElem != to_expand2_nElem);
110-
111-
int ret = 0;
112-
int err = expand2<TensorType>(LIBRARY_STATE r1, r2, to_expand1, to_expand2, raise);
113-
if (err != 0 && !raise) {
114-
ret = err;
115-
std::ostringstream warn;
116-
warn << to_expand1_name << " and " << to_expand2_name << " not broadcastable, but have the same number of "
117-
"elements. Falling back to deprecated pointwise behavior.";
118-
PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1);
99+
bool raise = !fallback || (to_expand1_nElem != to_expand2_nElem) || to_expand1_nElem == 0;
100+
try {
101+
expand2<TensorType>(LIBRARY_STATE r1, r2, to_expand1, to_expand2);
102+
} catch (std::exception &e) {
103+
if (!raise) {
104+
std::ostringstream warn;
105+
warn << to_expand1_name << " and " << to_expand2_name << " not broadcastable, but have the same number of "
106+
<< "elements. Falling back to deprecated pointwise behavior.";
107+
PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1);
108+
}
109+
throw;
119110
}
111+
120112
check_backincompat_expand_warn<TensorType, TensorType>(to_expand1, to_expand2, to_expand1_name, to_expand2_name,
121-
fallback, to_expand1_nElem, to_expand2_nElem, ret);
122-
return ret;
113+
fallback, to_expand1_nElem, to_expand2_nElem);
123114
}
124115

125116
template <typename TensorType>
126-
int expand_outplace3(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2, TensorType *r3,
127-
TensorType *to_expand1, TensorType *to_expand2, TensorType *to_expand3,
128-
char *to_expand1_name, char *to_expand2_name, char *to_expand3_name, bool fallback) {
117+
void expand_outplace3(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2, TensorType *r3,
118+
TensorType *to_expand1, TensorType *to_expand2, TensorType *to_expand3,
119+
char *to_expand1_name, char *to_expand2_name, char *to_expand3_name, bool fallback) {
129120
ptrdiff_t to_expand1_nElem = THSize_nElement(to_expand1->nDimension, to_expand1->size);
130121
ptrdiff_t to_expand2_nElem = THSize_nElement(to_expand2->nDimension, to_expand2->size);
131122
ptrdiff_t to_expand3_nElem = THSize_nElement(to_expand3->nDimension, to_expand3->size);
132-
bool to_expand2_raise = !fallback || (to_expand1_nElem != to_expand2_nElem);
133-
bool to_expand3_raise = !fallback || (to_expand1_nElem != to_expand2_nElem);
134-
135-
int ret = 0;
136-
int err = expand3<TensorType>(LIBRARY_STATE r1, r2, r3,
137-
to_expand1, to_expand2, to_expand3,
138-
to_expand2_raise || to_expand3_raise);
139-
140-
if (err != 0 && !to_expand2_raise && !to_expand3_raise) {
141-
ret = err;
142-
std::ostringstream warn;
143-
warn << to_expand1_name << ", " << to_expand2_name << ", and " << to_expand3_name << " not broadcastable,"
144-
<< " but have the same number of elements. Falling back to deprecated pointwise behavior.";
145-
PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1);
123+
bool to_expand2_raise = !fallback || (to_expand1_nElem != to_expand2_nElem) || to_expand1_nElem == 0;
124+
bool to_expand3_raise = !fallback || (to_expand1_nElem != to_expand2_nElem) || to_expand1_nElem == 0;
125+
126+
try {
127+
expand3<TensorType>(LIBRARY_STATE r1, r2, r3, to_expand1, to_expand2, to_expand3);
128+
} catch (std::exception &e) {
129+
if(!to_expand2_raise && !to_expand3_raise) {
130+
std::ostringstream warn;
131+
warn << to_expand1_name << ", " << to_expand2_name << ", and " << to_expand3_name << " not broadcastable,"
132+
<< " but have the same number of elements. Falling back to deprecated pointwise behavior.";
133+
PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1);
134+
}
135+
throw;
146136
}
147137

148138
check_backincompat_expand_warn<TensorType, TensorType>(to_expand1, to_expand2, to_expand1_name, to_expand2_name,
149-
fallback, to_expand1_nElem, to_expand2_nElem, ret);
139+
fallback, to_expand1_nElem, to_expand2_nElem);
150140
check_backincompat_expand_warn<TensorType, TensorType>(to_expand1, to_expand3, to_expand1_name, to_expand3_name,
151-
fallback, to_expand1_nElem, to_expand3_nElem, ret);
152-
return ret;
141+
fallback, to_expand1_nElem, to_expand3_nElem);
153142
}
154143

155144
#endif

0 commit comments

Comments
 (0)