Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions torchtitan/components/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Generator
from typing import Any, Generator

import torch
import torch.nn as nn
import torchtitan.protocols.train_spec as train_spec_module
from torch.distributed.pipelining.schedules import _PipelineSchedule
from torchtitan.components.dataloader import BaseDataLoader
from torchtitan.components.loss import LossFunction
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
self.job_config = job_config
self.parallel_dims = parallel_dims
self.loss_fn = loss_fn
self.tokenizer = tokenizer
self.validation_dataloader = build_text_validation_dataloader(
job_config=job_config,
dp_world_size=dp_world_size,
Expand All @@ -76,12 +78,43 @@ def __init__(
self.pp_has_first_stage = pp_has_first_stage
self.pp_has_last_stage = pp_has_last_stage

self.train_spec = train_spec_module.get_train_spec(job_config.model.name)
model_args = self.train_spec.model_args[job_config.model.flavor]
model_args.update_from_config(job_config)
self.model_args = model_args

if self.job_config.validation.steps == -1:
logger.warning(
"Setting validation steps to -1 might cause hangs because of "
"unequal sample counts across ranks when dataset is exhausted."
)

def post_dataloading_process(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reuse the one in train.py?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, it would be possible, but I would need to make some modifications since their signatures differ. In particular, the trainer accesses model_parts through self.

One possible solution is to turn post_dataloading_process into a utility function that both the trainer and validator can call. However, this would mean you can no longer modify the behavior through inheritance—unless the current post_dataloading_process simply becomes a wrapper around the new utility function.|

If you have any suggestions, I can add a commit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.
If you want to "modify the behavior", we probably should have a PostDataloadingProcessor class.
If not, I think it's fine to call a util directly in both trainer and validator.

WDYT?
cc @fegin

Copy link
Contributor

@fegin fegin Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking this as well, I'm seeing the duplicated code in trainer and validator, model_parts is not a big problem, you can add it. See #2144. But the duplicated logic is pretty concerning.

My initial proposal is to have an util function in distributed.utils so that both methods can call it. The only worry is that I don't know whether distributed.utils is the best place to put it. It is not now because the util function purely unwrap the input but the util function will encapsulate the CP logic after #2144. So distributed.utils seems to be a good place.

@francesco-bertolotti you can do it as your PR may be landed first. I can rebase on top of your change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed we should share code, but not sure about distributed.utils. Let's discuss offline.

self,
model_parts: list[nn.Module],
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]:
"""Post-processing hook after data loading and before model forward pass."""

inputs = input_dict["input"]
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}

# prepare attention mask
extra_kwargs: dict[str, Any] = (
{
"attention_masks": model_parts[0].get_attention_masks(
input_batch=input_dict["input"],
tokenizer=self.tokenizer,
extra_inputs={k: v for k, v in input_dict.items() if k != "input"},
)
}
if getattr(self.model_args, "attn_type", "sdpa") in ["flex", "varlen"]
else {}
)

return inputs, labels, extra_inputs, extra_kwargs

@torch.no_grad()
def validate(
self,
Expand All @@ -105,11 +138,15 @@ def validate(
):
break

inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
model_parts, input_dict, labels
)

self.metrics_processor.ntokens_since_last_log += labels.numel()
for k, v in input_dict.items():
input_dict[k] = v.to(device_type)
inputs = input_dict["input"]
labels = labels.to(device_type)
inputs = inputs.to(device_type)

optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
Expand Down Expand Up @@ -139,7 +176,9 @@ def validate(
losses=losses,
)
else:
self.pp_schedule.eval(target=targets, losses=losses)
self.pp_schedule.eval(
**extra_kwargs, target=targets, losses=losses
)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
Expand All @@ -155,7 +194,9 @@ def validate(
with self.validation_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
predictions = model_parts[0](inputs)
predictions = model_parts[0](
inputs, **extra_inputs, **extra_kwargs
)
loss = self.loss_fn(predictions, labels)

accumulated_losses.append(loss.detach())
Expand Down
Loading