Skip to content

Commit 4701338

Browse files
committed
add rclass
1 parent 7a183e1 commit 4701338

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

snippets/rclass.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from ulab import numpy as np
2+
from typing import List, Tuple, Union # upip.install("pycopy-typing")
3+
4+
ndarray = np.array
5+
_DType = int
6+
RClassKeyType = Union[slice, int, float]
7+
8+
# this is a stripped down version of RClass (used by np.r_[...etc])
9+
# it doesn't include support for string arguments as the first index element
10+
class RClass:
11+
12+
def __getitem__(self, key: Union[RClassKeyType, Tuple[RClassKeyType, ...]]):
13+
14+
if not isinstance(key, tuple):
15+
key = (key,)
16+
17+
objs: List[ndarray] = []
18+
scalars: List[int] = []
19+
arraytypes: List[_DType] = []
20+
scalartypes: List[_DType] = []
21+
22+
# these may get overridden in following loop
23+
axis = 0
24+
25+
for idx, item in enumerate(key):
26+
scalar = False
27+
if isinstance(item, slice):
28+
step = item.step
29+
start = item.start
30+
stop = item.stop
31+
if start is None:
32+
start = 0
33+
if step is None:
34+
nstep = 1
35+
if isinstance(step, complex):
36+
size = int(abs(step))
37+
newobj = cast(ndarray, linspace(start, stop, num=size))
38+
else:
39+
newobj = np.arange(start, stop, step)
40+
41+
# if is number
42+
elif isinstance(item, (int, float)):
43+
newobj = np.array([item])
44+
scalars.append(len(objs))
45+
scalar = True
46+
scalartypes.append(newobj.dtype())
47+
48+
else:
49+
raise Exception("index object %s of type %s is not supported by r_[]" % (
50+
str(item), type(item)))
51+
52+
objs.append(newobj)
53+
if not scalar and isinstance(newobj, ndarray):
54+
arraytypes.append(newobj.dtype())
55+
56+
# Ensure that scalars won't up-cast unless warranted
57+
# TODO: ensure that this actually works for dtype coercion
58+
# likelihood is we're going to have to do some funky logic for this
59+
final_dtype = max(arraytypes + scalartypes)
60+
if final_dtype is not None:
61+
for idx in scalars:
62+
objs[idx] = np.array(objs[idx], dtype=final_dtype)
63+
64+
res = np.concatenate(tuple(objs), axis=axis)
65+
66+
return res
67+
68+
# this seems weird - not sure what it's for
69+
def __len__(self):
70+
return 0
71+
72+
73+
r_ = RClass()

0 commit comments

Comments
 (0)