forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_serialization.py
More file actions
158 lines (126 loc) · 4.48 KB
/
_serialization.py
File metadata and controls
158 lines (126 loc) · 4.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import pickle
from dataclasses import dataclass
from io import BufferedIOBase
from typing import Any
import torch
import torch._weights_only_unpickler as _weights_only_unpickler
from torch.serialization import _load, _save, DEFAULT_PROTOCOL, MAP_LOCATION
__all__: list[str] = []
@dataclass
class _Entry:
key: str
is_storage: bool
length: int
_weights_only_unpickler._add_safe_globals([_Entry])
class _PseudoZipFile:
def __init__(self) -> None:
self.records: dict[str, tuple[object, int]] = {}
def write_record(self, key: str, data: object, length: int) -> None:
self.records[key] = (data, length)
def write_to(self, f: BufferedIOBase) -> None:
entries = []
for key, (data, length) in self.records.items():
entries.append(
_Entry(
key=key,
is_storage=isinstance(data, torch.UntypedStorage),
length=length,
)
)
pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL)
for data, _ in self.records.values():
if isinstance(data, bytes):
f.write(data)
elif isinstance(data, str):
f.write(data.encode("utf-8"))
elif isinstance(data, torch.UntypedStorage):
data._write_file(f, False, False, 1)
else:
raise TypeError(f"unknown type: {type(data)}")
def read_from(self, f: BufferedIOBase) -> None:
entries = _weights_only_unpickler.load(f)
for entry in entries:
data = f.read(entry.length)
if entry.is_storage:
if entry.length == 0:
storage = torch.UntypedStorage(0)
else:
storage = torch.frombuffer(
data,
dtype=torch.uint8,
).untyped_storage()
self.records[entry.key] = (
storage,
entry.length,
)
else:
self.records[entry.key] = (data, entry.length)
def has_record(self, key: str) -> bool:
return key in self.records
def get_record(self, key: str) -> object:
return self.records[key][0]
def get_storage_from_record(
self, key: str, _length: int, _type: int
) -> torch.Tensor:
return torch.tensor(self.records[key][0], dtype=torch.uint8)
def serialization_id(self) -> str:
return "torchft"
def _streaming_save(
obj: object,
f: BufferedIOBase,
pickle_module: Any = pickle,
pickle_protocol: int = DEFAULT_PROTOCOL,
) -> None:
"""
Save the object to a file-like object in a streaming fashion compatible with
network sockets.
This behaves similarly to :func:`torch.save` with a few notable differences:
* A non-seekable file like object can be used when loading.
* No forwards/backwards compatibility is provided for the serialization
format. This is only intended to be used with a single version of PyTorch
with transient storage (i.e. sockets or temp files).
* mmap is not supported
See :func:`torch.save` for more details on specific arguments.
"""
zip_file = _PseudoZipFile()
_save(
obj,
zip_file=zip_file,
pickle_module=pickle_module,
pickle_protocol=pickle_protocol,
_disable_byteorder_record=False,
)
zip_file.write_to(f)
def _streaming_load(
f: BufferedIOBase,
map_location: MAP_LOCATION = None,
pickle_module: Any = None,
*,
weights_only: bool = True,
**pickle_load_args: Any,
) -> object:
"""
Load the object from a file-like object in a streaming fashion compatible with
network sockets.
See :func:`_streaming_save` for more details about the streaming behavior.
See :func:`torch.load` for more details on specific arguments.
"""
if weights_only:
if pickle_module is not None:
raise RuntimeError(
"Can not safely load weights when explicit pickle_module is specified"
)
pickle_module = _weights_only_unpickler
else:
if pickle_module is None:
pickle_module = pickle
if "encoding" not in pickle_load_args:
pickle_load_args["encoding"] = "utf-8"
zip_file = _PseudoZipFile()
zip_file.read_from(f)
return _load(
zip_file=zip_file,
map_location=map_location,
pickle_module=pickle_module,
**pickle_load_args,
)