Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
propagated fix for pipeline parallelism
Co-authored-by: Francesco Bertolotti <[email protected]>
  • Loading branch information
francesco-bertolotti and f14-bertolotti committed Dec 12, 2025
commit 628dbb6b5e040e40f2db4f825e34b7c010d5387f
48 changes: 37 additions & 11 deletions torchtitan/components/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Generator, Any
from typing import Any, Generator

import torch
import torch.nn as nn
import torchtitan.protocols.train_spec as train_spec_module
from torch.distributed.pipelining.schedules import _PipelineSchedule
from torchtitan.components.dataloader import BaseDataLoader
from torchtitan.components.loss import LossFunction
Expand All @@ -18,7 +19,6 @@
from torchtitan.hf_datasets.text_datasets import build_text_validation_dataloader
from torchtitan.tools import utils
from torchtitan.tools.logging import logger
import torchtitan.protocols.train_spec as train_spec_module


class BaseValidator:
Expand Down Expand Up @@ -89,6 +89,31 @@ def __init__(
"unequal sample counts across ranks when dataset is exhausted."
)

def post_dataloading_process(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reuse the one in train.py?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, it would be possible, but I would need to make some modifications since their signatures differ. In particular, the trainer accesses model_parts through self.

One possible solution is to turn post_dataloading_process into a utility function that both the trainer and validator can call. However, this would mean you can no longer modify the behavior through inheritance—unless the current post_dataloading_process simply becomes a wrapper around the new utility function.|

If you have any suggestions, I can add a commit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.
If you want to "modify the behavior", we probably should have a PostDataloadingProcessor class.
If not, I think it's fine to call a util directly in both trainer and validator.

WDYT?
cc @fegin

Copy link
Contributor

@fegin fegin Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking this as well, I'm seeing the duplicated code in trainer and validator, model_parts is not a big problem, you can add it. See #2144. But the duplicated logic is pretty concerning.

My initial proposal is to have an util function in distributed.utils so that both methods can call it. The only worry is that I don't know whether distributed.utils is the best place to put it. It is not now because the util function purely unwrap the input but the util function will encapsulate the CP logic after #2144. So distributed.utils seems to be a good place.

@francesco-bertolotti you can do it as your PR may be landed first. I can rebase on top of your change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed we should share code, but not sure about distributed.utils. Let's discuss offline.

self,
model_parts: list[nn.Module],
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]:
"""Post-processing hook after data loading and before model forward pass."""

inputs = input_dict["input"]
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}

# prepare attention mask
extra_kwargs: dict[str, Any] = (
{
"attention_masks": model_parts[0].get_attention_masks(
input_batch=input_dict["input"],
tokenizer=self.tokenizer,
extra_inputs={k: v for k, v in input_dict.items() if k != "input"},
)
}
if getattr(self.model_args, "attn_type", "sdpa") in ["flex", "varlen"]
else {}
)

return inputs, labels, extra_inputs, extra_kwargs

@torch.no_grad()
def validate(
Expand All @@ -113,18 +138,15 @@ def validate(
):
break

# prepare attention mask
extra_kwargs : dict[str, Any] = {"attention_masks" : model_parts[0].get_attention_masks(
input_batch=input_dict["input"],
tokenizer=self.tokenizer,
extra_inputs={k: v for k, v in input_dict.items() if k != "input"},
)} if getattr(self.model_args, "attn_type", "sdpa") in ["flex", "varlen"] else {}
inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
model_parts, input_dict, labels
)

self.metrics_processor.ntokens_since_last_log += labels.numel()
for k, v in input_dict.items():
input_dict[k] = v.to(device_type)
inputs = input_dict["input"]
labels = labels.to(device_type)
inputs = inputs.to(device_type)

optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
Expand Down Expand Up @@ -154,7 +176,9 @@ def validate(
losses=losses,
)
else:
self.pp_schedule.eval(target=targets, losses=losses)
self.pp_schedule.eval(
**extra_kwargs, target=targets, losses=losses
)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
Expand All @@ -170,7 +194,9 @@ def validate(
with self.validation_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
predictions = model_parts[0](inputs, **extra_kwargs)
predictions = model_parts[0](
inputs, **extra_inputs, **extra_kwargs
)
loss = self.loss_fn(predictions, labels)

accumulated_losses.append(loss.detach())
Expand Down
Loading