Skip to content

Commit 0f6226f

Browse files
committed
update train.py
1 parent 8860e12 commit 0f6226f

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,8 @@
397397
wanted_step = 5e4 if optimizer_type == "sgd" else 1.5e4
398398
total_step = num_train // Unfreeze_batch_size * UnFreeze_Epoch
399399
if total_step <= wanted_step:
400+
if num_train // Unfreeze_batch_size == 0:
401+
raise ValueError('数据集过小,无法进行训练,请扩充数据集。')
400402
wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1
401403
print("\n\033[1;33;44m[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m"%(optimizer_type, wanted_step))
402404
print("\033[1;33;44m[Warning] 本次运行的总训练数据量为%d,Unfreeze_batch_size为%d,共训练%d个Epoch,计算出总训练步长为%d。\033[0m"%(num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step))

0 commit comments

Comments
 (0)