diff --git a/requirements.txt b/requirements.txt index 22b51fc490e3..bfe795f1d378 100755 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ scipy>=1.4.1 torch>=1.7.0 torchvision>=0.8.1 tqdm>=4.41.0 - +pycuda # Logging ------------------------------------- tensorboard>=2.4.1 # wandb diff --git a/utils/yolov5_trt.py b/utils/yolov5_trt.py new file mode 100755 index 000000000000..30a54619b3e4 --- /dev/null +++ b/utils/yolov5_trt.py @@ -0,0 +1,403 @@ +""" +An example that uses TensorRT's Python api to make inferences. +""" +import ctypes +import os +import random +import shutil +import sys +import threading +import time + +import cv2 +import numpy as np +import pycuda.autoinit +import pycuda.driver as cuda +import tensorrt as trt +import torch +import torchvision + + +def get_img_path_batches(batch_size, img_dir): + ret = [] + batch = [] + for root, dirs, files in os.walk(img_dir): + for name in files: + if len(batch) == batch_size: + ret.append(batch) + batch = [] + batch.append(os.path.join(root, name)) + if len(batch) > 0: + ret.append(batch) + return ret + + +def plot_one_box(x, img, color=None, label=None, line_thickness=None): + """ + description: Plots one bounding box on image img, + this function comes from YoLov5 project. + param: + x: a box likes [x1,y1,x2,y2] + img: a opencv image object + color: color to draw rectangle, such as (0,255,0) + label: str + line_thickness: int + return: + no return + + """ + tl = ( + line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 + ) # line/font thickness + color = color or [random.randint(0, 255) for _ in range(3)] + c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) + cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) + if label: + tf = max(tl - 1, 1) # font thickness + t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] + c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 + cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + img, + label, + (c1[0], c1[1] - 2), + 0, + tl / 3, + [225, 255, 255], + thickness=tf, + lineType=cv2.LINE_AA, + ) + + +class YoLov5TRT: + """ + description: A YOLOv5 class that warps TensorRT ops, preprocess and postprocess ops. + """ + + def __init__(self, engine_file_path,CONF_THRESH,IOU_THRESH): + # Create a Context on this device, + self.ctx = cuda.Device(0).make_context() + self.conf_thresh=CONF_THRESH + self.iou_thresh=IOU_THRESH + stream = cuda.Stream() + TRT_LOGGER = trt.Logger(trt.Logger.INFO) + runtime = trt.Runtime(TRT_LOGGER) + + # Deserialize the engine from file + with open(engine_file_path, "rb") as f: + engine = runtime.deserialize_cuda_engine(f.read()) + context = engine.create_execution_context() + + host_inputs = [] + cuda_inputs = [] + host_outputs = [] + cuda_outputs = [] + bindings = [] + + for binding in engine: + print('bingding:', binding, engine.get_binding_shape(binding)) + size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size + dtype = trt.nptype(engine.get_binding_dtype(binding)) + # Allocate host and device buffers + host_mem = cuda.pagelocked_empty(size, dtype) + cuda_mem = cuda.mem_alloc(host_mem.nbytes) + # Append the device buffer to device bindings. + bindings.append(int(cuda_mem)) + # Append to the appropriate list. + if engine.binding_is_input(binding): + self.input_w = engine.get_binding_shape(binding)[-1] + self.input_h = engine.get_binding_shape(binding)[-2] + host_inputs.append(host_mem) + cuda_inputs.append(cuda_mem) + else: + host_outputs.append(host_mem) + cuda_outputs.append(cuda_mem) + + # Store + self.stream = stream + self.context = context + self.engine = engine + self.host_inputs = host_inputs + self.cuda_inputs = cuda_inputs + self.host_outputs = host_outputs + self.cuda_outputs = cuda_outputs + self.bindings = bindings + self.batch_size = engine.max_batch_size + + def infer(self, raw_image_generator): + threading.Thread.__init__(self) + # Make self the active context, pushing it on top of the context stack. + self.ctx.push() + # Restore + stream = self.stream + context = self.context + engine = self.engine + host_inputs = self.host_inputs + cuda_inputs = self.cuda_inputs + host_outputs = self.host_outputs + cuda_outputs = self.cuda_outputs + bindings = self.bindings + # Do image preprocess + batch_image_raw = [] + batch_origin_h = [] + batch_origin_w = [] + batch_input_image = np.empty(shape=[self.batch_size, 3, self.input_h, self.input_w]) + start0 = time.time() + for i, image_raw in enumerate(raw_image_generator): + input_image, image_raw, origin_h, origin_w = self.preprocess_image(image_raw) + batch_image_raw.append(image_raw) + batch_origin_h.append(origin_h) + batch_origin_w.append(origin_w) + np.copyto(batch_input_image[i], input_image) + end0 = time.time() + print("resize time", end0 - start0) + batch_input_image = np.ascontiguousarray(batch_input_image) + + # Copy input image to host buffer + np.copyto(host_inputs[0], batch_input_image.ravel()) + start = time.time() + # Transfer input data to the GPU. + cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream) + # Run inference. + context.execute_async(batch_size=self.batch_size, bindings=bindings, stream_handle=stream.handle) + # Transfer predictions back from the GPU. + cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream) + # Synchronize the stream + stream.synchronize() + end = time.time() + # Remove any context from the top of the context stack, deactivating it. + + self.ctx.pop() + + # Here we use the first row of output in that batch_size = 1 + output = host_outputs[0] + # Do postprocess + categories=['person','phone','helmet','head','car','truck','boat','ship','fire','drop'] + + for i in range(self.batch_size): + result_boxes, result_scores, result_classid = self.post_process( + output[i * 6001: (i + 1) * 6001], batch_origin_h[i], batch_origin_w[i] + ) + # Draw rectangles and labels on the original image + for j in range(len(result_boxes)): + box = result_boxes[j] + plot_one_box( + box, + batch_image_raw[i], + label="{}:{:.2f}".format( + categories[int(result_classid[j])], result_scores[j] + ), + ) + return batch_image_raw, end - start,[result_boxes, result_scores, result_classid] + + def destroy(self): + # Remove any context from the top of the context stack, deactivating it. + # self.ctx.pop() + + self.ctx.pop() + + + + def get_raw_image(self, image_path_batch): + """ + description: Read an image from image path + """ + for img_path in image_path_batch: + yield cv2.imread(img_path) + + def get_raw_image_zeros(self, image_path_batch=None): + """ + description: Ready data for warmup + """ + for _ in range(self.batch_size): + yield np.zeros([self.input_h, self.input_w, 3], dtype=np.uint8) + + def preprocess_image(self, raw_bgr_image): + """ + description: Convert BGR image to RGB, + resize and pad it to target size, normalize to [0,1], + transform to NCHW format. + param: + input_image_path: str, image path + return: + image: the processed image + image_raw: the original image + h: original height + w: original width + """ + image_raw = raw_bgr_image + h, w, c = image_raw.shape + image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB) + # Calculate widht and height and paddings + r_w = self.input_w / w + r_h = self.input_h / h + if r_h > r_w: + tw = self.input_w + th = int(r_w * h) + tx1 = tx2 = 0 + ty1 = int((self.input_h - th) / 2) + ty2 = self.input_h - th - ty1 + else: + tw = int(r_h * w) + th = self.input_h + tx1 = int((self.input_w - tw) / 2) + tx2 = self.input_w - tw - tx1 + ty1 = ty2 = 0 + # Resize the image with long side while maintaining ratio + image = cv2.resize(image, (tw, th)) + # Pad the short side with (128,128,128) + image = cv2.copyMakeBorder( + image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, (128, 128, 128) + ) + image = image.astype(np.float32) + # Normalize to [0,1] + image /= 255.0 + # HWC to CHW format: + image = np.transpose(image, [2, 0, 1]) + # CHW to NCHW format + image = np.expand_dims(image, axis=0) + # Convert the image to row-major order, also known as "C order": + image = np.ascontiguousarray(image) + return image, image_raw, h, w + + def xywh2xyxy(self, origin_h, origin_w, x): + """ + description: Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + param: + origin_h: height of original image + origin_w: width of original image + x: A boxes tensor, each row is a box [center_x, center_y, w, h] + return: + y: A boxes tensor, each row is a box [x1, y1, x2, y2] + """ + y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x) + r_w = self.input_w / origin_w + r_h = self.input_h / origin_h + if r_h > r_w: + y[:, 0] = x[:, 0] - x[:, 2] / 2 + y[:, 2] = x[:, 0] + x[:, 2] / 2 + y[:, 1] = x[:, 1] - x[:, 3] / 2 - (self.input_h - r_w * origin_h) / 2 + y[:, 3] = x[:, 1] + x[:, 3] / 2 - (self.input_h - r_w * origin_h) / 2 + y /= r_w + else: + y[:, 0] = x[:, 0] - x[:, 2] / 2 - (self.input_w - r_h * origin_w) / 2 + y[:, 2] = x[:, 0] + x[:, 2] / 2 - (self.input_w - r_h * origin_w) / 2 + y[:, 1] = x[:, 1] - x[:, 3] / 2 + y[:, 3] = x[:, 1] + x[:, 3] / 2 + y /= r_h + + return y + + def post_process(self, output, origin_h, origin_w): + """ + description: postprocess the prediction + param: + output: A tensor likes [num_boxes,cx,cy,w,h,conf,cls_id, cx,cy,w,h,conf,cls_id, ...] + origin_h: height of original image + origin_w: width of original image + return: + result_boxes: finally boxes, a boxes tensor, each row is a box [x1, y1, x2, y2] + result_scores: finally scores, a tensor, each element is the score correspoing to box + result_classid: finally classid, a tensor, each element is the classid correspoing to box + """ + # Get the num of boxes detected + num = int(output[0]) + # Reshape to a two dimentional ndarray + pred = np.reshape(output[1:], (-1, 6))[:num, :] + # to a torch Tensor + pred = torch.Tensor(pred).cuda() + # Get the boxes + boxes = pred[:, :4] + # Get the scores + scores = pred[:, 4] + # Get the classid + classid = pred[:, 5] + # Choose those boxes that score > CONF_THRESH + si = scores > self.conf_thresh + boxes = boxes[si, :] + scores = scores[si] + classid = classid[si] + # Trandform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2] + boxes = self.xywh2xyxy(origin_h, origin_w, boxes) + # Do nms + indices = torchvision.ops.nms(boxes, scores, iou_threshold=self.iou_thresh).cpu() + result_boxes = boxes[indices, :].cpu() + result_scores = scores[indices].cpu() + result_classid = classid[indices].cpu() + return result_boxes, result_scores, result_classid + + +class inferThread(threading.Thread): + def __init__(self, yolov5_wrapper, image_path_batch): + threading.Thread.__init__(self) + self.yolov5_wrapper = yolov5_wrapper + self.image_path_batch = image_path_batch + + def run(self): + batch_image_raw, use_time = self.yolov5_wrapper.infer(self.yolov5_wrapper.get_raw_image(self.image_path_batch)) + for i, img_path in enumerate(self.image_path_batch): + parent, filename = os.path.split(img_path) + save_name = os.path.join('output', filename) + # Save image + cv2.imwrite(save_name, batch_image_raw[i]) + print(f'input->{self.image_path_batch}, time->{use_time * 1000:.2f}ms, saving into output/') + + +class warmUpThread(threading.Thread): + def __init__(self, yolov5_wrapper): + threading.Thread.__init__(self) + self.yolov5_wrapper = yolov5_wrapper + + def run(self): + batch_image_raw, use_time = self.yolov5_wrapper.infer(self.yolov5_wrapper.get_raw_image_zeros()) + print(f'warm_up->{batch_image_raw[0].shape}, time->{use_time * 1000:.2f}ms') + +if __name__ == "__main__": + # load custom plugins + PLUGIN_LIBRARY = "build5/libmyplugins.so" + engine_file_path = "build5/yolov5l.engine" + + if len(sys.argv) > 1: + engine_file_path = sys.argv[1] + if len(sys.argv) > 2: + PLUGIN_LIBRARY = sys.argv[2] + + ctypes.CDLL(PLUGIN_LIBRARY) + + # load coco labels + + categories = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", + "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", + "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", + "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", + "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", + "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", + "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", + "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", + "hair drier", "toothbrush"] + + if os.path.exists('output/'): + shutil.rmtree('output/') + os.makedirs('output/') + # a YoLov5TRT instance + yolov5_wrapper = YoLov5TRT(engine_file_path) + try: + print('batch size is', yolov5_wrapper.batch_size) + + image_dir = "../enginetest/images" + image_path_batches = get_img_path_batches(yolov5_wrapper.batch_size, image_dir) + + for i in range(10): + # create a new thread to do warm_up + thread1 = warmUpThread(yolov5_wrapper) + thread1.start() + thread1.join() + for batch in image_path_batches: + print("1111111111111",batch) + # create a new thread to do inference + thread1 = inferThread(yolov5_wrapper, batch) + thread1.start() + thread1.join() + finally: + # destroy the instance + yolov5_wrapper.destroy() diff --git a/val.py b/val.py index dfabb65b979c..749a78dde3ba 100644 --- a/val.py +++ b/val.py @@ -2,11 +2,16 @@ """ Validate a trained YOLOv5 model accuracy on a custom dataset +the tensorrt .engine file can be validated now + Usage: $ python path/to/val.py --data coco128.yaml --weights yolov5s.pt --img 640 + + $ python path/to/val.py --data coco128.yaml --img 640 --engine_library xxxx.so --engine_path xxxx.engine """ import argparse +import ctypes import json import os import sys @@ -32,6 +37,7 @@ from utils.metrics import ConfusionMatrix, ap_per_class from utils.plots import output_to_target, plot_images, plot_val_study from utils.torch_utils import select_device, time_sync +from utils.yolov5_trt import YoLov5TRT def save_one_txt(predn, save_conf, shape, file): @@ -107,39 +113,52 @@ def run(data, plots=True, callbacks=Callbacks(), compute_loss=None, + engine_library=None, + engine_path=None, ): # Initialize/load model and set device + val_engine=engine_path is not None training = model is not None - if training: # called by train.py - device, pt = next(model.parameters()).device, True # get model device, PyTorch model - - half &= device.type != 'cpu' # half precision only supported on CUDA - model.half() if half else model.float() - else: # called directly + if val_engine: + ctypes.CDLL(engine_library) + model = YoLov5TRT(engine_path,conf_thres,iou_thres) device = select_device(device, batch_size=batch_size) - - # Directories + pt=False + stride=32 save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir + else: + if training: # called by train.py + device, pt = next(model.parameters()).device, True # get model device, PyTorch model + half &= device.type != 'cpu' # half precision only supported on CUDA + model.half() if half else model.float() + model.eval() + else: # called directly + device = select_device(device, batch_size=batch_size) + + # Directories + save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run + (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir + + # Load model + model = DetectMultiBackend(weights, device=device, dnn=dnn) + stride, pt = model.stride, model.pt + imgsz = check_img_size(imgsz, s=stride) # check image size + half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA + if pt: + model.model.half() if half else model.model.float() + else: + half = False + batch_size = 1 # export.py models default to batch-size 1 + device = torch.device('cpu') + LOGGER.info(f'Forcing --batch-size 1 square inference shape(1,3,{imgsz},{imgsz}) for non-PyTorch backends') - # Load model - model = DetectMultiBackend(weights, device=device, dnn=dnn) - stride, pt = model.stride, model.pt - imgsz = check_img_size(imgsz, s=stride) # check image size - half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA - if pt: - model.model.half() if half else model.model.float() - else: - half = False - batch_size = 1 # export.py models default to batch-size 1 - device = torch.device('cpu') - LOGGER.info(f'Forcing --batch-size 1 square inference shape(1,3,{imgsz},{imgsz}) for non-PyTorch backends') - - # Data - data = check_dataset(data) # check + # Data + # check + model.eval() # Configure - model.eval() + data = check_dataset(data) is_coco = isinstance(data.get('val'), str) and data['val'].endswith('coco/val2017.txt') # COCO dataset nc = 1 if single_cls else int(data['nc']) # number of classes iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95 @@ -151,12 +170,16 @@ def run(data, model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.model.parameters()))) # warmup pad = 0.0 if task == 'speed' else 0.5 task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images - dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=pt, + batch_size=model.batch_size if val_engine else batch_size + dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=True, prefix=colorstr(f'{task}: '))[0] seen = 0 confusion_matrix = ConfusionMatrix(nc=nc) - names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)} + if val_engine: + names = {k: v for k, v in enumerate(data['names'])} + else: + names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)} class_map = coco80_to_coco91_class() if is_coco else list(range(1000)) s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95') dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 @@ -167,27 +190,38 @@ def run(data, t1 = time_sync() if pt: im = im.to(device, non_blocking=True) - targets = targets.to(device) + targets = targets.to(device) im = im.half() if half else im.float() # uint8 to fp16/32 - im /= 255 # 0 - 255 to 0.0 - 1.0 + if not val_engine: + im /= 255 # 0 - 255 to 0.0 - 1.0 nb, _, height, width = im.shape # batch size, channels, height, width t2 = time_sync() dt[0] += t2 - t1 - + if val_engine: + [result_boxes, result_scores, result_classid] = model.infer(model.get_raw_image(paths))[2] + result_boxes = result_boxes.numpy().tolist() + result_scores = result_scores.numpy().tolist() + result_classid = result_classid.numpy().tolist() + result_scores = [[i] for i in result_scores] + result_classid = [[i] for i in result_classid] + out1 = np.hstack((result_boxes, result_scores, result_classid)) + out = [torch.from_numpy(out1).to(device='cuda')] + targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) + else: # Inference - out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs - dt[1] += time_sync() - t2 + out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs + dt[1] += time_sync() - t2 - # Loss - if compute_loss: - loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls + # Loss + if compute_loss: + loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls - # NMS - targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels - lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling - t3 = time_sync() - out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls) - dt[2] += time_sync() - t3 + # NMS + targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels + lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling + t3 = time_sync() + out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls) + dt[2] += time_sync() - t3 # Metrics for si, pred in enumerate(out): @@ -206,7 +240,8 @@ def run(data, if single_cls: pred[:, 5] = 0 predn = pred.clone() - scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred + if not val_engine: + scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred # Evaluate if nl: @@ -228,12 +263,15 @@ def run(data, callbacks.run('on_val_image_end', pred, predn, path, names, im[si]) # Plot images - if plots and batch_i < 3: + if plots and batch_i < 3 and not val_engine: f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels Thread(target=plot_images, args=(im, targets, paths, f, names), daemon=True).start() f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start() - + if val_engine: + model.destroy() + else: + model.float() # for training # Compute metrics stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy if len(stats) and stats[0].any(): @@ -291,7 +329,7 @@ def run(data, LOGGER.info(f'pycocotools unable to run: {e}') # Return results - model.float() # for training + if not training: s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}") @@ -305,6 +343,7 @@ def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)') + parser.add_argument('--batch-size', type=int, default=32, help='batch size') parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold') @@ -323,6 +362,8 @@ def parse_opt(): parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference') + parser.add_argument('--engine_library', type=str, default=None, help='.so file of the tensorrt file') + parser.add_argument('--engine_path', type=str, default=None, help='.engine file of the tensorrt file') opt = parser.parse_args() opt.data = check_yaml(opt.data) # check YAML opt.save_json |= opt.data.endswith('coco.yaml')