Skip to content

Commit d1b3d40

Browse files
committed
add complex support to all/any
1 parent 8efdec7 commit d1b3d40

File tree

1 file changed

+58
-21
lines changed

1 file changed

+58
-21
lines changed

code/numpy/numerical.c

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
9595
bool anytype = optype == NUMERICAL_ALL ? 1 : 0;
9696
if(mp_obj_is_type(oin, &ulab_ndarray_type)) {
9797
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(oin);
98-
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
9998
uint8_t *array = (uint8_t *)ndarray->array;
10099
if(ndarray->len == 0) { // return immediately with empty arrays
101100
if(optype == NUMERICAL_ALL) {
@@ -132,33 +131,71 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
132131
size_t l = 0;
133132
if(axis == mp_const_none) {
134133
do {
135-
mp_float_t value = func(array);
136-
if((value != MICROPY_FLOAT_CONST(0.0)) & !anytype) {
137-
// optype = NUMERICAL_ANY
138-
return mp_const_true;
139-
} else if((value == MICROPY_FLOAT_CONST(0.0)) & anytype) {
140-
// optype == NUMERICAL_ALL
141-
return mp_const_false;
134+
#if ULAB_SUPPORTS_COMPLEX
135+
if(ndarray->dtype == NDARRAY_COMPLEX) {
136+
mp_float_t real = *((mp_float_t *)array);
137+
mp_float_t imag = *((mp_float_t *)(array + sizeof(mp_float_t)));
138+
if(((real != MICROPY_FLOAT_CONST(0.0)) | (imag != MICROPY_FLOAT_CONST(0.0))) & !anytype) {
139+
// optype = NUMERICAL_ANY
140+
return mp_const_true;
141+
} else if(((real == MICROPY_FLOAT_CONST(0.0)) & (imag == MICROPY_FLOAT_CONST(0.0))) & anytype) {
142+
// optype == NUMERICAL_ALL
143+
return mp_const_false;
144+
}
145+
} else {
146+
#endif
147+
mp_float_t value = func(array);
148+
if((value != MICROPY_FLOAT_CONST(0.0)) & !anytype) {
149+
// optype = NUMERICAL_ANY
150+
return mp_const_true;
151+
} else if((value == MICROPY_FLOAT_CONST(0.0)) & anytype) {
152+
// optype == NUMERICAL_ALL
153+
return mp_const_false;
154+
}
155+
#if ULAB_SUPPORTS_COMPLEX
142156
}
157+
#endif
143158
array += _shape_strides.strides[0];
144159
l++;
145160
} while(l < _shape_strides.shape[0]);
146161
} else { // a scalar axis keyword was supplied
147162
do {
148-
mp_float_t value = func(array);
149-
if((value != MICROPY_FLOAT_CONST(0.0)) & !anytype) {
150-
// optype == NUMERICAL_ANY
151-
*rarray = 1;
152-
// since we are breaking out of the loop, move the pointer forward
153-
array += _shape_strides.strides[0] * (_shape_strides.shape[0] - l);
154-
break;
155-
} else if((value == MICROPY_FLOAT_CONST(0.0)) & anytype) {
156-
// optype == NUMERICAL_ALL
157-
*rarray = 0;
158-
// since we are breaking out of the loop, move the pointer forward
159-
array += _shape_strides.strides[0] * (_shape_strides.shape[0] - l);
160-
break;
163+
#if ULAB_SUPPORTS_COMPLEX
164+
if(ndarray->dtype == NDARRAY_COMPLEX) {
165+
mp_float_t real = *((mp_float_t *)array);
166+
mp_float_t imag = *((mp_float_t *)(array + sizeof(mp_float_t)));
167+
if(((real != MICROPY_FLOAT_CONST(0.0)) | (imag != MICROPY_FLOAT_CONST(0.0))) & !anytype) {
168+
// optype = NUMERICAL_ANY
169+
*rarray = 1;
170+
// since we are breaking out of the loop, move the pointer forward
171+
array += _shape_strides.strides[0] * (_shape_strides.shape[0] - l);
172+
break;
173+
} else if(((real == MICROPY_FLOAT_CONST(0.0)) & (imag == MICROPY_FLOAT_CONST(0.0))) & anytype) {
174+
// optype == NUMERICAL_ALL
175+
*rarray = 0;
176+
// since we are breaking out of the loop, move the pointer forward
177+
array += _shape_strides.strides[0] * (_shape_strides.shape[0] - l);
178+
break;
179+
}
180+
} else {
181+
#endif
182+
mp_float_t value = func(array);
183+
if((value != MICROPY_FLOAT_CONST(0.0)) & !anytype) {
184+
// optype == NUMERICAL_ANY
185+
*rarray = 1;
186+
// since we are breaking out of the loop, move the pointer forward
187+
array += _shape_strides.strides[0] * (_shape_strides.shape[0] - l);
188+
break;
189+
} else if((value == MICROPY_FLOAT_CONST(0.0)) & anytype) {
190+
// optype == NUMERICAL_ALL
191+
*rarray = 0;
192+
// since we are breaking out of the loop, move the pointer forward
193+
array += _shape_strides.strides[0] * (_shape_strides.shape[0] - l);
194+
break;
195+
}
196+
#if ULAB_SUPPORTS_COMPLEX
161197
}
198+
#endif
162199
array += _shape_strides.strides[0];
163200
l++;
164201
} while(l < _shape_strides.shape[0]);

0 commit comments

Comments
 (0)