Skip to content

Commit aab1ca3

Browse files
authored
Merge pull request #6 from molstar/refactoring
Refactoring
2 parents a1dd7ce + b527772 commit aab1ca3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1048
-1305
lines changed

ciftools/binary/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
from ciftools.binary.decoder import decode_cif_data
2-
from ciftools.binary.writer import BinaryCIFWriter

ciftools/binary/data.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from typing import Any, Dict, List, Optional, Union
2+
3+
import numpy as np
4+
from ciftools.binary.decoder import decode_cif_data
5+
from ciftools.binary.encoded_data import EncodedCIFCategory, EncodedCIFColumn, EncodedCIFFile
6+
from ciftools.models.data import CIFCategory, CIFColumn, CIFDataBlock, CIFFile, CIFValuePresenceEnum
7+
8+
9+
class BinaryCIFColumn(CIFColumn):
10+
def __init__(
11+
self,
12+
name: str,
13+
values: np.ndarray,
14+
value_presence: Optional[np.ndarray],
15+
):
16+
self.name = name
17+
self._values = values
18+
self._value_presence = value_presence
19+
self._row_count = len(values)
20+
21+
def get_string(self, row: int) -> str:
22+
return str(self._values[row])
23+
24+
def get_integer(self, row: int) -> int:
25+
return int(self._values[row])
26+
27+
def get_float(self, row: int) -> float:
28+
return float(self._values[row])
29+
30+
def get_value_presence(self, row: int) -> CIFValuePresenceEnum:
31+
if self._value_presence:
32+
return self._value_presence[row]
33+
return 0 # type: ignore
34+
35+
def are_values_equal(self, row_a: int, row_b: int) -> bool:
36+
return self._values[row_a] == self._values[row_b]
37+
38+
def string_equals(self, row: int, value: str) -> bool:
39+
return str(self._values[row]) == value
40+
41+
def as_ndarray(
42+
self, *, dtype: Optional[Union[np.dtype, str]] = None, start: Optional[int] = None, end: Optional[int] = None
43+
) -> np.ndarray:
44+
if dtype is None and start is None and end is None:
45+
return self._values
46+
if dtype is None:
47+
return self._values[start:end]
48+
return self._values[start:end].astype(dtype)
49+
50+
def __getitem__(self, idx: Any) -> Any:
51+
if isinstance(idx, int) and self._value_presence and self._value_presence[idx]:
52+
return None
53+
return self._values[idx]
54+
55+
def __len__(self):
56+
return self._row_count
57+
58+
@property
59+
def value_presences(self) -> Optional[np.ndarray]:
60+
return self._value_presence
61+
62+
63+
def _decode_cif_column(column: EncodedCIFColumn) -> CIFColumn:
64+
values = decode_cif_data(column["data"])
65+
value_mask = decode_cif_data(column["mask"]) if column["mask"] else None
66+
return BinaryCIFColumn(column["name"], values, value_mask)
67+
68+
69+
class BinaryCIFCategory(CIFCategory):
70+
def __getitem__(self, name: str) -> BinaryCIFColumn:
71+
if name not in self._field_cache:
72+
raise ValueError(f"{name} is not a valid category name")
73+
74+
if not self._field_cache[name]:
75+
self._field_cache[name] = _decode_cif_column(self._columns[name])
76+
77+
return self._field_cache[name] # type: ignore
78+
79+
def __contains__(self, key: str):
80+
return key in self._columns
81+
82+
def __init__(self, category: EncodedCIFCategory, lazy: bool):
83+
self._field_names = [c["name"] for c in category["columns"]]
84+
self._field_cache = {c["name"]: None if lazy else _decode_cif_column(c) for c in category["columns"]}
85+
self._columns: dict[str, EncodedCIFColumn] = {c["name"]: c for c in category["columns"]}
86+
self._n_columns = len(category["columns"])
87+
self._n_rows = category["rowCount"]
88+
self._name = category["name"][1:]
89+
90+
@property
91+
def name(self) -> str:
92+
return self._name
93+
94+
@property
95+
def n_rows(self) -> int:
96+
return self._n_rows
97+
98+
@property
99+
def n_columns(self) -> int:
100+
return self._n_columns
101+
102+
@property
103+
def field_names(self) -> List[str]:
104+
return self._field_names
105+
106+
107+
class BinaryCIFDataBlock(CIFDataBlock):
108+
def __getitem__(self, name: str) -> CIFCategory:
109+
return self._categories[name]
110+
111+
def __contains__(self, key: str):
112+
return key in self._categories
113+
114+
def __init__(self, header: str, categories: Dict[str, BinaryCIFCategory]):
115+
self._header = header
116+
self._categories = categories
117+
118+
@property
119+
def header(self) -> str:
120+
return self._header
121+
122+
@property
123+
def categories(self) -> Dict[str, CIFCategory]:
124+
return self._categories # type: ignore
125+
126+
127+
class BinaryCIFFile(CIFFile):
128+
def __getitem__(self, index_or_name: Union[int, str]):
129+
if isinstance(index_or_name, str):
130+
return self._block_map.get(index_or_name)
131+
else:
132+
return (
133+
self.data_blocks[index_or_name]
134+
if index_or_name < len(self.data_blocks) and index_or_name >= 0
135+
else None
136+
)
137+
138+
def __len__(self):
139+
return len(self._data_blocks)
140+
141+
def __contains__(self, key: str):
142+
return key in self._block_map
143+
144+
def __init__(self, data_blocks: List[BinaryCIFDataBlock]):
145+
self._data_blocks = data_blocks
146+
self._block_map: dict[str, CIFDataBlock] = {b.header: b for b in data_blocks}
147+
148+
@staticmethod
149+
def from_data(data: EncodedCIFFile, *, lazy=True) -> "BinaryCIFFile":
150+
"""
151+
- lazy:
152+
- True: individual columns are decoded only when accessed
153+
- False: decode all columns immediately
154+
"""
155+
156+
min_version = (0, 3, 0)
157+
version = tuple(map(int, data["version"].split(".")))
158+
if version < min_version:
159+
raise ValueError(f"Invalid version {data['version']}, expected >={'.'.join(map(str, min_version))}")
160+
161+
data_blocks = [
162+
BinaryCIFDataBlock(
163+
block["header"],
164+
{category["name"][1:]: BinaryCIFCategory(category, lazy) for category in block["categories"]},
165+
)
166+
for block in data["dataBlocks"]
167+
]
168+
169+
return BinaryCIFFile(data_blocks)
170+
171+
@property
172+
def data_blocks(self) -> List[CIFDataBlock]:
173+
return self._data_blocks # type: ignore

