@@ -41,8 +41,6 @@ mp_obj_t filter_convolve(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_a
4141
4242 ndarray_obj_t * a = MP_OBJ_TO_PTR (args [0 ].u_obj );
4343 ndarray_obj_t * c = MP_OBJ_TO_PTR (args [1 ].u_obj );
44- COMPLEX_DTYPE_NOT_IMPLEMENTED (a -> dtype )
45- COMPLEX_DTYPE_NOT_IMPLEMENTED (c -> dtype )
4644 // deal with linear arrays only
4745 #if ULAB_MAX_DIMS > 1
4846 if ((a -> ndim != 1 ) || (c -> ndim != 1 )) {
@@ -56,30 +54,77 @@ mp_obj_t filter_convolve(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_a
5654 }
5755
5856 int len = len_a + len_c - 1 ; // convolve mode "full"
59- ndarray_obj_t * out = ndarray_new_linear_array (len , NDARRAY_FLOAT );
60- mp_float_t * outptr = (mp_float_t * )out -> array ;
57+ int32_t off = len_c - 1 ;
58+ uint8_t dtype = NDARRAY_FLOAT ;
59+
60+ #if ULAB_SUPPORTS_COMPLEX
61+ if ((a -> dtype == NDARRAY_COMPLEX ) || (c -> dtype == NDARRAY_COMPLEX )) {
62+ dtype = NDARRAY_COMPLEX ;
63+ }
64+ #endif
65+ ndarray_obj_t * ndarray = ndarray_new_linear_array (len , dtype );
66+ mp_float_t * array = (mp_float_t * )ndarray -> array ;
67+
6168 uint8_t * aarray = (uint8_t * )a -> array ;
6269 uint8_t * carray = (uint8_t * )c -> array ;
6370
64- int32_t off = len_c - 1 ;
6571 int32_t as = a -> strides [ULAB_MAX_DIMS - 1 ] / a -> itemsize ;
6672 int32_t cs = c -> strides [ULAB_MAX_DIMS - 1 ] / c -> itemsize ;
6773
68- for (int32_t k = - off ; k < len - off ; k ++ ) {
69- mp_float_t accum = (mp_float_t )0.0 ;
74+
75+ #if ULAB_SUPPORTS_COMPLEX
76+ if (dtype == NDARRAY_COMPLEX ) {
77+ mp_float_t a_real , a_imag ;
78+ mp_float_t c_real , c_imag = MICROPY_FLOAT_CONST (0.0 );
79+ for (int32_t k = - off ; k < len - off ; k ++ ) {
80+ mp_float_t accum_real = MICROPY_FLOAT_CONST (0.0 );
81+ mp_float_t accum_imag = MICROPY_FLOAT_CONST (0.0 );
82+
83+ int32_t top_n = MIN (len_c , len_a - k );
84+ int32_t bot_n = MAX (- k , 0 );
85+
86+ for (int32_t n = bot_n ; n < top_n ; n ++ ) {
87+ int32_t idx_c = (len_c - n - 1 ) * cs ;
88+ int32_t idx_a = (n + k ) * as ;
89+ if (a -> dtype != NDARRAY_COMPLEX ) {
90+ a_real = ndarray_get_float_index (aarray , a -> dtype , idx_a );
91+ a_imag = MICROPY_FLOAT_CONST (0.0 );
92+ } else {
93+ a_real = ndarray_get_float_index (aarray , NDARRAY_FLOAT , 2 * idx_a );
94+ a_imag = ndarray_get_float_index (aarray , NDARRAY_FLOAT , 2 * idx_a + 1 );
95+ }
96+
97+ if (c -> dtype != NDARRAY_COMPLEX ) {
98+ c_real = ndarray_get_float_index (carray , c -> dtype , idx_c );
99+ c_imag = MICROPY_FLOAT_CONST (0.0 );
100+ } else {
101+ c_real = ndarray_get_float_index (carray , NDARRAY_FLOAT , 2 * idx_c );
102+ c_imag = ndarray_get_float_index (carray , NDARRAY_FLOAT , 2 * idx_c + 1 );
103+ }
104+ accum_real += a_real * c_real - a_imag * c_imag ;
105+ accum_imag += a_real * c_imag + a_imag * c_real ;
106+ }
107+ * array ++ = accum_real ;
108+ * array ++ = accum_imag ;
109+ }
110+ return MP_OBJ_FROM_PTR (ndarray );
111+ }
112+ #endif
113+
114+ for (int32_t k = - off ; k < len - off ; k ++ ) {
115+ mp_float_t accum = MICROPY_FLOAT_CONST (0.0 );
70116 int32_t top_n = MIN (len_c , len_a - k );
71117 int32_t bot_n = MAX (- k , 0 );
72- for (int32_t n = bot_n ; n < top_n ; n ++ ) {
118+ for (int32_t n = bot_n ; n < top_n ; n ++ ) {
73119 int32_t idx_c = (len_c - n - 1 ) * cs ;
74120 int32_t idx_a = (n + k ) * as ;
75121 mp_float_t ai = ndarray_get_float_index (aarray , a -> dtype , idx_a );
76122 mp_float_t ci = ndarray_get_float_index (carray , c -> dtype , idx_c );
77123 accum += ai * ci ;
78124 }
79- * outptr ++ = accum ;
125+ * array ++ = accum ;
80126 }
81-
82- return out ;
127+ return MP_OBJ_FROM_PTR (ndarray );
83128}
84129
85130MP_DEFINE_CONST_FUN_OBJ_KW (filter_convolve_obj , 2 , filter_convolve );
0 commit comments