Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions torchtitan/components/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Generator
from typing import Generator, Any

import torch
import torch.nn as nn
Expand All @@ -18,6 +18,7 @@
from torchtitan.hf_datasets.text_datasets import build_text_validation_dataloader
from torchtitan.tools import utils
from torchtitan.tools.logging import logger
import torchtitan.protocols.train_spec as train_spec_module


class BaseValidator:
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
self.job_config = job_config
self.parallel_dims = parallel_dims
self.loss_fn = loss_fn
self.tokenizer = tokenizer
self.validation_dataloader = build_text_validation_dataloader(
job_config=job_config,
dp_world_size=dp_world_size,
Expand All @@ -76,12 +78,18 @@ def __init__(
self.pp_has_first_stage = pp_has_first_stage
self.pp_has_last_stage = pp_has_last_stage

self.train_spec = train_spec_module.get_train_spec(job_config.model.name)
model_args = self.train_spec.model_args[job_config.model.flavor]
model_args.update_from_config(job_config)
self.model_args = model_args

if self.job_config.validation.steps == -1:
logger.warning(
"Setting validation steps to -1 might cause hangs because of "
"unequal sample counts across ranks when dataset is exhausted."
)


@torch.no_grad()
def validate(
self,
Expand All @@ -105,6 +113,13 @@ def validate(
):
break

# prepare attention mask
extra_kwargs : dict[str, Any] = {"attention_masks" : model_parts[0].get_attention_masks(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_batch=input_dict["input"],
tokenizer=self.tokenizer,
extra_inputs={k: v for k, v in input_dict.items() if k != "input"},
)} if getattr(self.model_args, "attn_type", "sdpa") in ["flex", "varlen"] else {}

self.metrics_processor.ntokens_since_last_log += labels.numel()
for k, v in input_dict.items():
input_dict[k] = v.to(device_type)
Expand Down Expand Up @@ -155,7 +170,7 @@ def validate(
with self.validation_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
predictions = model_parts[0](inputs)
predictions = model_parts[0](inputs, **extra_kwargs)
loss = self.loss_fn(predictions, labels)

accumulated_losses.append(loss.detach())
Expand Down
Loading