@@ -95,7 +95,6 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
9595 bool anytype = optype == NUMERICAL_ALL ? 1 : 0 ;
9696 if (mp_obj_is_type (oin , & ulab_ndarray_type )) {
9797 ndarray_obj_t * ndarray = MP_OBJ_TO_PTR (oin );
98- COMPLEX_DTYPE_NOT_IMPLEMENTED (ndarray -> dtype )
9998 uint8_t * array = (uint8_t * )ndarray -> array ;
10099 if (ndarray -> len == 0 ) { // return immediately with empty arrays
101100 if (optype == NUMERICAL_ALL ) {
@@ -132,33 +131,71 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
132131 size_t l = 0 ;
133132 if (axis == mp_const_none ) {
134133 do {
135- mp_float_t value = func (array );
136- if ((value != MICROPY_FLOAT_CONST (0.0 )) & !anytype ) {
137- // optype = NUMERICAL_ANY
138- return mp_const_true ;
139- } else if ((value == MICROPY_FLOAT_CONST (0.0 )) & anytype ) {
140- // optype == NUMERICAL_ALL
141- return mp_const_false ;
134+ #if ULAB_SUPPORTS_COMPLEX
135+ if (ndarray -> dtype == NDARRAY_COMPLEX ) {
136+ mp_float_t real = * ((mp_float_t * )array );
137+ mp_float_t imag = * ((mp_float_t * )(array + sizeof (mp_float_t )));
138+ if (((real != MICROPY_FLOAT_CONST (0.0 )) | (imag != MICROPY_FLOAT_CONST (0.0 ))) & !anytype ) {
139+ // optype = NUMERICAL_ANY
140+ return mp_const_true ;
141+ } else if (((real == MICROPY_FLOAT_CONST (0.0 )) & (imag == MICROPY_FLOAT_CONST (0.0 ))) & anytype ) {
142+ // optype == NUMERICAL_ALL
143+ return mp_const_false ;
144+ }
145+ } else {
146+ #endif
147+ mp_float_t value = func (array );
148+ if ((value != MICROPY_FLOAT_CONST (0.0 )) & !anytype ) {
149+ // optype = NUMERICAL_ANY
150+ return mp_const_true ;
151+ } else if ((value == MICROPY_FLOAT_CONST (0.0 )) & anytype ) {
152+ // optype == NUMERICAL_ALL
153+ return mp_const_false ;
154+ }
155+ #if ULAB_SUPPORTS_COMPLEX
142156 }
157+ #endif
143158 array += _shape_strides .strides [0 ];
144159 l ++ ;
145160 } while (l < _shape_strides .shape [0 ]);
146161 } else { // a scalar axis keyword was supplied
147162 do {
148- mp_float_t value = func (array );
149- if ((value != MICROPY_FLOAT_CONST (0.0 )) & !anytype ) {
150- // optype == NUMERICAL_ANY
151- * rarray = 1 ;
152- // since we are breaking out of the loop, move the pointer forward
153- array += _shape_strides .strides [0 ] * (_shape_strides .shape [0 ] - l );
154- break ;
155- } else if ((value == MICROPY_FLOAT_CONST (0.0 )) & anytype ) {
156- // optype == NUMERICAL_ALL
157- * rarray = 0 ;
158- // since we are breaking out of the loop, move the pointer forward
159- array += _shape_strides .strides [0 ] * (_shape_strides .shape [0 ] - l );
160- break ;
163+ #if ULAB_SUPPORTS_COMPLEX
164+ if (ndarray -> dtype == NDARRAY_COMPLEX ) {
165+ mp_float_t real = * ((mp_float_t * )array );
166+ mp_float_t imag = * ((mp_float_t * )(array + sizeof (mp_float_t )));
167+ if (((real != MICROPY_FLOAT_CONST (0.0 )) | (imag != MICROPY_FLOAT_CONST (0.0 ))) & !anytype ) {
168+ // optype = NUMERICAL_ANY
169+ * rarray = 1 ;
170+ // since we are breaking out of the loop, move the pointer forward
171+ array += _shape_strides .strides [0 ] * (_shape_strides .shape [0 ] - l );
172+ break ;
173+ } else if (((real == MICROPY_FLOAT_CONST (0.0 )) & (imag == MICROPY_FLOAT_CONST (0.0 ))) & anytype ) {
174+ // optype == NUMERICAL_ALL
175+ * rarray = 0 ;
176+ // since we are breaking out of the loop, move the pointer forward
177+ array += _shape_strides .strides [0 ] * (_shape_strides .shape [0 ] - l );
178+ break ;
179+ }
180+ } else {
181+ #endif
182+ mp_float_t value = func (array );
183+ if ((value != MICROPY_FLOAT_CONST (0.0 )) & !anytype ) {
184+ // optype == NUMERICAL_ANY
185+ * rarray = 1 ;
186+ // since we are breaking out of the loop, move the pointer forward
187+ array += _shape_strides .strides [0 ] * (_shape_strides .shape [0 ] - l );
188+ break ;
189+ } else if ((value == MICROPY_FLOAT_CONST (0.0 )) & anytype ) {
190+ // optype == NUMERICAL_ALL
191+ * rarray = 0 ;
192+ // since we are breaking out of the loop, move the pointer forward
193+ array += _shape_strides .strides [0 ] * (_shape_strides .shape [0 ] - l );
194+ break ;
195+ }
196+ #if ULAB_SUPPORTS_COMPLEX
161197 }
198+ #endif
162199 array += _shape_strides .strides [0 ];
163200 l ++ ;
164201 } while (l < _shape_strides .shape [0 ]);
0 commit comments