Skip to content

[RFC] Do we really need all_reduce in BaseLossContext? #1335

@nil0x9

Description

@nil0x9

Currently in loss calculation, there is an invocation of all-reduce with autograd:

loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD)

IIUC this all-reduce is primarily for logging purposes. It could be potentially be removed safely (plz correct me if I'm wrong here), bc:

  1. FSDP works fine with just local loss, and
  2. with the above reduced loss, during backward stage the gradient would be identical as local_loss;
  3. loss would reduced in TrainEngine (thus it would be reduced twice!)
    dist.all_reduce(reduced_llm_loss.div_(dist.get_world_size()))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions