|
| 1 | +''' |
| 2 | +python Keye_VL_Caption.py Prince_Ciel_Phantomhive_Sebastian_Michaelis_both_Videos_qwen_vl_captioned Prince_Ciel_Phantomhive_Sebastian_Michaelis_both_Videos_keye_captioned --use_flash_attention \ |
| 3 | +--text "给你的视频中可能出现的主要人物为两个(可能出现一个或两个),当人物为一个戴眼罩的男孩时,男孩的名字是'夏尔',当人物是一个穿燕尾西服的成年男子时,男子的名字是'塞巴斯蒂安',在你的视频描述中要使用人物的名字并且简单描述人物的外貌及衣着。 使用中文描述这个视频 /think" |
| 4 | +''' |
| 5 | + |
| 6 | +import os |
| 7 | +import torch |
| 8 | +import argparse |
| 9 | +from pathlib import Path |
| 10 | +import shutil |
| 11 | +from transformers import AutoModel, AutoProcessor |
| 12 | +from keye_vl_utils import process_vision_info |
| 13 | +from moviepy.editor import VideoFileClip |
| 14 | +import re |
| 15 | + |
| 16 | +def get_video_duration(video_path): |
| 17 | + """获取视频时长(秒)""" |
| 18 | + try: |
| 19 | + with VideoFileClip(video_path) as video: |
| 20 | + return video.duration |
| 21 | + except: |
| 22 | + return float('inf') |
| 23 | + |
| 24 | +class KeyeVL_Captioner: |
| 25 | + def __init__(self, args): |
| 26 | + self.args = args |
| 27 | + self.model = None |
| 28 | + self.processor = None |
| 29 | + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 30 | + |
| 31 | + def setup_model(self): |
| 32 | + """设置和加载Keye-VL模型""" |
| 33 | + if self.model is None: |
| 34 | + self.model = AutoModel.from_pretrained( |
| 35 | + self.args.model_path, |
| 36 | + torch_dtype="auto", |
| 37 | + trust_remote_code=True, |
| 38 | + attn_implementation="flash_attention_2" if self.args.use_flash_attention else "eager", |
| 39 | + ).eval() |
| 40 | + self.model.to(self.device) |
| 41 | + |
| 42 | + if self.processor is None: |
| 43 | + self.processor = AutoProcessor.from_pretrained( |
| 44 | + self.args.model_path, |
| 45 | + trust_remote_code=True |
| 46 | + ) |
| 47 | + |
| 48 | + def determine_thinking_mode(self, text): |
| 49 | + """根据文本内容确定思考模式""" |
| 50 | + if text.endswith('/no_think'): |
| 51 | + return "no_think", text.replace('/no_think', '').strip() |
| 52 | + elif text.endswith('/think'): |
| 53 | + return "think", text.replace('/think', '').strip() |
| 54 | + else: |
| 55 | + return "auto", text |
| 56 | + |
| 57 | + def process_media(self, media_path, output_dir): |
| 58 | + """处理单个媒体文件(图片或视频)""" |
| 59 | + # 检查文件是否存在 |
| 60 | + if not os.path.exists(media_path): |
| 61 | + print(f"文件不存在: {media_path}") |
| 62 | + return None |
| 63 | + |
| 64 | + # 获取文件扩展名确定媒体类型 |
| 65 | + ext = os.path.splitext(media_path)[1].lower() |
| 66 | + is_video = ext in ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv'] |
| 67 | + is_image = ext in ['.jpg', '.jpeg', '.png', '.bmp', '.gif'] |
| 68 | + |
| 69 | + if not (is_video or is_image): |
| 70 | + print(f"不支持的文件格式: {media_path}") |
| 71 | + return None |
| 72 | + |
| 73 | + # 视频时长过滤 |
| 74 | + if is_video: |
| 75 | + duration = get_video_duration(media_path) |
| 76 | + print(f"视频: {media_path}, 时长: {duration}秒") |
| 77 | + if self.args.max_duration > 0 and duration > self.args.max_duration: |
| 78 | + print(f"跳过时长超过限制的视频: {duration}秒 > {self.args.max_duration}秒") |
| 79 | + return None |
| 80 | + |
| 81 | + # 确定思考模式 |
| 82 | + thinking_mode, processed_text = self.determine_thinking_mode(self.args.text) |
| 83 | + processed_text = self.args.text |
| 84 | + |
| 85 | + # 准备媒体输入 |
| 86 | + media_content = [] |
| 87 | + if is_video: |
| 88 | + media_content.append({ |
| 89 | + "type": "video", |
| 90 | + "video": media_path, |
| 91 | + "fps": self.args.fps, |
| 92 | + "max_frames": self.args.max_frames |
| 93 | + }) |
| 94 | + else: |
| 95 | + media_content.append({ |
| 96 | + "type": "image", |
| 97 | + "image": media_path |
| 98 | + }) |
| 99 | + |
| 100 | + # 构建消息 |
| 101 | + messages = [ |
| 102 | + { |
| 103 | + "role": "user", |
| 104 | + "content": media_content + [ |
| 105 | + {"type": "text", "text": processed_text + " \think" }, |
| 106 | + ], |
| 107 | + } |
| 108 | + ] |
| 109 | + |
| 110 | + # 处理视觉信息并生成输入 |
| 111 | + text = self.processor.apply_chat_template( |
| 112 | + messages, tokenize=False, add_generation_prompt=True |
| 113 | + ) |
| 114 | + image_inputs, video_inputs, mm_processor_kwargs = process_vision_info(messages) |
| 115 | + |
| 116 | + inputs = self.processor( |
| 117 | + text=[text], |
| 118 | + images=image_inputs, |
| 119 | + videos=video_inputs, |
| 120 | + padding=True, |
| 121 | + return_tensors="pt", |
| 122 | + **mm_processor_kwargs |
| 123 | + ) |
| 124 | + inputs = inputs.to(self.device) |
| 125 | + |
| 126 | + # 生成描述 |
| 127 | + with torch.no_grad(): |
| 128 | + generated_ids = self.model.generate( |
| 129 | + **inputs, |
| 130 | + max_new_tokens=self.args.max_new_tokens, |
| 131 | + temperature=self.args.temperature, |
| 132 | + do_sample=self.args.temperature > 0 |
| 133 | + ) |
| 134 | + generated_ids_trimmed = [ |
| 135 | + out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
| 136 | + ] |
| 137 | + output_text = self.processor.batch_decode( |
| 138 | + generated_ids_trimmed, |
| 139 | + skip_special_tokens=True, |
| 140 | + clean_up_tokenization_spaces=False |
| 141 | + ) |
| 142 | + |
| 143 | + result = output_text[0] if isinstance(output_text, list) else output_text |
| 144 | + |
| 145 | + # 清理结果 |
| 146 | + #result = re.sub(r'<[^>]*>', '', result).strip() |
| 147 | + |
| 148 | + # 保存结果 |
| 149 | + media_name = os.path.basename(media_path) |
| 150 | + txt_filename = os.path.splitext(media_name)[0] + ".txt" |
| 151 | + txt_path = os.path.join(output_dir, txt_filename) |
| 152 | + |
| 153 | + with open(txt_path, 'w', encoding='utf-8') as f: |
| 154 | + f.write(result) |
| 155 | + |
| 156 | + # 复制媒体文件到输出目录 |
| 157 | + output_media_path = os.path.join(output_dir, media_name) |
| 158 | + shutil.copy2(media_path, output_media_path) |
| 159 | + |
| 160 | + print(f"文件: {media_name}") |
| 161 | + print(f"思考模式: {thinking_mode}") |
| 162 | + print(f"描述: {result}") |
| 163 | + print("-" * 50) |
| 164 | + |
| 165 | + return result |
| 166 | + |
| 167 | + def process_all_media(self): |
| 168 | + """处理所有媒体文件""" |
| 169 | + os.makedirs(self.args.output_dir, exist_ok=True) |
| 170 | + self.setup_model() |
| 171 | + |
| 172 | + if os.path.isfile(self.args.source_path): |
| 173 | + self.process_media(self.args.source_path, self.args.output_dir) |
| 174 | + elif os.path.isdir(self.args.source_path): |
| 175 | + supported_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', |
| 176 | + '.jpg', '.jpeg', '.png', '.bmp', '.gif'] |
| 177 | + |
| 178 | + for file in os.listdir(self.args.source_path): |
| 179 | + if any(file.lower().endswith(ext) for ext in supported_extensions): |
| 180 | + media_path = os.path.join(self.args.source_path, file) |
| 181 | + self.process_media(media_path, self.args.output_dir) |
| 182 | + |
| 183 | + if not self.args.keep_model_loaded: |
| 184 | + self.cleanup() |
| 185 | + |
| 186 | + def cleanup(self): |
| 187 | + """清理模型和释放内存""" |
| 188 | + del self.model |
| 189 | + del self.processor |
| 190 | + self.model = None |
| 191 | + self.processor = None |
| 192 | + if torch.cuda.is_available(): |
| 193 | + torch.cuda.empty_cache() |
| 194 | + |
| 195 | +def main(): |
| 196 | + parser = argparse.ArgumentParser(description="Keye-VL媒体描述生成工具") |
| 197 | + |
| 198 | + # 必需参数 |
| 199 | + parser.add_argument("source_path", help="输入媒体文件路径或包含媒体文件的文件夹路径") |
| 200 | + parser.add_argument("output_dir", help="输出目录路径") |
| 201 | + |
| 202 | + # 模型参数 |
| 203 | + parser.add_argument("--model_path", default="Kwai-Keye/Keye-VL-1_5-8B", |
| 204 | + help="Keye-VL模型路径") |
| 205 | + parser.add_argument("--use_flash_attention", action="store_true", |
| 206 | + help="是否使用flash attention加速") |
| 207 | + |
| 208 | + # 处理参数 |
| 209 | + parser.add_argument("--text", default="请描述这个内容", |
| 210 | + help="描述提示文本,可添加/think或/no_think后缀指定模式") |
| 211 | + parser.add_argument("--max_duration", type=float, default=10.0, |
| 212 | + help="最大处理视频时长(秒),-1表示无限制") |
| 213 | + parser.add_argument("--fps", type=float, default=1.0, |
| 214 | + help="视频采样帧率") |
| 215 | + parser.add_argument("--max_frames", type=int, default=16, |
| 216 | + help="最大处理帧数") |
| 217 | + |
| 218 | + # 生成参数 |
| 219 | + parser.add_argument("--temperature", type=float, default=0.7, |
| 220 | + help="生成温度(0-1)") |
| 221 | + parser.add_argument("--max_new_tokens", type=int, default=1024, |
| 222 | + help="最大新生成token数量") |
| 223 | + |
| 224 | + # 其他参数 |
| 225 | + parser.add_argument("--keep_model_loaded", action="store_true", |
| 226 | + help="处理完成后保持模型加载状态") |
| 227 | + |
| 228 | + args = parser.parse_args() |
| 229 | + |
| 230 | + # 创建处理器并处理媒体文件 |
| 231 | + captioner = KeyeVL_Captioner(args) |
| 232 | + captioner.process_all_media() |
| 233 | + |
| 234 | +if __name__ == "__main__": |
| 235 | + main() |
0 commit comments