Skip to content

Commit 8efdec7

Browse files
committed
extend convolve for the complex case
1 parent 6a7d20d commit 8efdec7

File tree

1 file changed

+56
-11
lines changed

1 file changed

+56
-11
lines changed

code/numpy/filter.c

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ mp_obj_t filter_convolve(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_a
4141

4242
ndarray_obj_t *a = MP_OBJ_TO_PTR(args[0].u_obj);
4343
ndarray_obj_t *c = MP_OBJ_TO_PTR(args[1].u_obj);
44-
COMPLEX_DTYPE_NOT_IMPLEMENTED(a->dtype)
45-
COMPLEX_DTYPE_NOT_IMPLEMENTED(c->dtype)
4644
// deal with linear arrays only
4745
#if ULAB_MAX_DIMS > 1
4846
if((a->ndim != 1) || (c->ndim != 1)) {
@@ -56,30 +54,77 @@ mp_obj_t filter_convolve(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_a
5654
}
5755

5856
int len = len_a + len_c - 1; // convolve mode "full"
59-
ndarray_obj_t *out = ndarray_new_linear_array(len, NDARRAY_FLOAT);
60-
mp_float_t *outptr = (mp_float_t *)out->array;
57+
int32_t off = len_c - 1;
58+
uint8_t dtype = NDARRAY_FLOAT;
59+
60+
#if ULAB_SUPPORTS_COMPLEX
61+
if((a->dtype == NDARRAY_COMPLEX) || (c->dtype == NDARRAY_COMPLEX)) {
62+
dtype = NDARRAY_COMPLEX;
63+
}
64+
#endif
65+
ndarray_obj_t *ndarray = ndarray_new_linear_array(len, dtype);
66+
mp_float_t *array = (mp_float_t *)ndarray->array;
67+
6168
uint8_t *aarray = (uint8_t *)a->array;
6269
uint8_t *carray = (uint8_t *)c->array;
6370

64-
int32_t off = len_c - 1;
6571
int32_t as = a->strides[ULAB_MAX_DIMS - 1] / a->itemsize;
6672
int32_t cs = c->strides[ULAB_MAX_DIMS - 1] / c->itemsize;
6773

68-
for(int32_t k=-off; k < len-off; k++) {
69-
mp_float_t accum = (mp_float_t)0.0;
74+
75+
#if ULAB_SUPPORTS_COMPLEX
76+
if(dtype == NDARRAY_COMPLEX) {
77+
mp_float_t a_real, a_imag;
78+
mp_float_t c_real, c_imag = MICROPY_FLOAT_CONST(0.0);
79+
for(int32_t k = -off; k < len-off; k++) {
80+
mp_float_t accum_real = MICROPY_FLOAT_CONST(0.0);
81+
mp_float_t accum_imag = MICROPY_FLOAT_CONST(0.0);
82+
83+
int32_t top_n = MIN(len_c, len_a - k);
84+
int32_t bot_n = MAX(-k, 0);
85+
86+
for(int32_t n = bot_n; n < top_n; n++) {
87+
int32_t idx_c = (len_c - n - 1) * cs;
88+
int32_t idx_a = (n + k) * as;
89+
if(a->dtype != NDARRAY_COMPLEX) {
90+
a_real = ndarray_get_float_index(aarray, a->dtype, idx_a);
91+
a_imag = MICROPY_FLOAT_CONST(0.0);
92+
} else {
93+
a_real = ndarray_get_float_index(aarray, NDARRAY_FLOAT, 2 * idx_a);
94+
a_imag = ndarray_get_float_index(aarray, NDARRAY_FLOAT, 2 * idx_a + 1);
95+
}
96+
97+
if(c->dtype != NDARRAY_COMPLEX) {
98+
c_real = ndarray_get_float_index(carray, c->dtype, idx_c);
99+
c_imag = MICROPY_FLOAT_CONST(0.0);
100+
} else {
101+
c_real = ndarray_get_float_index(carray, NDARRAY_FLOAT, 2 * idx_c);
102+
c_imag = ndarray_get_float_index(carray, NDARRAY_FLOAT, 2 * idx_c + 1);
103+
}
104+
accum_real += a_real * c_real - a_imag * c_imag;
105+
accum_imag += a_real * c_imag + a_imag * c_real;
106+
}
107+
*array++ = accum_real;
108+
*array++ = accum_imag;
109+
}
110+
return MP_OBJ_FROM_PTR(ndarray);
111+
}
112+
#endif
113+
114+
for(int32_t k = -off; k < len-off; k++) {
115+
mp_float_t accum = MICROPY_FLOAT_CONST(0.0);
70116
int32_t top_n = MIN(len_c, len_a - k);
71117
int32_t bot_n = MAX(-k, 0);
72-
for(int32_t n=bot_n; n < top_n; n++) {
118+
for(int32_t n = bot_n; n < top_n; n++) {
73119
int32_t idx_c = (len_c - n - 1) * cs;
74120
int32_t idx_a = (n + k) * as;
75121
mp_float_t ai = ndarray_get_float_index(aarray, a->dtype, idx_a);
76122
mp_float_t ci = ndarray_get_float_index(carray, c->dtype, idx_c);
77123
accum += ai * ci;
78124
}
79-
*outptr++ = accum;
125+
*array++ = accum;
80126
}
81-
82-
return out;
127+
return MP_OBJ_FROM_PTR(ndarray);
83128
}
84129

85130
MP_DEFINE_CONST_FUN_OBJ_KW(filter_convolve_obj, 2, filter_convolve);

0 commit comments

Comments
 (0)