Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions libai/models/bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,12 @@ def forward(
prediction_scores = self.lm_logits(prediction_scores, word_embeddings_weight)

if lm_labels is not None:
prediction_scores = flow._C.amp_white_identity(prediction_scores)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果要追求极致地性能优化,建议专门搞一个 benchmark 分支去测试性能,类似的修改全部在那个分支引入,main 分支建议不要引入这种 magic trick,用户根本看不懂这是在干嘛。

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果这个接口改为:

prediction_scores = flow.amp.cast_fp16_hint(prediction_scores)

可以接受吗?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prediction_scores = flow.amp.to_fp16(prediction_scores)

https://pytorch.org/docs/stable/amp.html#module-torch.amp

torch.amp 也支持自动 + 手动 混合

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得不是接口的命名问题, 而是用户看到这段代码会很疑惑, 不知道在干嘛

但是我看这个pr是想merge到libai_bench分支的, libai_bench分支是专门追求极致的性能的, 不考虑用户的易用性. 我觉得还ok

seq_relationship_score = flow._C.amp_white_identity(seq_relationship_score)
return self.loss_func(
prediction_scores, lm_labels, loss_mask, seq_relationship_score, ns_labels
)

return {
"prediction_scores": prediction_scores,
"seq_relationship_score": seq_relationship_score,
Expand Down