@@ -8,24 +8,24 @@ template <typename ExpandType>
88ExpandType *newForExpand (LIBRARY_STATE_TYPE_NOARGS);
99
1010template <typename TensorType>
11- int expand (LIBRARY_STATE_TYPE TensorType *r, TensorType *tensor, THLongStorage *sizes, int raiseErrors );
11+ void expand (LIBRARY_STATE_TYPE TensorType *r, TensorType *tensor, THLongStorage *sizes);
1212
1313template <typename TensorType>
14- int expand2 (LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
15- TensorType *e1 , TensorType *e2 , int raiseErrors );
14+ void expand2 (LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
15+ TensorType *e1 , TensorType *e2 );
1616
1717template <typename TensorType>
18- int expand3 (LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2, TensorType *r3,
19- TensorType *e1 , TensorType *e2 , TensorType *e3 , int raiseErrors );
18+ void expand3 (LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2, TensorType *r3,
19+ TensorType *e1 , TensorType *e2 , TensorType *e3 );
2020
2121template <typename ExpandType, typename TensorType>
2222void check_backincompat_expand_warn (ExpandType *to_expand, TensorType *tensor,
2323 char *to_expand_name, char *tensor_name, bool fallback,
24- ptrdiff_t to_expand_nElem, ptrdiff_t tensor_nElem, int to_expand_err ) {
24+ ptrdiff_t to_expand_nElem, ptrdiff_t tensor_nElem) {
2525 if (fallback && getBackCompatBroadcastWarn ()) {
2626 bool same_shape = THSize_isSameSizeAs (tensor->size , tensor->nDimension ,
2727 to_expand->size , to_expand->nDimension );
28- if (!same_shape && to_expand_err == 0 && (tensor_nElem == to_expand_nElem)) {
28+ if (!same_shape && (tensor_nElem == to_expand_nElem)) {
2929 std::ostringstream warn;
3030 warn << tensor_name << " and " << to_expand_name << " do not have the same shape, but are "
3131 << " broadcastable, and have the same number of elements. Changing behavior in a backwards incompatible "
@@ -36,120 +36,109 @@ void check_backincompat_expand_warn(ExpandType *to_expand, TensorType *tensor,
3636}
3737
3838template <typename ExpandType, typename TensorType>
39- int expand_inplace (LIBRARY_STATE_TYPE ExpandType *r, ExpandType *to_expand, TensorType *tensor,
40- char *to_expand_name, char *tensor_name, bool fallback,
41- THLongStorage *tensor_size, ptrdiff_t to_expand_nElem, ptrdiff_t tensor_nElem, bool raise) {
42- int ret = 0 ;
43-
44- int to_expand_err = expand<ExpandType>(LIBRARY_STATE r, to_expand, tensor_size, raise);
45- if (to_expand_err != 0 && !raise) {
46- ret = to_expand_err;
47- std::ostringstream warn;
48- warn << to_expand_name << " is not broadcastable to " << tensor_name
49- << " , but they have the same number of elements. Falling back to deprecated pointwise behavior." ;
50- PyErr_WarnEx (PyExc_UserWarning, warn.str ().c_str (), 1 );
39+ void expand_inplace (LIBRARY_STATE_TYPE ExpandType *r, ExpandType *to_expand, TensorType *tensor,
40+ char *to_expand_name, char *tensor_name, bool fallback,
41+ THLongStorage *tensor_size, ptrdiff_t to_expand_nElem, ptrdiff_t tensor_nElem, bool raise) {
42+ try {
43+ expand<ExpandType>(LIBRARY_STATE r, to_expand, tensor_size);
44+ } catch (std::exception &e) {
45+ if (!raise) {
46+ std::ostringstream warn;
47+ warn << to_expand_name << " is not broadcastable to " << tensor_name
48+ << " , but they have the same number of elements. Falling back to deprecated pointwise behavior." ;
49+ PyErr_WarnEx (PyExc_UserWarning, warn.str ().c_str (), 1 );
50+ }
51+ throw ;
5152 }
52-
53- return ret;
5453}
5554
5655template <typename ExpandType, typename TensorType>
57- int expand_inplace1 (LIBRARY_STATE_TYPE ExpandType *r, ExpandType *to_expand, TensorType *tensor,
58- char *to_expand_name, char *tensor_name, bool fallback) {
56+ void expand_inplace1 (LIBRARY_STATE_TYPE ExpandType *r, ExpandType *to_expand, TensorType *tensor,
57+ char *to_expand_name, char *tensor_name, bool fallback) {
5958 ptrdiff_t to_expand_nElem = THSize_nElement (to_expand->nDimension , to_expand->size );
6059 ptrdiff_t tensor_nElem = THSize_nElement (tensor->nDimension , tensor->size );
61- bool to_expand_raise = !fallback || (to_expand_nElem != tensor_nElem);
60+ bool to_expand_raise = !fallback || (to_expand_nElem != tensor_nElem) || to_expand_nElem == 0 ;
6261 THLongStoragePtr tensor_size (THLongStorage_newWithSize (tensor->nDimension ));
6362 THLongStorage_rawCopy (tensor_size.get (), tensor->size );
6463
65- int ret = expand_inplace (LIBRARY_STATE r, to_expand, tensor, to_expand_name, tensor_name, fallback,
66- tensor_size, to_expand_nElem, tensor_nElem, to_expand_raise);
67-
64+ expand_inplace (LIBRARY_STATE r, to_expand, tensor, to_expand_name, tensor_name, fallback,
65+ tensor_size, to_expand_nElem, tensor_nElem, to_expand_raise);
6866 check_backincompat_expand_warn<ExpandType, TensorType>(to_expand, tensor, to_expand_name, tensor_name, fallback,
69- to_expand_nElem, tensor_nElem, ret);
70-
71- return ret;
67+ to_expand_nElem, tensor_nElem);
7268}
7369
7470template <typename TensorType>
75- int expand_inplace2 (LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
76- TensorType *to_expand1, TensorType *to_expand2, TensorType *tensor,
77- char *to_expand1_name, char *to_expand2_name, char *tensor_name, bool fallback) {
71+ void expand_inplace2 (LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
72+ TensorType *to_expand1, TensorType *to_expand2, TensorType *tensor,
73+ char *to_expand1_name, char *to_expand2_name, char *tensor_name, bool fallback) {
7874 ptrdiff_t tensor_nElem = THSize_nElement (tensor->nDimension , tensor->size );
7975 ptrdiff_t to_expand1_nElem = THSize_nElement (to_expand1->nDimension , to_expand1->size );
8076 ptrdiff_t to_expand2_nElem = THSize_nElement (to_expand2->nDimension , to_expand2->size );
81- bool to_expand1_raise = !fallback || (tensor_nElem != to_expand1_nElem);
82- bool to_expand2_raise = !fallback || (tensor_nElem != to_expand2_nElem);
77+ bool to_expand1_raise = !fallback || (tensor_nElem != to_expand1_nElem) || tensor_nElem == 0 ;
78+ bool to_expand2_raise = !fallback || (tensor_nElem != to_expand2_nElem) || tensor_nElem == 0 ;
8379 THLongStoragePtr tensor_size (THLongStorage_newWithSize (tensor->nDimension ));
8480 THLongStorage_rawCopy (tensor_size.get (), tensor->size );
8581
86- int ret = expand_inplace (LIBRARY_STATE r1, to_expand1, tensor, to_expand1_name, tensor_name, fallback,
87- tensor_size, to_expand1_nElem, tensor_nElem, to_expand1_raise || to_expand2_raise);
88-
89- int ret2 = 0 ;
90- if (ret == 0 ) {
91- ret2 = expand_inplace (LIBRARY_STATE r2, to_expand2, tensor, to_expand2_name, tensor_name, fallback,
92- tensor_size, to_expand2_nElem, tensor_nElem, to_expand1_raise || to_expand2_raise);
93- }
82+ expand_inplace (LIBRARY_STATE r1, to_expand1, tensor, to_expand1_name, tensor_name, fallback,
83+ tensor_size, to_expand1_nElem, tensor_nElem, to_expand1_raise || to_expand2_raise);
84+ expand_inplace (LIBRARY_STATE r2, to_expand2, tensor, to_expand2_name, tensor_name, fallback,
85+ tensor_size, to_expand2_nElem, tensor_nElem, to_expand1_raise || to_expand2_raise);
9486
9587 check_backincompat_expand_warn<TensorType, TensorType>(to_expand1, tensor, to_expand1_name, tensor_name, fallback,
96- to_expand1_nElem, tensor_nElem, ret );
88+ to_expand1_nElem, tensor_nElem);
9789 check_backincompat_expand_warn<TensorType, TensorType>(to_expand2, tensor, to_expand2_name, tensor_name, fallback,
98- to_expand2_nElem, tensor_nElem, ret2);
99-
100- return ret == 0 && ret2 == 0 ? 0 : -1 ;
90+ to_expand2_nElem, tensor_nElem);
10191}
10292
10393template <typename TensorType>
104- int expand_outplace2 (LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
105- TensorType *to_expand1, TensorType *to_expand2,
106- char *to_expand1_name, char *to_expand2_name, bool fallback) {
94+ void expand_outplace2 (LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2,
95+ TensorType *to_expand1, TensorType *to_expand2,
96+ char *to_expand1_name, char *to_expand2_name, bool fallback) {
10797 ptrdiff_t to_expand1_nElem = THSize_nElement (to_expand1->nDimension , to_expand1->size );
10898 ptrdiff_t to_expand2_nElem = THSize_nElement (to_expand2->nDimension , to_expand2->size );
109- bool raise = !fallback || (to_expand1_nElem != to_expand2_nElem);
110-
111- int ret = 0 ;
112- int err = expand2<TensorType>(LIBRARY_STATE r1, r2, to_expand1, to_expand2, raise);
113- if (err != 0 && !raise) {
114- ret = err;
115- std::ostringstream warn;
116- warn << to_expand1_name << " and " << to_expand2_name << " not broadcastable, but have the same number of "
117- " elements. Falling back to deprecated pointwise behavior." ;
118- PyErr_WarnEx (PyExc_UserWarning, warn.str ().c_str (), 1 );
99+ bool raise = !fallback || (to_expand1_nElem != to_expand2_nElem) || to_expand1_nElem == 0 ;
100+ try {
101+ expand2<TensorType>(LIBRARY_STATE r1, r2, to_expand1, to_expand2);
102+ } catch (std::exception &e) {
103+ if (!raise) {
104+ std::ostringstream warn;
105+ warn << to_expand1_name << " and " << to_expand2_name << " not broadcastable, but have the same number of "
106+ << " elements. Falling back to deprecated pointwise behavior." ;
107+ PyErr_WarnEx (PyExc_UserWarning, warn.str ().c_str (), 1 );
108+ }
109+ throw ;
119110 }
111+
120112 check_backincompat_expand_warn<TensorType, TensorType>(to_expand1, to_expand2, to_expand1_name, to_expand2_name,
121- fallback, to_expand1_nElem, to_expand2_nElem, ret);
122- return ret;
113+ fallback, to_expand1_nElem, to_expand2_nElem);
123114}
124115
125116template <typename TensorType>
126- int expand_outplace3 (LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2, TensorType *r3,
127- TensorType *to_expand1, TensorType *to_expand2, TensorType *to_expand3,
128- char *to_expand1_name, char *to_expand2_name, char *to_expand3_name, bool fallback) {
117+ void expand_outplace3 (LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2, TensorType *r3,
118+ TensorType *to_expand1, TensorType *to_expand2, TensorType *to_expand3,
119+ char *to_expand1_name, char *to_expand2_name, char *to_expand3_name, bool fallback) {
129120 ptrdiff_t to_expand1_nElem = THSize_nElement (to_expand1->nDimension , to_expand1->size );
130121 ptrdiff_t to_expand2_nElem = THSize_nElement (to_expand2->nDimension , to_expand2->size );
131122 ptrdiff_t to_expand3_nElem = THSize_nElement (to_expand3->nDimension , to_expand3->size );
132- bool to_expand2_raise = !fallback || (to_expand1_nElem != to_expand2_nElem);
133- bool to_expand3_raise = !fallback || (to_expand1_nElem != to_expand2_nElem);
134-
135- int ret = 0 ;
136- int err = expand3<TensorType>(LIBRARY_STATE r1, r2, r3,
137- to_expand1, to_expand2, to_expand3,
138- to_expand2_raise || to_expand3_raise);
139-
140- if (err != 0 && !to_expand2_raise && !to_expand3_raise) {
141- ret = err;
142- std::ostringstream warn;
143- warn << to_expand1_name << " , " << to_expand2_name << " , and " << to_expand3_name << " not broadcastable,"
144- << " but have the same number of elements. Falling back to deprecated pointwise behavior." ;
145- PyErr_WarnEx (PyExc_UserWarning, warn.str ().c_str (), 1 );
123+ bool to_expand2_raise = !fallback || (to_expand1_nElem != to_expand2_nElem) || to_expand1_nElem == 0 ;
124+ bool to_expand3_raise = !fallback || (to_expand1_nElem != to_expand2_nElem) || to_expand1_nElem == 0 ;
125+
126+ try {
127+ expand3<TensorType>(LIBRARY_STATE r1, r2, r3, to_expand1, to_expand2, to_expand3);
128+ } catch (std::exception &e) {
129+ if (!to_expand2_raise && !to_expand3_raise) {
130+ std::ostringstream warn;
131+ warn << to_expand1_name << " , " << to_expand2_name << " , and " << to_expand3_name << " not broadcastable,"
132+ << " but have the same number of elements. Falling back to deprecated pointwise behavior." ;
133+ PyErr_WarnEx (PyExc_UserWarning, warn.str ().c_str (), 1 );
134+ }
135+ throw ;
146136 }
147137
148138 check_backincompat_expand_warn<TensorType, TensorType>(to_expand1, to_expand2, to_expand1_name, to_expand2_name,
149- fallback, to_expand1_nElem, to_expand2_nElem, ret );
139+ fallback, to_expand1_nElem, to_expand2_nElem);
150140 check_backincompat_expand_warn<TensorType, TensorType>(to_expand1, to_expand3, to_expand1_name, to_expand3_name,
151- fallback, to_expand1_nElem, to_expand3_nElem, ret);
152- return ret;
141+ fallback, to_expand1_nElem, to_expand3_nElem);
153142}
154143
155144#endif
0 commit comments