forked from ronghuaiyang/arcface-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvalidate_model.py
More file actions
183 lines (147 loc) · 5.02 KB
/
validate_model.py
File metadata and controls
183 lines (147 loc) · 5.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
简化版模型验证脚本
"""
import torch
from models import resnet_face18
import numpy as np
import os
import cv2
from config.yaml_config import Config
import sys
def load_image(img_path):
"""加载LFW图片"""
image = cv2.imread(img_path, 0)
if image is None:
return None
image = np.dstack((image, np.fliplr(image)))
image = image.transpose((2, 0, 1))
image = image[:, np.newaxis, :, :]
image = image.astype(np.float32, copy=False)
image -= 127.5
image /= 127.5
return image
def cosin_metric(x1, x2):
"""计算余弦相似度"""
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
def main():
print("="*80)
print("模型验证报告")
print("="*80)
# 1. 加载配置和模型
print("\n[1/4] 加载配置和模型...")
opt = Config('config.yaml')
device = torch.device("cuda" if opt.use_gpu and torch.cuda.is_available() else "cpu")
print(f" 设备: {device}")
print(f" 模型路径: {opt.test_model_path}")
if not os.path.exists(opt.test_model_path):
print(f" ✗ 模型文件不存在")
return
model = resnet_face18(use_se=opt.use_se)
model.load_state_dict(torch.load(opt.test_model_path, map_location=device))
model.to(device)
model.eval()
print(f" ✓ 模型加载成功")
# 2. LFW配对验证
print(f"\n[2/4] LFW配对验证...")
print(f" 测试列表: {opt.lfw_test_list}")
if not os.path.exists(opt.lfw_test_list):
print(f" ✗ 测试列表不存在")
return
with open(opt.lfw_test_list, 'r') as f:
pairs = f.readlines()
print(f" 总配对数: {len(pairs)}")
sims = []
labels = []
failed = 0
processed = 0
with torch.no_grad():
for idx, pair in enumerate(pairs):
splits = pair.strip().split()
if len(splits) != 3:
continue
img_path1 = os.path.join(opt.lfw_root, splits[0])
img_path2 = os.path.join(opt.lfw_root, splits[1])
label = int(splits[2])
# 加载图片
image1 = load_image(img_path1)
image2 = load_image(img_path2)
if image1 is None or image2 is None:
failed += 1
continue
# 转换为tensor
data1 = torch.from_numpy(image1).to(device)
data2 = torch.from_numpy(image2).to(device)
# 提取特征
output1 = model(data1)
output2 = model(data2)
# 合并原图和翻转图的特征
fe_1 = output1.data.cpu().numpy()
fe_2 = output2.data.cpu().numpy()
feature1 = np.hstack((fe_1[0], fe_1[1]))
feature2 = np.hstack((fe_2[0], fe_2[1]))
# 计算余弦相似度
sim = cosin_metric(feature1, feature2)
sims.append(sim)
labels.append(label)
processed += 1
# 进度显示
if (idx + 1) % 1000 == 0:
print(f" 处理进度: {idx+1}/{len(pairs)}", end='\r')
sys.stdout.flush()
print(f"\n ✓ 处理完成: {processed}个有效配对")
if failed > 0:
print(f" ✗ 失败: {failed}个配对")
# 3. 计算准确率
print(f"\n[3/4] 计算准确率...")
sims = np.asarray(sims)
labels = np.asarray(labels)
# 找最佳阈值
best_acc = 0
best_th = 0
thresholds = sorted(sims)
for th in thresholds[::10]: # 每隔10个采样
y_test = (sims >= th)
acc = np.mean((y_test == labels).astype(int))
if acc > best_acc:
best_acc = acc
best_th = th
print(f" 最佳准确率: {best_acc*100:.2f}%")
print(f" 最佳阈值: {best_th:.4f}")
# 4. 统计分析
print(f"\n[4/4] 相似度分布分析...")
pos_sims = sims[labels == 1]
neg_sims = sims[labels == 0]
print(f" 同一人 (正样本):")
print(f" 数量: {len(pos_sims)}")
print(f" 相似度: {pos_sims.mean():.4f} ± {pos_sims.std():.4f}")
print(f" 范围: [{pos_sims.min():.4f}, {pos_sims.max():.4f}]")
print(f" 不同人 (负样本):")
print(f" 数量: {len(neg_sims)}")
print(f" 相似度: {neg_sims.mean():.4f} ± {neg_sims.std():.4f}")
print(f" 范围: [{neg_sims.min():.4f}, {neg_sims.max():.4f}]")
margin = pos_sims.mean() - neg_sims.mean()
print(f"\n 判别间距: {margin:.4f}")
# 评估判别性
if margin > 0.2:
quality = "优秀"
elif margin > 0.15:
quality = "良好"
elif margin > 0.1:
quality = "一般"
else:
quality = "较差"
print(f" 特征判别性: {quality}")
# 最终总结
print("\n" + "="*80)
print("验证总结")
print("="*80)
print(f"✓ LFW准确率: {best_acc*100:.2f}%")
print(f"✓ 最佳阈值: {best_th:.4f}")
print(f"✓ 判别间距: {margin:.4f}")
print(f"✓ 特征质量: {quality}")
print("="*80)
return best_acc
if __name__ == '__main__':
acc = main()