1010
1111import numpy as np
1212import torch
13- import torchvision .transforms .functional as F
1413
14+ from ultralytics .yolo .data .augment import LetterBox
1515from ultralytics .yolo .utils import LOGGER , SimpleClass , deprecation_warn , ops
1616from ultralytics .yolo .utils .plotting import Annotator , colors
17- from ultralytics .yolo .utils .torch_utils import TORCHVISION_0_10
1817
1918
2019class BaseTensor (SimpleClass ):
@@ -160,6 +159,7 @@ def plot(
160159 pil = False ,
161160 example = 'abc' ,
162161 img = None ,
162+ img_gpu = None ,
163163 kpt_line = True ,
164164 labels = True ,
165165 boxes = True ,
@@ -178,14 +178,15 @@ def plot(
178178 pil (bool): Whether to return the image as a PIL Image.
179179 example (str): An example string to display. Useful for indicating the expected format of the output.
180180 img (numpy.ndarray): Plot to another image. if not, plot to original image.
181+ img_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
181182 kpt_line (bool): Whether to draw lines connecting keypoints.
182183 labels (bool): Whether to plot the label of bounding boxes.
183184 boxes (bool): Whether to plot the bounding boxes.
184185 masks (bool): Whether to plot the masks.
185186 probs (bool): Whether to plot classification probability
186187
187188 Returns:
188- (None) or (PIL.Image ): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned .
189+ (numpy.ndarray ): A numpy array of the annotated image .
189190 """
190191 # Deprecation warn TODO: remove in 8.2
191192 if 'show_conf' in kwargs :
@@ -200,22 +201,20 @@ def plot(
200201 pred_probs , show_probs = self .probs , probs
201202 names = self .names
202203 keypoints = self .keypoints
204+ if pred_masks and show_masks :
205+ if img_gpu is None :
206+ img = LetterBox (pred_masks .shape [1 :])(image = annotator .im )
207+ img_gpu = torch .as_tensor (img , dtype = torch .float16 , device = pred_masks .masks .device ).permute (
208+ 2 , 0 , 1 ).flip (0 ).contiguous () / 255
209+ annotator .masks (pred_masks .data , colors = [colors (x , True ) for x in pred_boxes .cls ], im_gpu = img_gpu )
210+
203211 if pred_boxes and show_boxes :
204212 for d in reversed (pred_boxes ):
205213 c , conf , id = int (d .cls ), float (d .conf ) if conf else None , None if d .id is None else int (d .id .item ())
206214 name = ('' if id is None else f'id:{ id } ' ) + names [c ]
207215 label = (f'{ name } { conf :.2f} ' if conf else name ) if labels else None
208216 annotator .box_label (d .xyxy .squeeze (), label , color = colors (c , True ))
209217
210- if pred_masks and show_masks :
211- im = torch .as_tensor (annotator .im , dtype = torch .float16 , device = pred_masks .data .device ).permute (2 , 0 ,
212- 1 ).flip (0 )
213- if TORCHVISION_0_10 :
214- im = F .resize (im .contiguous (), pred_masks .data .shape [1 :], antialias = True ) / 255
215- else :
216- im = F .resize (im .contiguous (), pred_masks .data .shape [1 :]) / 255
217- annotator .masks (pred_masks .data , colors = [colors (x , True ) for x in pred_boxes .cls ], im_gpu = im )
218-
219218 if pred_probs is not None and show_probs :
220219 n5 = min (len (names ), 5 )
221220 top5i = pred_probs .argsort (0 , descending = True )[:n5 ].tolist () # top 5 indices
@@ -226,7 +225,7 @@ def plot(
226225 for k in reversed (keypoints ):
227226 annotator .kpts (k , self .orig_shape , kpt_line = kpt_line )
228227
229- return np . asarray ( annotator .im ) if annotator . pil else annotator . im
228+ return annotator .result ()
230229
231230
232231class Boxes (BaseTensor ):
0 commit comments