You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been trying to use the model for long duration videos with manipulating the script grounded_sam2_tracking_demo_with_continuous_id. However, somehow as frames getting processed vram memory usage keeps accumulating.
I've tried reconstructing both sam2 and dino models at the beginning of each chunk(e.g. 200frames) but it didn't help. Any advice?
Here is my code:
import os
import cv2
import torch
import numpy as np
import supervision as sv
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images
from utils.common_utils import CommonUtils
from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
import json
import copy
import gc
import tracemalloc
import shutil
"""
Step 1: Environment settings and model initialization
"""
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# init sam image predictor and video predictor model
sam2_checkpoint = "./checkpoints/sam2.1_hiera_tiny.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device", device)
# init grounding dino model from huggingface
model_id = "IDEA-Research/grounding-dino-tiny"
# setup the input image and text prompt for SAM 2 and Grounding DINO
# VERY important: text queries need to be lowercased + end with a dot
text = "people."
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "/content/extracted_video_frames"
# 'output_dir' is the directory to save the annotated frames
output_dir = "./outputs"
# 'output_video_path' is the path to save the final video
output_video_path = "./outputs/jackoutput.mp4"
# create the output directory
CommonUtils.creat_dirs(output_dir)
mask_data_dir = os.path.join(output_dir, "mask_data")
json_data_dir = os.path.join(output_dir, "json_data")
result_dir = os.path.join(output_dir, "result")
CommonUtils.creat_dirs(mask_data_dir)
CommonUtils.creat_dirs(json_data_dir)
# scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
print("Total frames:", len(frame_names))
chunk_size = 200
step = 20
for chunk_start in range(0, len(frame_names), chunk_size):
chunk_end = min(chunk_start + chunk_size, len(frame_names))
processor = AutoProcessor.from_pretrained(model_id)
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
image_predictor = SAM2ImagePredictor(sam2_image_model)
chunk_dir = f"./temp_chunk_{chunk_start}"
os.makedirs(chunk_dir, exist_ok=True)
chunk_frame_names = frame_names[chunk_start:chunk_end]
for fname in chunk_frame_names:
shutil.copy(
os.path.join(video_dir, fname),
os.path.join(chunk_dir, fname)
)
# CHUNK İÇİN YENİ inference_state OLUŞTUR
inference_state = video_predictor.init_state(
video_path=chunk_dir,
offload_video_to_cpu=True,
async_loading_frames=False
)
PROMPT_TYPE_FOR_VIDEO = "mask"
sam2_masks = MaskDictionaryModel()
objects_count = 0
print(f"Processing frames {chunk_start} to {chunk_end}")
for start_frame_idx in range(chunk_start, chunk_end, step):
img_path = os.path.join(video_dir, frame_names[start_frame_idx])
image = Image.open(img_path)
image_base_name = frame_names[start_frame_idx].split(".")[0]
mask_dict = MaskDictionaryModel(promote_type = PROMPT_TYPE_FOR_VIDEO,
mask_name = f"mask_{image_base_name}.npy")
# Run Grounding DINO
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = grounding_model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=0.25,
text_threshold=0.25,
target_sizes=[image.size[::-1]]
)
# Run SAM predictor
image_predictor.set_image(np.array(image.convert("RGB")))
input_boxes = results[0]["boxes"]
OBJECTS = results[0]["labels"]
if input_boxes.shape[0] != 0:
masks, scores, logits = image_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
if masks.ndim == 2:
masks = masks[None]
scores = scores[None]
logits = logits[None]
elif masks.ndim == 4:
masks = masks.squeeze(1)
if mask_dict.promote_type == "mask":
mask_dict.add_new_frame_annotation(
mask_list=torch.tensor(masks).to(device).detach(),
box_list=torch.tensor(input_boxes),
label_list=OBJECTS
)
else:
raise NotImplementedError("Only mask prompts supported")
objects_count = mask_dict.update_masks(
tracking_annotation_dict=sam2_masks,
iou_threshold=0.8,
objects_count=objects_count
)
else:
print(f"No object detected in frame {start_frame_idx}")
mask_dict = sam2_masks
if len(mask_dict.labels) == 0:
mask_dict.save_empty_mask_and_json(
mask_data_dir,
json_data_dir,
image_name_list=frame_names[start_frame_idx:start_frame_idx+step]
)
print(f"No object detected in frame {start_frame_idx}, skipping.")
continue
else:
video_predictor.reset_state(inference_state)
chunk_frame_idx = start_frame_idx - chunk_start
for object_id, object_info in mask_dict.labels.items():
frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
inference_state,
chunk_frame_idx,
object_id,
object_info.mask,
)
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(
inference_state,
max_frame_num_to_track=step,
start_frame_idx=chunk_frame_idx
):
frame_masks = MaskDictionaryModel()
global_out_frame_idx = out_frame_idx + chunk_start
for i, out_obj_id in enumerate(out_obj_ids):
out_mask = (out_mask_logits[i] > 0.0)
object_info = ObjectInfo(
instance_id=out_obj_id,
mask=out_mask[0],
class_name=mask_dict.get_target_class_name(out_obj_id)
)
object_info.update_box()
frame_masks.labels[out_obj_id] = object_info
image_base_name = frame_names[global_out_frame_idx].split(".")[0]
frame_masks.mask_name = f"mask_{image_base_name}.npy"
frame_masks.mask_height = out_mask.shape[-2]
frame_masks.mask_width = out_mask.shape[-1]
video_segments[global_out_frame_idx] = frame_masks
sam2_masks = frame_masks
# Save segment masks
for frame_idx, frame_masks_info in video_segments.items():
mask = frame_masks_info.labels
mask_img = torch.zeros(
frame_masks_info.mask_height,
frame_masks_info.mask_width
)
for obj_id, obj_info in mask.items():
mask_img[obj_info.mask == True] = obj_id
mask_img = mask_img.cpu().numpy().astype(np.uint16)
np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)
json_data = frame_masks_info.to_dict()
json_data_path = os.path.join(
json_data_dir,
frame_masks_info.mask_name.replace(".npy", ".json")
)
with open(json_data_path, "w") as f:
json.dump(json_data, f)
# Cleanup
del image
del inputs
del outputs
del masks, scores, logits
del mask_dict
del video_segments
gc.collect()
torch.cuda.empty_cache()
shutil.rmtree(chunk_dir)
# CHUNK BİTİNCE TAM TEMİZLİK
del inference_state
del video_predictor
del sam2_image_model
del image_predictor
del processor
del grounding_model
torch.cuda.empty_cache()
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I've been trying to use the model for long duration videos with manipulating the script grounded_sam2_tracking_demo_with_continuous_id. However, somehow as frames getting processed vram memory usage keeps accumulating.
I've tried reconstructing both sam2 and dino models at the beginning of each chunk(e.g. 200frames) but it didn't help. Any advice?
Here is my code:
Beta Was this translation helpful? Give feedback.
All reactions