Skip to content

Commit ebb6568

Browse files
committed
Added support for large tensors, save safetensors on CPU
1 parent 27e0a6b commit ebb6568

File tree

6 files changed

+34
-35
lines changed

6 files changed

+34
-35
lines changed

TorchSharp.PyBridge.Tests/TorchSharp.PyBridge.Tests.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
<PackageReference Include="NUnit3TestAdapter" Version="4.4.2" />
4141
<PackageReference Include="NUnit.Analyzers" Version="3.6.1" />
4242
<PackageReference Include="coverlet.collector" Version="3.2.0" />
43-
<PackageReference Include="TorchSharp-cpu" Version="0.101.3" />
43+
<PackageReference Include="TorchSharp-cpu" Version="0.102.0" />
4444
</ItemGroup>
4545

4646
<ItemGroup>

TorchSharp.PyBridge/PyBridgeModuleExtensions.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ public static Module load_py(this Module module, string location, bool strict =
121121
public static Module load_py(this Module module, System.IO.Stream stream, bool strict = true, IList<string>? skip = null, Dictionary<string, bool>? loadedParameters = null, bool leaveOpen = false) {
122122
// Create a dispose score so that we don't keep anyof the loaded tensors past this function
123123
using var d = torch.NewDisposeScope();
124+
using var d2 = torch.no_grad(); // To circumvent a bug introduced in 0.102.0
124125

125126
// Unpickle the state dictionary into memory
126127
var stateHashtable = PyTorchUnpickler.UnpickleStateDict(stream, leaveOpen);
@@ -182,6 +183,7 @@ public static Module load_safetensors(this Module module, string location, bool
182183
public static Module load_safetensors(this Module module, System.IO.Stream stream, bool strict = true, IList<string>? skip = null, Dictionary<string, bool>? loadedParameters = null, bool leaveOpen = false) {
183184
// Create a dispose score so that we don't keep anyof the loaded tensors past this function
184185
using var d = torch.NewDisposeScope();
186+
using var d2 = torch.no_grad(); // To circumvent a bug introduced in 0.102.0
185187

186188
// Retrieve the current state dict of the module, so that we can make sure to only load the relevant
187189
// tensors from the file.

TorchSharp.PyBridge/PyTorchPickler.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ protected override bool persistentId(object pid, out object? newpid) {
9595
// Start by serializing the object to a file in the archive
9696
var entry = _archive.CreateEntry($"model/data/{_tensorCount}");
9797
using (var stream = entry.Open())
98-
stream.Write(tensor.bytes.ToArray(), 0, tensor.bytes.Length);
99-
98+
tensor.WriteBytesToStream(stream);
99+
100100
// Collect the items for our persistentId, as above.
101101
newpid = new object[] {
102102
"storage",

TorchSharp.PyBridge/PyTorchUnpickler.cs

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -81,30 +81,17 @@ protected override object persistentLoad(object pid) {
8181
string archiveKey = (string)opid[2];
8282
// Tuple Item3: location (cpu/gpu), but we always load onto CPU.
8383
// Tuple Item4: numElems (the number of elements in the tensor)
84-
int numElem = (int)opid[4];
85-
84+
8685
// Convert the storage name into the relevant scalar type (e.g., LongStorage => torch.long)
8786
// and then check how many bytes each element is
8887
var dtype = GetScalarTypeFromStorageName(storageType);
89-
var elemSize = (int)torch.empty(0, dtype).ElementSize;
90-
91-
int totalSize = numElem * elemSize;
92-
93-
//
94-
// TODO: Fix this so that you can read large tensors. Right now, they are limited to 2GB
95-
//
96-
if (totalSize > int.MaxValue)
97-
throw new NotImplementedException("Loading tensors larger than 2GB");
98-
88+
9989
// Retrieve the entry from the archive
10090
var entry = _archive.Entries.First(f => f.FullName.EndsWith($"data/{archiveKey}"));
101-
// Read in the relevant bytes from the entry
102-
var bytesBuffer = new byte[totalSize];
103-
entry!.Open().Read(bytesBuffer, 0, totalSize);
104-
91+
10592
// Send this back, so our TensorObjectConstructor can create our torch.tensor from the object.
10693
return new TensorObject() {
107-
data = bytesBuffer,
94+
data = entry!.Open(),
10895
dtype = dtype
10996
};
11097
}
@@ -176,7 +163,8 @@ public object construct(object[] args) {
176163
torch.Tensor t = shape.Length == 0 ? torch.zeros(1, arg0.dtype)
177164
: torch.WrappedTensorDisposeScope(() =>
178165
torch.zeros(shape, arg0.dtype).as_strided(shape, stride, storageOffset));
179-
t.bytes = arg0.data;
166+
t.ReadBytesFromStream(arg0.data);
167+
arg0.data.Close();
180168
return t;
181169
}
182170
}
@@ -201,7 +189,7 @@ public object construct(object[] args) {
201189
/// Therefore, this class is a simple wrapper for the bytes + dtype of the storage.
202190
/// </summary>
203191
class TensorObject {
204-
public byte[]? data { get; set; }
192+
public Stream data { get; set; }
205193
public torch.ScalarType dtype { get; set; }
206194
}
207195
}

TorchSharp.PyBridge/Safetensors.cs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ static class Safetensors {
2929

3030
var tensor = torch.empty(kvp.Value.Shape, dtype: ConvertToTorchDType(kvp.Value.DataType));
3131

32-
// Make sure the length isn't > int.MaxValue, since .NET has the 2GB limit
32+
// Make sure the length matches the number of bytes to load
3333
long length = kvp.Value.Offsets[1] - kvp.Value.Offsets[0];
34-
if (length > int.MaxValue)
35-
throw new NotImplementedException("Loading tensors larger than 2GB");
34+
if (length != tensor.ElementSize * tensor.NumberOfElements)
35+
throw new NotImplementedException($"Error when loading tensor {kvp.Key} - mismatched # of elements");
3636

3737
stream.Position = offset + kvp.Value.Offsets[0];
38-
tensor.bytes = stream.ReadBytes((int)length);
39-
38+
tensor.ReadBytesFromStream(stream);
39+
4040
ret.Add(kvp.Key, tensor);
4141
}
4242

@@ -75,9 +75,14 @@ public static void SaveStateDict(Stream stream, Dictionary<string, torch.Tensor>
7575
var br = new BinaryWriter(stream);
7676
br.Write((ulong)indexJson.Length);
7777
br.Write(indexJson);
78-
foreach (var kvp in orderedState)
79-
br.Write(kvp.Value.bytes);
80-
78+
foreach (var kvp in orderedState) {
79+
if (kvp.Value.device.type == DeviceType.CPU)
80+
kvp.Value.WriteBytesToStream(stream);
81+
else {
82+
using var tmp = kvp.Value.cpu();
83+
tmp.WriteBytesToStream(stream);
84+
}
85+
}
8186
if (!leaveOpen)
8287
br.Close();
8388
}

TorchSharp.PyBridge/TorchSharp.PyBridge.csproj

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
<ItemGroup>
1313
<PackageReference Include="Razorvine.Pickle" Version="1.5.0" />
14-
<PackageReference Include="TorchSharp" Version="[0.101.3,)" PrivateAssets="All" />
14+
<PackageReference Include="TorchSharp" Version="[0.102.0,)" PrivateAssets="All" />
1515
<PackageReference Include="TqdmSharp" Version="1.3.3" />
1616
</ItemGroup>
1717

@@ -22,10 +22,14 @@
2222
<PackageProjectUrl>https://github.com/shaltielshmid/TorchSharp.PyBridge</PackageProjectUrl>
2323
<RepositoryUrl>https://github.com/shaltielshmid/TorchSharp.PyBridge.git</RepositoryUrl>
2424
<RepositoryType>git</RepositoryType>
25-
<Version>1.2.0</Version>
26-
<AssemblyVersion>1.2.0.0</AssemblyVersion>
27-
<FileVersion>1.2.0.0</FileVersion>
28-
<PackageReleaseNotes>1.2.0: Added `load_safetensors` and `save_safetensors` extensions for modules.</PackageReleaseNotes>
25+
<Version>1.3.0</Version>
26+
<AssemblyVersion>1.3.0.0</AssemblyVersion>
27+
<FileVersion>1.3.0.0</FileVersion>
28+
<PackageReleaseNotes>
29+
1.3.0:
30+
- Added support for loading tensors that are greater than 2GB (following the update in TorchSharp 0.102.0)
31+
- Added support for loading and saving safetensors when model isn't on CPU.
32+
</PackageReleaseNotes>
2933
</PropertyGroup>
3034

3135
<ItemGroup>

0 commit comments

Comments
 (0)