@@ -367,52 +367,49 @@ mp_obj_t create_diag(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args)
367367 mp_arg_val_t args [MP_ARRAY_SIZE (allowed_args )];
368368 mp_arg_parse_all (n_args , pos_args , kw_args , MP_ARRAY_SIZE (allowed_args ), allowed_args , args );
369369
370- if (! mp_obj_is_type ( args [0 ].u_obj , & ulab_ndarray_type )) {
371- mp_raise_TypeError ( translate ( "input must be an ndarray" )) ;
372- }
373- ndarray_obj_t * source = MP_OBJ_TO_PTR ( args [ 0 ]. u_obj );
374- if ( source -> ndim == 1 ) { // return a rank-2 tensor with the prescribed diagonal
375- ndarray_obj_t * target = ndarray_new_dense_ndarray ( 2 , ndarray_shape_vector ( 0 , 0 , source -> len , source -> len ) , source -> dtype );
370+ ndarray_obj_t * source = ndarray_from_iterable ( args [0 ].u_obj , NDARRAY_FLOAT );
371+ ndarray_obj_t * target = NULL ;
372+
373+ if ( source -> ndim == 2 ) { // return the diagonal
374+ size_t len = MIN ( source -> shape [ ULAB_MAX_DIMS - 1 ], source -> shape [ ULAB_MAX_DIMS - 2 ]);
375+ target = ndarray_new_linear_array ( len , source -> dtype );
376376 uint8_t * sarray = (uint8_t * )source -> array ;
377377 uint8_t * tarray = (uint8_t * )target -> array ;
378- for (size_t i = 0 ; i < source -> len ; i ++ ) {
378+ for (size_t i = 0 ; i < len ; i ++ ) {
379379 memcpy (tarray , sarray , source -> itemsize );
380- sarray += source -> strides [ULAB_MAX_DIMS - 1 ];
381- tarray += ( source -> len + 1 ) * target -> itemsize ;
380+ sarray += ( source -> strides [ULAB_MAX_DIMS - 1 ] + source -> strides [ ULAB_MAX_DIMS - 2 ]) ;
381+ tarray += target -> itemsize ;
382382 }
383- return MP_OBJ_FROM_PTR (target );
384- }
385- if (source -> ndim > 2 ) {
386- mp_raise_TypeError (translate ("input must be a tensor of rank 2" ));
387- }
388- int32_t k = args [1 ].u_int ;
389- size_t len = 0 ;
390- uint8_t * sarray = (uint8_t * )source -> array ;
391- if (k < 0 ) { // move the pointer "vertically"
392- if (- k < (int32_t )source -> shape [ULAB_MAX_DIMS - 2 ]) {
393- sarray -= k * source -> strides [ULAB_MAX_DIMS - 2 ];
394- len = MIN (source -> shape [ULAB_MAX_DIMS - 2 ] + k , source -> shape [ULAB_MAX_DIMS - 1 ]);
383+ } else if (source -> ndim == 1 ) { // return a rank-2 tensor with the prescribed diagonal
384+ int32_t k = args [1 ].u_int ;
385+ size_t len = source -> len ;
386+ if (k < 0 ) {
387+ len -= k ;
388+ } else {
389+ len += k ;
395390 }
396- } else { // move the pointer "horizontally"
397- if (k < (int32_t )source -> shape [ULAB_MAX_DIMS - 1 ]) {
398- sarray += k * source -> strides [ULAB_MAX_DIMS - 1 ];
399- len = MIN (source -> shape [ULAB_MAX_DIMS - 1 ] - k , source -> shape [ULAB_MAX_DIMS - 2 ]);
391+ target = ndarray_new_dense_ndarray (2 , ndarray_shape_vector (0 , 0 , len , len ), source -> dtype );
392+ uint8_t * sarray = (uint8_t * )source -> array ;
393+ uint8_t * tarray = (uint8_t * )target -> array ;
394+
395+ if (k < 0 ) {
396+ k = - k ;
397+ tarray += len * k * target -> itemsize ;
398+ } else {
399+ tarray += k * target -> itemsize ;
400+ }
401+ for (size_t i = 0 ; i < source -> len ; i ++ ) {
402+ memcpy (tarray , sarray , source -> itemsize );
403+ sarray += source -> strides [ULAB_MAX_DIMS - 1 ];
404+ tarray += (len + 1 ) * target -> itemsize ;
400405 }
401406 }
402-
403- if ( len == 0 ) {
404- mp_raise_ValueError (translate ("offset is too large " ));
407+ #if ULAB_MAX_DIMS > 2
408+ else {
409+ mp_raise_ValueError (translate ("input must be 1- or 2-d " ));
405410 }
411+ #endif
406412
407- ndarray_obj_t * target = ndarray_new_linear_array (len , source -> dtype );
408- uint8_t * tarray = (uint8_t * )target -> array ;
409-
410- for (size_t i = 0 ; i < len ; i ++ ) {
411- memcpy (tarray , sarray , source -> itemsize );
412- sarray += source -> strides [ULAB_MAX_DIMS - 2 ];
413- sarray += source -> strides [ULAB_MAX_DIMS - 1 ];
414- tarray += source -> itemsize ;
415- }
416413 return MP_OBJ_FROM_PTR (target );
417414}
418415
0 commit comments