Skip to content

Commit 4dc9afc

Browse files
committed
fix np.diag
1 parent 973c1f5 commit 4dc9afc

File tree

3 files changed

+41
-38
lines changed

3 files changed

+41
-38
lines changed

code/numpy/create.c

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -367,52 +367,49 @@ mp_obj_t create_diag(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args)
367367
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
368368
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
369369

370-
if(!mp_obj_is_type(args[0].u_obj, &ulab_ndarray_type)) {
371-
mp_raise_TypeError(translate("input must be an ndarray"));
372-
}
373-
ndarray_obj_t *source = MP_OBJ_TO_PTR(args[0].u_obj);
374-
if(source->ndim == 1) { // return a rank-2 tensor with the prescribed diagonal
375-
ndarray_obj_t *target = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, source->len, source->len), source->dtype);
370+
ndarray_obj_t *source = ndarray_from_iterable(args[0].u_obj, NDARRAY_FLOAT);
371+
ndarray_obj_t *target = NULL;
372+
373+
if(source->ndim == 2) { // return the diagonal
374+
size_t len = MIN(source->shape[ULAB_MAX_DIMS - 1], source->shape[ULAB_MAX_DIMS - 2]);
375+
target = ndarray_new_linear_array(len, source->dtype);
376376
uint8_t *sarray = (uint8_t *)source->array;
377377
uint8_t *tarray = (uint8_t *)target->array;
378-
for(size_t i=0; i < source->len; i++) {
378+
for(size_t i=0; i < len; i++) {
379379
memcpy(tarray, sarray, source->itemsize);
380-
sarray += source->strides[ULAB_MAX_DIMS - 1];
381-
tarray += (source->len + 1) * target->itemsize;
380+
sarray += (source->strides[ULAB_MAX_DIMS - 1] + source->strides[ULAB_MAX_DIMS - 2]);
381+
tarray += target->itemsize;
382382
}
383-
return MP_OBJ_FROM_PTR(target);
384-
}
385-
if(source->ndim > 2) {
386-
mp_raise_TypeError(translate("input must be a tensor of rank 2"));
387-
}
388-
int32_t k = args[1].u_int;
389-
size_t len = 0;
390-
uint8_t *sarray = (uint8_t *)source->array;
391-
if(k < 0) { // move the pointer "vertically"
392-
if(-k < (int32_t)source->shape[ULAB_MAX_DIMS - 2]) {
393-
sarray -= k * source->strides[ULAB_MAX_DIMS - 2];
394-
len = MIN(source->shape[ULAB_MAX_DIMS - 2] + k, source->shape[ULAB_MAX_DIMS - 1]);
383+
} else if(source->ndim == 1) { // return a rank-2 tensor with the prescribed diagonal
384+
int32_t k = args[1].u_int;
385+
size_t len = source->len;
386+
if(k < 0) {
387+
len -= k;
388+
} else {
389+
len += k;
395390
}
396-
} else { // move the pointer "horizontally"
397-
if(k < (int32_t)source->shape[ULAB_MAX_DIMS - 1]) {
398-
sarray += k * source->strides[ULAB_MAX_DIMS - 1];
399-
len = MIN(source->shape[ULAB_MAX_DIMS - 1] - k, source->shape[ULAB_MAX_DIMS - 2]);
391+
target = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, len, len), source->dtype);
392+
uint8_t *sarray = (uint8_t *)source->array;
393+
uint8_t *tarray = (uint8_t *)target->array;
394+
395+
if(k < 0) {
396+
k = -k;
397+
tarray += len * k * target->itemsize;
398+
} else {
399+
tarray += k * target->itemsize;
400+
}
401+
for(size_t i = 0; i < source->len; i++) {
402+
memcpy(tarray, sarray, source->itemsize);
403+
sarray += source->strides[ULAB_MAX_DIMS - 1];
404+
tarray += (len + 1) * target->itemsize;
400405
}
401406
}
402-
403-
if(len == 0) {
404-
mp_raise_ValueError(translate("offset is too large"));
407+
#if ULAB_MAX_DIMS > 2
408+
else {
409+
mp_raise_ValueError(translate("input must be 1- or 2-d"));
405410
}
411+
#endif
406412

407-
ndarray_obj_t *target = ndarray_new_linear_array(len, source->dtype);
408-
uint8_t *tarray = (uint8_t *)target->array;
409-
410-
for(size_t i=0; i < len; i++) {
411-
memcpy(tarray, sarray, source->itemsize);
412-
sarray += source->strides[ULAB_MAX_DIMS - 2];
413-
sarray += source->strides[ULAB_MAX_DIMS - 1];
414-
tarray += source->itemsize;
415-
}
416413
return MP_OBJ_FROM_PTR(target);
417414
}
418415

code/ulab.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "user/user.h"
3434
#include "utils/utils.h"
3535

36-
#define ULAB_VERSION 5.0.1
36+
#define ULAB_VERSION 5.0.2
3737
#define xstr(s) str(s)
3838
#define str(s) #s
3939

docs/ulab-change-log.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
Tue, 8 Feb 2022
2+
3+
version 5.0.2
4+
5+
fix np.diag
6+
17
Thu, 3 Feb 2022
28

39
version 5.0.1

0 commit comments

Comments
 (0)