ciftools/binary/encoding/data_types.py renamed to ciftools/binary/data_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class DataType:
3535

3636
@staticmethod
3737
def from_dtype(dtype: Union[np.dtype, str]) -> DataTypeEnum:
38-
t = str(dtype.str)
38+
t = dtype if isinstance(dtype, str) else str(dtype.str)
3939
if t[0] in (">", "<", "|"):
4040
t = t[1:]
4141
return DataTypeEnum(DataType.__dtypes_to_data_types[t])

ciftools/binary/decoder.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from typing import Union
2-
31
import numpy as np
4-
from ciftools.binary.encoding.data_types import DataType
5-
from ciftools.binary.encoding.encodings import (
2+
from ciftools.binary.data_types import DataType
3+
from ciftools.binary.encoded_data import EncodedCIFData
4+
from ciftools.binary.encoding_types import (
65
ByteArrayEncoding,
76
DeltaEncoding,
87
FixedPointEncoding,
@@ -11,30 +10,21 @@
1110
RunLengthEncoding,
1211
StringArrayEncoding,
1312
)
14-
from ciftools.binary.encoding.types import EncodedCIFColumn, EncodedCIFData
15-
from ciftools.cif_format.base import CIFColumnBase
16-
from ciftools.cif_format.binary.column import BinaryCIFColumn
17-
18-
19-
def decode_cif_column(column: EncodedCIFColumn) -> CIFColumnBase:
20-
values = decode_cif_data(column["data"])
21-
value_kinds = decode_cif_data(column["mask"]) if column["mask"] else None # type: ignore
22-
return BinaryCIFColumn(column["name"], values, value_kinds) # type: ignore
2313

2414

25-
def decode_cif_data(encoded_data: EncodedCIFData) -> Union[np.ndarray, list[str]]:
15+
def decode_cif_data(encoded_data: EncodedCIFData) -> np.ndarray:
2616
result = encoded_data["data"]
2717
for encoding in encoded_data["encoding"][::-1]:
2818
if encoding["kind"] in _decoders:
2919
result = _decoders[encoding["kind"]](result, encoding) # type: ignore
3020
else:
3121
raise ValueError(f"Unsupported encoding '{encoding['kind']}'")
3222

33-
return result
23+
return result # type: ignore
3424

3525

3626
def _decode_byte_array(data: bytes, encoding: ByteArrayEncoding) -> np.ndarray:
37-
return np.frombuffer(data, dtype="<" + DataType.to_dtype(encoding["type"]))
27+
return np.frombuffer(data, dtype=f"<{str(DataType.to_dtype(encoding['type']))}")
3828

3929

4030
def _decode_fixed_point(data: np.ndarray, encoding: FixedPointEncoding) -> np.ndarray:
@@ -57,6 +47,7 @@ def _decode_delta(data: np.ndarray, encoding: DeltaEncoding) -> np.ndarray:
5747
return np.cumsum(result, out=result)
5848

5949

50+
# TODO: JIT
6051
def _decode_integer_packing_signed(data: np.ndarray, encoding: IntegerPackingEncoding) -> np.ndarray:
6152
upper_limit = 0x7F if encoding["byteCount"] == 1 else 0x7FFF
6253
lower_limit = -upper_limit - 1
@@ -78,6 +69,7 @@ def _decode_integer_packing_signed(data: np.ndarray, encoding: IntegerPackingEnc
7869
return output
7970

8071

72+
# TODO: JIT
8173
def _decode_integer_packing_unsigned(data: np.ndarray, encoding: IntegerPackingEncoding) -> np.ndarray:
8274
upper_limit = 0xFF if encoding["byteCount"] == 1 else 0xFFFF
8375
n = len(data)
@@ -107,7 +99,7 @@ def _decode_integer_packing(data: np.ndarray, encoding: IntegerPackingEncoding)
10799
return _decode_integer_packing_signed(data, encoding)
108100

109101

110-
def _decode_string_array(data: np.ndarray, encoding: StringArrayEncoding) -> list[str]:
102+
def _decode_string_array(data: np.ndarray, encoding: StringArrayEncoding) -> np.ndarray:
111103
offsets = decode_cif_data(EncodedCIFData(encoding=encoding["offsetEncoding"], data=encoding["offsets"]))
112104
indices = decode_cif_data(EncodedCIFData(encoding=encoding["dataEncoding"], data=data))
113105

@@ -117,7 +109,8 @@ def _decode_string_array(data: np.ndarray, encoding: StringArrayEncoding) -> lis
117109
for i in range(1, len(offsets)):
118110
strings.append(string_data[offsets[i - 1] : offsets[i]]) # type: ignore
119111

120-
return [strings[i + 1] for i in indices] # type: ignore
112+
return np.array([strings[i + 1] for i in indices], dtype=np.object_)
113+
# return [strings[i + 1] for i in indices]
121114

122115

123116
_decoders = {

ciftools/binary/encoding/types.py renamed to ciftools/binary/encoded_data.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional, TypedDict
1+
from typing import Optional, TypedDict, Union
22

3-
from ciftools.binary.encoding.encodings import EncodingBase
3+
import numpy as np
4+
from ciftools.binary.encoding_types import EncodingBase
45

56

67
class EncodedCIFData(TypedDict):
78
encoding: list[EncodingBase]
8-
data: bytes
9+
data: Union[bytes, np.ndarray]
910

1011

1112
class EncodedCIFColumn(TypedDict):

0 commit comments

Comments
 (0)