一个用于训练基于深度神经网络(DNN)的语音增强模型的代码模板,对深度学习工程师来说具有很高的价值,因为它可以显著提升工作效率。尽管不同程序员的编码风格各不相同,有的非常优秀,有的则相对一般,但我的理念始终是优先追求简洁性。在此背景下,我分享了一套在语音增强(Speech Enhancement,SE)任务中非常实用的训练代码文件组织结构。该模板的核心目标是简洁直观,而不是追求面面俱到。
- 2025-03-31:新增
plus分支,提供更优的实现方式,建议直接使用该分支。 - 2024-05-28:新增
pro分支,提供更完善的实现。
configs:训练与推理的配置文件。DNSMOS:来自微软的预训练 DNSMOS 模型权重。evaluation:基于 URGENT 2024 官方脚本改写的评测指标计算代码。models:模型定义。prepare_datasets:用于生成 DNS3 训练数据的脚本。dataloader.py:数据加载器所使用的数据集类。distributed_utils.py:分布式数据并行(DDP)训练相关工具函数。evaluate.py:基于推理阶段生成的 scp 文件进行评估的脚本。infer.py:模型推理脚本。loss_factory.py:语音增强任务中常用的多种损失函数。scheduler.py:Warmup 学习率调度器定义。train.py:训练脚本,支持单 GPU 与多 GPU 训练。
在启动一个新的语音增强(SE)项目时,建议按照以下步骤进行:
- 修改
dataloader.py,使其适配你的数据集; - 在
models目录中定义你自己的模型; - 修改
configs/cfg_train.yaml,以匹配你的训练配置; - 在
loss_factory.py中选择合适的损失函数,或根据需要自行新增; - 运行
train.py开始训练:
python train.py
python train.py -D 1
python train.py -C configs/cfg_train.yaml -D 1
python train.py -C configs/cfg_train.yaml -D 0,1,2,3- 训练完成后,在
configs/cfg_infer.yaml中指定模型 checkpoint 及相关路径; - 运行
evaluate.py进行模型评估。
- 本代码最初是为 Linux 系统 设计的,如果尝试在 Windows
平台上运行,可能会遇到以下问题:
- 路径不兼容:Linux 下使用的文件路径格式可能与 Windows 不一致;
- pesq 包安装困难:在 Windows 系统上安装
pesq包可能需要额外的配置步骤。
- 如果你觉得这个仓库对你有帮助,欢迎点 ⭐ 支持。
本代码模板在多个方面大量借鉴了优秀的开源项目
Sheffield_Clarity_CEC1_Entry,在此表示衷心感谢。