Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: enforce binding addresses order
  • Loading branch information
imyhxy committed Nov 30, 2021
commit 4e4dec321852681b6393288afd1143099e43a6f7
6 changes: 3 additions & 3 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import math
import platform
import warnings
from collections import namedtuple
from collections import OrderedDict, namedtuple
from copy import copy
from pathlib import Path

Expand Down Expand Up @@ -326,14 +326,14 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
logger = trt.Logger(trt.Logger.INFO)
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(f.read())
bindings = dict()
bindings = OrderedDict()
for index in range(model.num_bindings):
name = model.get_binding_name(index)
dtype = trt.nptype(model.get_binding_dtype(index))
shape = tuple(model.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
binding_addrs = {n: d.ptr for n, d in bindings.items()}
binding_addrs = OrderedDict({n: d.ptr for n, d in bindings.items()})
context = model.create_execution_context()
batch_size = bindings['images'].shape[0]
else: # TensorFlow model (TFLite, pb, saved_model)
Expand Down