Skip to content

Commit 57a8c53

Browse files
authored
Create Keye_VL_Caption.py
1 parent ea50e65 commit 57a8c53

File tree

1 file changed

+235
-0
lines changed

1 file changed

+235
-0
lines changed

Keye_VL_Caption.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
'''
2+
python Keye_VL_Caption.py Prince_Ciel_Phantomhive_Sebastian_Michaelis_both_Videos_qwen_vl_captioned Prince_Ciel_Phantomhive_Sebastian_Michaelis_both_Videos_keye_captioned --use_flash_attention \
3+
--text "给你的视频中可能出现的主要人物为两个(可能出现一个或两个),当人物为一个戴眼罩的男孩时,男孩的名字是'夏尔',当人物是一个穿燕尾西服的成年男子时,男子的名字是'塞巴斯蒂安',在你的视频描述中要使用人物的名字并且简单描述人物的外貌及衣着。 使用中文描述这个视频 /think"
4+
'''
5+
6+
import os
7+
import torch
8+
import argparse
9+
from pathlib import Path
10+
import shutil
11+
from transformers import AutoModel, AutoProcessor
12+
from keye_vl_utils import process_vision_info
13+
from moviepy.editor import VideoFileClip
14+
import re
15+
16+
def get_video_duration(video_path):
17+
"""获取视频时长(秒)"""
18+
try:
19+
with VideoFileClip(video_path) as video:
20+
return video.duration
21+
except:
22+
return float('inf')
23+
24+
class KeyeVL_Captioner:
25+
def __init__(self, args):
26+
self.args = args
27+
self.model = None
28+
self.processor = None
29+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30+
31+
def setup_model(self):
32+
"""设置和加载Keye-VL模型"""
33+
if self.model is None:
34+
self.model = AutoModel.from_pretrained(
35+
self.args.model_path,
36+
torch_dtype="auto",
37+
trust_remote_code=True,
38+
attn_implementation="flash_attention_2" if self.args.use_flash_attention else "eager",
39+
).eval()
40+
self.model.to(self.device)
41+
42+
if self.processor is None:
43+
self.processor = AutoProcessor.from_pretrained(
44+
self.args.model_path,
45+
trust_remote_code=True
46+
)
47+
48+
def determine_thinking_mode(self, text):
49+
"""根据文本内容确定思考模式"""
50+
if text.endswith('/no_think'):
51+
return "no_think", text.replace('/no_think', '').strip()
52+
elif text.endswith('/think'):
53+
return "think", text.replace('/think', '').strip()
54+
else:
55+
return "auto", text
56+
57+
def process_media(self, media_path, output_dir):
58+
"""处理单个媒体文件(图片或视频)"""
59+
# 检查文件是否存在
60+
if not os.path.exists(media_path):
61+
print(f"文件不存在: {media_path}")
62+
return None
63+
64+
# 获取文件扩展名确定媒体类型
65+
ext = os.path.splitext(media_path)[1].lower()
66+
is_video = ext in ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv']
67+
is_image = ext in ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
68+
69+
if not (is_video or is_image):
70+
print(f"不支持的文件格式: {media_path}")
71+
return None
72+
73+
# 视频时长过滤
74+
if is_video:
75+
duration = get_video_duration(media_path)
76+
print(f"视频: {media_path}, 时长: {duration}秒")
77+
if self.args.max_duration > 0 and duration > self.args.max_duration:
78+
print(f"跳过时长超过限制的视频: {duration}秒 > {self.args.max_duration}秒")
79+
return None
80+
81+
# 确定思考模式
82+
thinking_mode, processed_text = self.determine_thinking_mode(self.args.text)
83+
processed_text = self.args.text
84+
85+
# 准备媒体输入
86+
media_content = []
87+
if is_video:
88+
media_content.append({
89+
"type": "video",
90+
"video": media_path,
91+
"fps": self.args.fps,
92+
"max_frames": self.args.max_frames
93+
})
94+
else:
95+
media_content.append({
96+
"type": "image",
97+
"image": media_path
98+
})
99+
100+
# 构建消息
101+
messages = [
102+
{
103+
"role": "user",
104+
"content": media_content + [
105+
{"type": "text", "text": processed_text + " \think" },
106+
],
107+
}
108+
]
109+
110+
# 处理视觉信息并生成输入
111+
text = self.processor.apply_chat_template(
112+
messages, tokenize=False, add_generation_prompt=True
113+
)
114+
image_inputs, video_inputs, mm_processor_kwargs = process_vision_info(messages)
115+
116+
inputs = self.processor(
117+
text=[text],
118+
images=image_inputs,
119+
videos=video_inputs,
120+
padding=True,
121+
return_tensors="pt",
122+
**mm_processor_kwargs
123+
)
124+
inputs = inputs.to(self.device)
125+
126+
# 生成描述
127+
with torch.no_grad():
128+
generated_ids = self.model.generate(
129+
**inputs,
130+
max_new_tokens=self.args.max_new_tokens,
131+
temperature=self.args.temperature,
132+
do_sample=self.args.temperature > 0
133+
)
134+
generated_ids_trimmed = [
135+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
136+
]
137+
output_text = self.processor.batch_decode(
138+
generated_ids_trimmed,
139+
skip_special_tokens=True,
140+
clean_up_tokenization_spaces=False
141+
)
142+
143+
result = output_text[0] if isinstance(output_text, list) else output_text
144+
145+
# 清理结果
146+
#result = re.sub(r'<[^>]*>', '', result).strip()
147+
148+
# 保存结果
149+
media_name = os.path.basename(media_path)
150+
txt_filename = os.path.splitext(media_name)[0] + ".txt"
151+
txt_path = os.path.join(output_dir, txt_filename)
152+
153+
with open(txt_path, 'w', encoding='utf-8') as f:
154+
f.write(result)
155+
156+
# 复制媒体文件到输出目录
157+
output_media_path = os.path.join(output_dir, media_name)
158+
shutil.copy2(media_path, output_media_path)
159+
160+
print(f"文件: {media_name}")
161+
print(f"思考模式: {thinking_mode}")
162+
print(f"描述: {result}")
163+
print("-" * 50)
164+
165+
return result
166+
167+
def process_all_media(self):
168+
"""处理所有媒体文件"""
169+
os.makedirs(self.args.output_dir, exist_ok=True)
170+
self.setup_model()
171+
172+
if os.path.isfile(self.args.source_path):
173+
self.process_media(self.args.source_path, self.args.output_dir)
174+
elif os.path.isdir(self.args.source_path):
175+
supported_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv',
176+
'.jpg', '.jpeg', '.png', '.bmp', '.gif']
177+
178+
for file in os.listdir(self.args.source_path):
179+
if any(file.lower().endswith(ext) for ext in supported_extensions):
180+
media_path = os.path.join(self.args.source_path, file)
181+
self.process_media(media_path, self.args.output_dir)
182+
183+
if not self.args.keep_model_loaded:
184+
self.cleanup()
185+
186+
def cleanup(self):
187+
"""清理模型和释放内存"""
188+
del self.model
189+
del self.processor
190+
self.model = None
191+
self.processor = None
192+
if torch.cuda.is_available():
193+
torch.cuda.empty_cache()
194+
195+
def main():
196+
parser = argparse.ArgumentParser(description="Keye-VL媒体描述生成工具")
197+
198+
# 必需参数
199+
parser.add_argument("source_path", help="输入媒体文件路径或包含媒体文件的文件夹路径")
200+
parser.add_argument("output_dir", help="输出目录路径")
201+
202+
# 模型参数
203+
parser.add_argument("--model_path", default="Kwai-Keye/Keye-VL-1_5-8B",
204+
help="Keye-VL模型路径")
205+
parser.add_argument("--use_flash_attention", action="store_true",
206+
help="是否使用flash attention加速")
207+
208+
# 处理参数
209+
parser.add_argument("--text", default="请描述这个内容",
210+
help="描述提示文本,可添加/think或/no_think后缀指定模式")
211+
parser.add_argument("--max_duration", type=float, default=10.0,
212+
help="最大处理视频时长(秒),-1表示无限制")
213+
parser.add_argument("--fps", type=float, default=1.0,
214+
help="视频采样帧率")
215+
parser.add_argument("--max_frames", type=int, default=16,
216+
help="最大处理帧数")
217+
218+
# 生成参数
219+
parser.add_argument("--temperature", type=float, default=0.7,
220+
help="生成温度(0-1)")
221+
parser.add_argument("--max_new_tokens", type=int, default=1024,
222+
help="最大新生成token数量")
223+
224+
# 其他参数
225+
parser.add_argument("--keep_model_loaded", action="store_true",
226+
help="处理完成后保持模型加载状态")
227+
228+
args = parser.parse_args()
229+
230+
# 创建处理器并处理媒体文件
231+
captioner = KeyeVL_Captioner(args)
232+
captioner.process_all_media()
233+
234+
if __name__ == "__main__":
235+
main()

0 commit comments

Comments
 (0)