From bab83a4c17d8cd6120767658b9f17e0df80bc736 Mon Sep 17 00:00:00 2001 From: Garrett Goon <44747910+garrett361@users.noreply.github.com> Date: Wed, 26 Mar 2025 17:34:22 -0400 Subject: [PATCH 01/18] [DeepSeek] remove numpy, avoid tolist in gatherd_idxs (#1019) Removes the `numpy` usage and `tolist` CUDA sync when computing `gatherd_idxs`. --- torchtitan/experiments/deepseek_v3/model.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchtitan/experiments/deepseek_v3/model.py b/torchtitan/experiments/deepseek_v3/model.py index e9e867d8e9..f8fcff289d 100644 --- a/torchtitan/experiments/deepseek_v3/model.py +++ b/torchtitan/experiments/deepseek_v3/model.py @@ -29,8 +29,6 @@ import math from typing import Optional, Tuple -import numpy as np - import torch import torch.distributed as dist @@ -622,12 +620,14 @@ def moe_forward(self, x, topk_ids, topk_weight): # the tokens in `gathered_tokens` are headed for. This part doesn't need # gradient. with torch.no_grad(): - gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) - s = 0 - # TODO: remove `tolist()` - for i, k in enumerate(tokens_per_expert_group.tolist()): - gatherd_idxs[s : s + k] = i % self.experts_per_rank - s += k + gatherd_idxs = ( + torch.arange( + tokens_per_expert_group.numel(), + device=tokens_per_expert_group.device, + ) + % self.experts_per_rank + ) + gatherd_idxs = gatherd_idxs.repeat_interleave(tokens_per_expert_group) # Prepare buffer for tokens processed by experts if self.shuffle_method == "symm_mem": From 077adb9ea0feef24b01e45f1e196a0218d398e6e Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Wed, 26 Mar 2025 16:52:00 -0700 Subject: [PATCH 02/18] fix simple_fsdp logic when CP is used (#1011) I observed that 1. When CP is used without other data parallel (`dp_replicate` / `dp_shard`), it works fine with and without torch.compile 2. When CP is combined with any other data parallel, it only works without torch.compile. It would fail under compile, with error info like "_UnboundLocalError: cannot access local variable 'buf899' where it is not associated with a value_" So I'm marking CP as "to be fixed". --- torchtitan/experiments/simple_fsdp/README.md | 6 ++-- .../simple_fsdp/parallelize_llama.py | 29 ++++++++++++------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/torchtitan/experiments/simple_fsdp/README.md b/torchtitan/experiments/simple_fsdp/README.md index ab413a80b3..783439d93b 100644 --- a/torchtitan/experiments/simple_fsdp/README.md +++ b/torchtitan/experiments/simple_fsdp/README.md @@ -20,10 +20,10 @@ Some of the features require the updates from PyTorch, with which we are working |Activation Checkpointing| ✅ | |Mixed Precision Training| 🚧 | |Tensor Parallelism| 🚧 | -|Context Parallelism| ✅ | +|Context Parallelism| 🚧 | |Pipeline Parallelism| ✅ | -|Distributed Checkpointing| 🚧 | -|Float8 Training| ❌| +|Distributed Checkpointing| 🚧 | +|Float8 Training| ❌ | ### Citation diff --git a/torchtitan/experiments/simple_fsdp/parallelize_llama.py b/torchtitan/experiments/simple_fsdp/parallelize_llama.py index 8b819047ed..25e44e69e5 100644 --- a/torchtitan/experiments/simple_fsdp/parallelize_llama.py +++ b/torchtitan/experiments/simple_fsdp/parallelize_llama.py @@ -24,7 +24,8 @@ def parallelize_llama( job_config: JobConfig, ): """ - Apply activation checkpointing, torch.compile, and simplefsdp to the model. + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. @@ -59,16 +60,22 @@ def parallelize_llama( if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) - if parallel_dims.dp_shard_enabled or parallel_dims.dp_replicate_enabled: - if parallel_dims.dp_replicate_enabled and parallel_dims.dp_shard_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - fsdp_mode = "hybrid_shard" - elif parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate",) - fsdp_mode = "replicate" + # apply data parallel + if ( + parallel_dims.dp_replicate_enabled + or parallel_dims.dp_shard_enabled + or parallel_dims.cp_enabled + ): + if parallel_dims.dp_replicate_enabled: + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mode = "hybrid_shard" + else: + dp_mesh_dim_names = ("dp_replicate",) + dp_mode = "replicate" else: dp_mesh_dim_names = ("dp_shard_cp",) - fsdp_mode = "fully_shard" + dp_mode = "fully_shard" mp_policy = MixedPrecisionPolicy( param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], @@ -78,11 +85,11 @@ def parallelize_llama( model = data_parallel( model, world_mesh[tuple(dp_mesh_dim_names)], - mode=fsdp_mode, + mode=dp_mode, ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, ) - logger.info("Applied SimpleFSDP (fsdp mode=%s) to the model", fsdp_mode) + logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode) if job_config.training.compile: torch._inductor.config.reorder_for_peak_memory = False From 2404197326669db64bc80f515d7bc9f69863f466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E8=AF=91=E6=96=87?= <1020030101@qq.com> Date: Thu, 27 Mar 2025 12:38:07 +0800 Subject: [PATCH 03/18] Fix ZeroDivisionError when decay_steps=0 (#1010) When `decay_steps=0`, a ZeroDivisionError occurs. Example: Given parameters: - decay_steps=0 - warmup_steps=5 - training_steps=10 Then: - warmup_stable_steps=warmup_steps+stable_steps=10 - current_step ranges from 1 to 10 With the original code, when `current_step` equals `warmup_stable_steps`, we fall through to the `else` branch, which triggers a ZeroDivisionError due to `decay_steps` being 0. ```python if current_step < warmup_steps: # linear warmup # 0-indexed step, hence + 1 adjustments current_step += 1 curr_adjustment = float(current_step / (warmup_steps + 1)) elif current_step < warmup_stable_steps: curr_adjustment = 1.0 else: progress = float(current_step - warmup_stable_steps) / decay_steps if lr_decay_type == "linear": curr_adjustment = 1 - progress elif lr_decay_type == "sqrt": curr_adjustment = 1 - math.sqrt(progress) elif lr_decay_type == "cosine": curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) curr_adjustment = lr_min + (1 - lr_min) * curr_adjustment ``` This PR changes `current_step < warmup_stable_steps` to `current_step <= warmup_stable_steps` to better handle the boundary case and prevent the ZeroDivisionError when `decay_steps=0`. --- torchtitan/components/lr_scheduler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtitan/components/lr_scheduler.py b/torchtitan/components/lr_scheduler.py index 4943922b83..e0925ab3cd 100644 --- a/torchtitan/components/lr_scheduler.py +++ b/torchtitan/components/lr_scheduler.py @@ -150,7 +150,9 @@ def linear_warmup_stable_decay( elif current_step < warmup_stable_steps: curr_adjustment = 1.0 else: - progress = float(current_step - warmup_stable_steps) / decay_steps + # 0-indexed step, hence + 1 adjustments + current_step += 1 + progress = float(current_step - warmup_stable_steps) / (decay_steps + 1) if lr_decay_type == "linear": curr_adjustment = 1 - progress From ddf6ac2d9a22df181ecaf7d9254d47985107ce9c Mon Sep 17 00:00:00 2001 From: Less Wright Date: Wed, 26 Mar 2025 22:01:39 -0700 Subject: [PATCH 04/18] move usage.md to DS folder (where intended) (#1020) quick relocation of usage.md to DeepSeek folder, which ended up one folder higher than intended. --- torchtitan/experiments/{ => deepseek_v3}/USAGE.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename torchtitan/experiments/{ => deepseek_v3}/USAGE.md (100%) diff --git a/torchtitan/experiments/USAGE.md b/torchtitan/experiments/deepseek_v3/USAGE.md similarity index 100% rename from torchtitan/experiments/USAGE.md rename to torchtitan/experiments/deepseek_v3/USAGE.md From 29e3080a542df6069fe925187558b3be4a244b55 Mon Sep 17 00:00:00 2001 From: Panagiotis Kourdis Date: Wed, 26 Mar 2025 22:17:58 -0700 Subject: [PATCH 05/18] [XPU] Enable profiling for XPU devices (#1018) Adds support for profiling `XPU` devices also. --- torchtitan/tools/profiling.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 9cb684ab0a..050b992cc8 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -57,10 +57,15 @@ def trace_handler(prof): assert ( wait >= 0 ), "profile_freq must be greater than or equal to warmup + active" + gpu_device_profiled = None + if torch.cuda.is_available(): + gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA + elif torch.xpu.is_available(): + gpu_device_profiled = torch.profiler.ProfilerActivity.XPU with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, + gpu_device_profiled, ], schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), on_trace_ready=trace_handler, From 3c1298518a3d46ae70156c48095dbae70d30dd15 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Thu, 27 Mar 2025 13:36:02 -0700 Subject: [PATCH 06/18] fix simple fsdp readme (#1022) This is following up on https://github.com/pytorch/torchtitan/pull/1011 Previously CP didn't work on my end because I was accidentally using an old version of pytorch, which had a mismatch to triton version. It works for me now. I'll add CI later. --- torchtitan/experiments/simple_fsdp/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/experiments/simple_fsdp/README.md b/torchtitan/experiments/simple_fsdp/README.md index 783439d93b..887653ac02 100644 --- a/torchtitan/experiments/simple_fsdp/README.md +++ b/torchtitan/experiments/simple_fsdp/README.md @@ -20,7 +20,7 @@ Some of the features require the updates from PyTorch, with which we are working |Activation Checkpointing| ✅ | |Mixed Precision Training| 🚧 | |Tensor Parallelism| 🚧 | -|Context Parallelism| 🚧 | +|Context Parallelism| ✅ | |Pipeline Parallelism| ✅ | |Distributed Checkpointing| 🚧 | |Float8 Training| ❌ | From f29e3424e57b4048a732c8085e02fb967e752cb6 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Thu, 27 Mar 2025 18:19:23 -0700 Subject: [PATCH 07/18] ensure logging ranks are visible to titan (#1025) This PR ensures that the current logging environment is visible inside of titan. This way checks like PP logging visible can operate successfully by matching the logged ranks are also valid for showing the PP loss. Testing: a - run with 1F1B without valid loss rank (4 in this case), verify warning fires b - run with 1F1B with valid loss rank, ensure no warning displayed c - run with one and multiple ranks, verify export has no effect on actual rank logging (same as before). --- run_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_train.sh b/run_train.sh index d1c51a05a2..3b1c0c6494 100755 --- a/run_train.sh +++ b/run_train.sh @@ -11,7 +11,7 @@ set -ex # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh NGPU=${NGPU:-"8"} -LOG_RANK=${LOG_RANK:-0} +export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama/train_configs/debug_model.toml"} overrides="" From 12983cc48267a288b0d0ca098e1d20d030487645 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 27 Mar 2025 23:08:22 -0700 Subject: [PATCH 08/18] Avoid perform trainer.train() for seed checkpoint creation case (#1023) This pull request splits the `Trainer.__init__()` method to separate necessary initializations from actions that require an initialized trainer. The main() function now handles these actions, which currently include seed checkpoint creation and training step. This modification resolves the current seed checkpoint failure and enables the implementation of future actions that are not part of the training process. --- torchtitan/distributed/utils.py | 4 +++- torchtitan/train.py | 40 +++++++++++++++++---------------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 5afa0f7af9..ffda4c7cd0 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -135,7 +135,9 @@ def create_context_parallel_ctx( ) -def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): +def get_train_context( + enable_loss_parallel: bool, enable_compiled_autograd: bool +) -> Generator[None, None, None]: @contextlib.contextmanager def context(cp_context: Optional[Generator[None, None, None]] = None): with contextlib.ExitStack() as stack: diff --git a/torchtitan/train.py b/torchtitan/train.py index af984fdb65..bd1cd7bb19 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -8,7 +8,7 @@ import os import time from datetime import timedelta -from typing import Any, Iterable, Optional +from typing import Any, Generator, Iterable, Optional import torch @@ -43,6 +43,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): dataloader: train_spec_module.BaseDataLoader metrics_processor: train_spec_module.MetricsProcessor checkpointer: CheckpointManager + train_context: Generator[None, None, None] model_parts: list[torch.nn.Module] optimizers: train_spec_module.OptimizersContainer @@ -276,32 +277,18 @@ def __init__(self, job_config: JobConfig): ft_manager=ft_manager, ) - if job_config.checkpoint.create_seed_checkpoint: - assert ( - world_size == 1 - ), "Must create seed checkpoint using a single device, to disable sharding" - assert ( - job_config.checkpoint.enable_checkpoint - ), "Must enable checkpointing when creating a seed checkpoint" - self.checkpointer.save(curr_step=0, force=True) - logger.info("Created seed checkpoint") - return - - self.checkpointer.load(step=job_config.checkpoint.load_step) - self.train_context = dist_utils.get_train_context( parallel_dims.loss_parallel_enabled, parallelism_config.enable_compiled_autograd, ) logger.info( - "Trainer initialized. " - f"Training starts at step {self.step + 1}, " - f"with local batch size {job_config.training.batch_size}, " + "Trainer is initialized with " + f"local batch size {job_config.training.batch_size}, " f"global batch size {job_config.training.batch_size * dp_degree}, " f"sequence length {job_config.training.seq_len}, " f"total steps {job_config.training.steps} " - f"(warmup {job_config.lr_scheduler.warmup_steps})" + f"(warmup {job_config.lr_scheduler.warmup_steps})." ) def next_batch(self, data_iterator: Iterable) -> tuple[torch.Tensor, torch.Tensor]: @@ -402,6 +389,10 @@ def train_step(self, inputs: torch.Tensor, labels: torch.Tensor): @record def train(self): job_config = self.job_config + + trainer.checkpointer.load(step=job_config.checkpoint.load_step) + logger.info("Training starts at step {self.step + 1}.") + with maybe_enable_profiling( job_config, global_step=self.step ) as torch_profiler, maybe_enable_memory_snapshot( @@ -460,7 +451,18 @@ def close(self) -> None: try: trainer = Trainer(config) - trainer.train() + + if config.checkpoint.create_seed_checkpoint: + assert int( + os.environ["WORLD_SIZE"] + ), "Must create seed checkpoint using a single device, to disable sharding." + assert ( + config.checkpoint.enable_checkpoint + ), "Must enable checkpointing when creating a seed checkpoint." + trainer.checkpointer.save(curr_step=0, force=True) + logger.info("Created seed checkpoint") + else: + trainer.train() finally: if trainer: trainer.close() From c7a8a5a9bbdf080a12cb0ddc9bc2a023a4d1da8c Mon Sep 17 00:00:00 2001 From: Abdul Muneer <379885+abdulmuneer@users.noreply.github.com> Date: Sat, 29 Mar 2025 02:31:33 +0400 Subject: [PATCH 09/18] Adding support for collate_fn parameter in ParallelAwareDataloader class (#1021) Currently there is no provision to provide a custom collate function to the dataloader in torchtitan. This PR aims to address that. For example, if users are having multimodal datasets, they might want to collate the data in a specific way. They can't use the provided `build_hf_dataloader` function from `hf_dataset.py` today because the `ParallelAwareDataloader` class inside it does not take a `collate_fn` parameter. We can solve this by simply accepting a `collate_fn` parameter in the `__init__` of the `ParallelAwareDataloader` class. This class is inheriting from `StatefulDataLoader` where the collate function is already supported! All that we have to do is to accept this parameter and pass it while calling the init of the base class from ParallelAwareDataloader. An example usage after this change would be: ``` def build_hf_dataloader( dp_world_size: int, dp_rank: int, tokenizer: Tokenizer, job_config: JobConfig, infinite: bool = True, ) -> ParallelAwareDataloader: """Build a data loader for HuggingFace datasets.""" dataset_name = job_config.training.dataset dataset_path = job_config.training.dataset_path batch_size = job_config.training.batch_size seq_len = job_config.training.seq_len hf_ds = HuggingFaceDataset( dataset_name=dataset_name, dataset_path=dataset_path, tokenizer=tokenizer, seq_len=seq_len, dp_rank=dp_rank, dp_world_size=dp_world_size, infinite=infinite, ) # from import custom_collate_fn return ParallelAwareDataloader( dataset=hf_ds, dp_rank=dp_rank, dp_world_size=dp_world_size, batch_size=batch_size, collate_fn=custom_collate_fn ) ``` --- torchtitan/components/dataloader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/dataloader.py b/torchtitan/components/dataloader.py index 38aab8fae6..ca0aa7f53a 100644 --- a/torchtitan/components/dataloader.py +++ b/torchtitan/components/dataloader.py @@ -8,11 +8,13 @@ import pickle from abc import ABC, abstractmethod -from typing import Any +from collections.abc import Callable +from typing import Any, Optional from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import IterableDataset from torchdata.stateful_dataloader import StatefulDataLoader +from torchtitan.tools.logging import logger class BaseDataLoader(Stateful, ABC): @@ -39,6 +41,7 @@ class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): dp_rank: Data parallelism rank for this dataloader. dp_world_size: The world size of the data parallelism. batch_size: The batch size to use for each iteration. + collate_fn: Optional function to collate samples in a batch. """ dp_rank: int @@ -51,11 +54,12 @@ def __init__( dp_rank: int, dp_world_size: int, batch_size: int, + collate_fn: Optional[Callable] = None, ): self.dp_world_size = dp_world_size self.dp_rank = dp_rank self.batch_size = batch_size - super().__init__(dataset, batch_size) + super().__init__(dataset, batch_size, collate_fn=collate_fn) self._rank_id = f"dp_rank_{dp_rank}" def state_dict(self) -> dict[str, Any]: From ecf26c82e328916eade6720aafe2e4c7e7622e7b Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sat, 29 Mar 2025 15:19:17 -0700 Subject: [PATCH 10/18] Fix missing f-string in logging (#1032) ### Summary I noticed this log line when running a torchtitan training run, which seemed to indicate a missing f-string: `Training starts at step {self.step + 1}.` I took a look at the code and it does seem to be a case of accidentally using a string literal instead of f-string, so I fixed it. --- torchtitan/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index bd1cd7bb19..dcefbff421 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -391,7 +391,7 @@ def train(self): job_config = self.job_config trainer.checkpointer.load(step=job_config.checkpoint.load_step) - logger.info("Training starts at step {self.step + 1}.") + logger.info(f"Training starts at step {self.step + 1}.") with maybe_enable_profiling( job_config, global_step=self.step From 3e75baee54b9a462c61ac5684c8b0575f9adb74d Mon Sep 17 00:00:00 2001 From: Antoni-Joan Solergibert <74564958+TJ-Solergibert@users.noreply.github.com> Date: Mon, 31 Mar 2025 23:25:56 +0200 Subject: [PATCH 11/18] Adding OBELICS DataLoader (#663) Hi, In this PR I present a first draft of the Multimodal DataLoader. First I will describe how the batches are created and then I will explain the padding problem. image Let's begin checking the [OBELICS](https://huggingface.co/datasets/HuggingFaceM4/OBELICS) dataset. For every sample on the dataset we have 4 keys, but we are just interested in 2 of them: - `images`: A list either with URLs of images OR `None`s to specify the position of the text. - `texts`: A list either with text strings OR `None`s to specify the position of the images. It's important to highlight that `len(images)==len(texts)` and that for each index, **one element and only one** is not `None`. The [`format_obelics`](https://github.com/TJ-Solergibert/torchtitan/blob/07a7a12af64075225874b18c99ab158776c2d254/torchtitan/datasets/multimodal/utils.py#L438) function will transform each sample to a format that can be later fed into the transform block that will prepare the samples to the target type. Each formatted sample will be a dictionary containing 2 keys: - `images`: `List` of PIL Images with the loaded images. - `text`: `str` with the text of the sample ready to be tokenized, including the image tokens. image Once formatted, we will process each sample with the transform block. This transform block is composed of [`CLIPPreprocess`](https://github.com/TJ-Solergibert/torchtitan/blob/07a7a12af64075225874b18c99ab158776c2d254/torchtitan/datasets/multimodal/clip.py#L25), `TikTokenizer` & [`VisionCrossAttentionMask`](https://github.com/TJ-Solergibert/torchtitan/blob/multimodal_dl/torchtitan/datasets/multimodal/vision_attention_mask.py) modules. ## `CLIPPreprocess` image This module will prepare the List of images to be fed into the CLIP model. The most relevant steps is resizing the image without distortion, dividing the image into tiles and padding if necessary. Highlight the fact that it will still produce a List of tensors and NOT a tensor as every image can have a different number of tiles. This will be addressed in the collator where we will pad the image tiles to the largest in the batch. Also, we keep the maximum number of tiles to 4 and the tile size to 448 for pretraining [[1]](https://github.com/pytorch/torchtune/blob/0cc1b1f6a2a9c54ca640c4eb0a4d0b94ba94bb04/torchtune/models/llama3_2_vision/_model_builders.py#L92), [[2]](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision/blob/3f2e93603aaa5dd142f27d34b06dfa2b6e97b8be/preprocessor_config.json#L22). ## `TikTokenizer` I've included a new method in the tokenizer to encode the multimodal text. In short, it just encodes the text adding the special `image_id` token and returns both the `input_ids` & `labels` masking the `bos`, `eos` & `image_id` tokens. ## `VisionCrossAttentionMask ` image This module will create the attention mask for the Fused layers. In short, for each **TILE** we will have 1025 `image_tokens` and this mask will specify for each `text_token` to which `image_tokens` should attend to. We are returning again a List of tensors as the quantity of `image_tokens` will depend on the number of tiles. Again, we will solve this in the collator. # Padding & the collator As we've previously seen, both the outputs of the `CLIPPreprocess` & `VisionCrossAttentionMask` are list of tensors because of the different number of tiles. Within the same sample we should pad both artifacts to the maximum number of tiles, but the issue arises when we run `batch_size > 1` as we will also need to pad the `input_ids` (& `labels`) which is relatively cheap BUT also the Number of images, as the input to the CLIP model will be a tensor of shape [Batch size, Number of images, Number of tiles, Channels, Tile size, Tile size]. Padding to the maximum number of tiles is bad, but in the worst case scenario you end up increasing the tensor x4 (from 1 tile to maximum number of tiles = 4). But for the number of images it can get really really big, as there are samples with +30 images. To check this phenomenon I've included [`scripts/check_padding_mm.py`](https://github.com/TJ-Solergibert/torchtitan/blob/multimodal_dl/scripts/check_padding_mm.py) which computes the % of padding in a sample. Feel free to give it a try but it's very easy to get samples where the majority of the input is padding. ``` python3 scripts/check_padding_mm.py Unpadded tokens: 8717, Total tokens in batch: 21728 Padded text tokens: 13011, 59.88% ######################################## Unpadded images: 25, Total images in batch: 64 Padded images: 39, 60.94% (Each image with shape [4, 3, 448, 448]) ######################################## Unpadded number of tiles: 61, Total number of tiles: 256 Padded tiles: 195, 68.72% (Each with shape [3, 448, 448]) ######################################## Unpadded cross attention mask elements: 545030425, Total cross attention mask elements: 5701427200 Padded cross attention mask elements: 5156396775, 90.44% ``` That's why I proposed continue working on a DataLoader & Dataset than can pack multiple samples up to a given `input_ids` length OR number of images in a batch. Packing the `input_ids` is fairly easy while packing the cross attention masks will require a bit more effort. Let me know if you would be interested on supporting that feature or you just want to include in the repo an example of the multimodal pipeline despite the padding issue described. I also plan including some unit test, to check the generated samples & recovering from failures abilities. Other comments: - [torchtitan/datasets/mm_datasets.py](https://github.com/TJ-Solergibert/torchtitan/blob/multimodal_dl/torchtitan/datasets/mm_datasets.py): I hardcoded the image token & also the Llama3VisionTransform init. Let me know whether I should leave this (default) values or let the user modify them. - [torchtitan/datasets/multimodal/collator.py](https://github.com/TJ-Solergibert/torchtitan/blob/multimodal_dl/torchtitan/datasets/multimodal/collator.py): I hardcoded the ignore index to mask the special tokens & added some debug lists for the [`scripts/check_padding_mm.py`](https://github.com/TJ-Solergibert/torchtitan/blob/multimodal_dl/scripts/check_padding_mm.py) script. - [torchtitan/datasets/multimodal/llama3_transform.py](https://github.com/TJ-Solergibert/torchtitan/blob/multimodal_dl/torchtitan/datasets/multimodal/llama3_transform.py): More hardcoded values & what should we do (Include and mask) with eos & bos tokens. - [torchtitan/datasets/multimodal/utils.py](https://github.com/TJ-Solergibert/torchtitan/blob/multimodal_dl/torchtitan/datasets/multimodal/utils.py): Most of the code directly copied from `torchtune` cleaning the unnecessary parts like the code for the inference case. Also in the [`format_obelics`](https://github.com/TJ-Solergibert/torchtitan/blob/07a7a12af64075225874b18c99ab158776c2d254/torchtitan/datasets/multimodal/utils.py#L438) function we could drop the last images in the case the sample end with images and not text as no token will attend to them and we dont compute the loss with the image tokens (So they are useless) - [torchtitan/datasets/tokenizer/tiktoken.py](https://github.com/TJ-Solergibert/torchtitan/blob/multimodal_dl/torchtitan/datasets/tokenizer/tiktoken.py): Hardcoded image token + encode_multimodal method. Is it fine to include this method here or should we move it somewhere else? Also we could standardise the nomenclature for the `input_ids`/`tokens` across the repo. Toni --- scripts/estimate/estimation.py | 2 +- torchtitan/experiments/multimodal/__init__.py | 32 ++ .../multimodal/check_padding_mm.py | 109 +++++ .../experiments/multimodal/mm_collator.py | 227 +++++++++ .../experiments/multimodal/mm_dataset.py | 268 +++++++++++ .../experiments/multimodal/requirements.txt | 1 + .../multimodal/tokenizer/tiktoken.py | 232 ++++++++++ .../experiments/multimodal/transform.py | 185 ++++++++ torchtitan/experiments/multimodal/utils.py | 437 ++++++++++++++++++ 9 files changed, 1492 insertions(+), 1 deletion(-) create mode 100644 torchtitan/experiments/multimodal/__init__.py create mode 100644 torchtitan/experiments/multimodal/check_padding_mm.py create mode 100644 torchtitan/experiments/multimodal/mm_collator.py create mode 100644 torchtitan/experiments/multimodal/mm_dataset.py create mode 100644 torchtitan/experiments/multimodal/requirements.txt create mode 100644 torchtitan/experiments/multimodal/tokenizer/tiktoken.py create mode 100644 torchtitan/experiments/multimodal/transform.py create mode 100644 torchtitan/experiments/multimodal/utils.py diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 3b5831487a..4f7049be0f 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -194,7 +194,7 @@ def estimate_memory(job_config: JobConfig): ) print(f"Tracker Max: {tracker_peak / gib} GiB") if job_config.memory_estimation.disable_fake_mode and peak_active > 0: - print(f"Tracker Accuracy: {tracker_peak/peak_active}") + print(f"Tracker Accuracy: {tracker_peak / peak_active}") gc.enable() diff --git a/torchtitan/experiments/multimodal/__init__.py b/torchtitan/experiments/multimodal/__init__.py new file mode 100644 index 0000000000..ceb29e6ce4 --- /dev/null +++ b/torchtitan/experiments/multimodal/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from mm_dataset import build_mm_dataloader + +from torchtitan.components.loss import cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer + +from torchtitan.models.llama import parallelize_llama, pipeline_llama + +from torchtitan.models.llama_multimodal import llama3_2_configs, MultimodalDecoder +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +register_train_spec( + TrainSpec( + name="llama3", + cls=MultimodalDecoder, # TODO(tj.solergibert) Create VisionEncoder + MultimodalDecoder class? + config=llama3_2_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_mm_dataloader, + build_tokenizer_fn=build_tiktoken_tokenizer, + loss_fn=cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/multimodal/check_padding_mm.py b/torchtitan/experiments/multimodal/check_padding_mm.py new file mode 100644 index 0000000000..0345009256 --- /dev/null +++ b/torchtitan/experiments/multimodal/check_padding_mm.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import click + +from mm_dataset import build_mm_dataloader +from tokenizer.tiktoken import build_tiktoken_tokenizer + +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import init_logger + + +@click.command() +@click.option("--dataset", default="OBELICS") +@click.option("--batch-size", default=4) +@click.option("--seq-len", default=4096) +@click.option("--tokenizer-path", required=True) +@click.option("--dp-rank", default=0) +@click.option("--dp-world-size", default=2) +@click.option("--batch-number", default=4) +def main( + dataset: str, + batch_size: int, + seq_len: int, + tokenizer_path: str, + dp_rank: int, + dp_world_size: int, + batch_number: int, +): + init_logger() + job_config = JobConfig() + job_config.parse_args( + [ + "--training.dataset", + dataset, + "--training.batch_size", + str(batch_size), + "--training.seq_len", + str(seq_len), + "--model.tokenizer_path", + tokenizer_path, + ] + ) + tokenizer = build_tiktoken_tokenizer(job_config) + dl = build_mm_dataloader( + dp_world_size=dp_world_size, + dp_rank=dp_rank, + tokenizer=tokenizer, + job_config=job_config, + ) + dl_iter = iter(dl) + + for _ in range(batch_number): + batch = next(dl_iter) + + # Analyze Batch + # input_ids + total_input_ids = batch["input_ids"].shape[0] * batch["input_ids"].shape[1] + total_non_padding_tokens = total_input_ids - int( + (batch["input_ids"] == 128004).sum() + ) + total_padding_tokens = total_input_ids - total_non_padding_tokens + print(f"Padding tokens in each sample: {(batch['input_ids'] == 128004).sum(dim=1)}") + print( + f"Unpadded tokens: {total_non_padding_tokens}, Total tokens in batch: {total_input_ids}" + ) + print( + f"Padded text tokens: {total_padding_tokens}, {(total_padding_tokens) / total_input_ids * 100:.2f}%" + ) + print(80 * "#") + # Images + padded_images = 0 + padded_tiles = 0 + for sample in batch["encoder_input"]["images"]: + for image in sample: + if int(image.sum()) == 0: + padded_images += 1 + for tile in image: + if int(tile.sum()) == 0: + padded_tiles += 1 + + total_images = ( + batch["encoder_input"]["images"].shape[0] + * batch["encoder_input"]["images"].shape[1] + ) + + print( + f"Unpadded images: {total_images - padded_images}, Total images in batch: {total_images}" + ) + print( + f'Padded images: {padded_images}, {padded_images / total_images * 100:.2f}% (Each image with shape {list(batch["encoder_input"]["images"][0, 0].shape)})' # noqa: B950 + ) + print(80 * "#") + # Tiles + total_number_of_tiles = total_images * batch["encoder_input"]["images"].shape[2] + + print( + f"Unpadded number of tiles: {total_number_of_tiles - padded_tiles}, Total number of tiles: {total_number_of_tiles}" + ) + print( + f'Padded tiles: {padded_tiles}, {padded_tiles / total_number_of_tiles * 100:.2f}% (Each with shape {list(batch["encoder_input"]["images"][0, 0, 0].shape)})' # noqa: B950 + ) + print(80 * "#") + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/multimodal/mm_collator.py b/torchtitan/experiments/multimodal/mm_collator.py new file mode 100644 index 0000000000..98793a7f6f --- /dev/null +++ b/torchtitan/experiments/multimodal/mm_collator.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F + +from tokenizer.tiktoken import IGNORE_INDEX + +from torch.nn.utils.rnn import pad_sequence + + +def padded_collate( + batch: List[Dict[str, List[int]]], + padding_idx: int = 0, + ignore_idx: int = -100, +) -> Dict[str, torch.Tensor]: + """Pad a batch of sequences to the longest sequence length in the batch, and + convert integer lists to tensors. + + Args: + batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs. + padding_idx (int): Padding index for input ids. Defaults to 0. + ignore_idx (int): Padding index for labels. Defaults to -100. + + Returns: + Dict[str, torch.Tensor]: Collated input and label tensors. + + Example: + >>> token_pairs = [ + >>> {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + >>> {"input_ids": [7,], "labels": [10,]}, + >>> ] + >>> collated = padded_collate( + >>> batch=token_pairs, + >>> padding_idx=padding_idx, + >>> ignore_idx=ignore_idx, + >>> ) + >>> collated["input_ids"] + >>> tensor([[1, 2, 3], [7, 0, 0]]) + >>> collated["labels"] + >>> tensor([[4, 5, 6], [10, -100, -100]]) + """ + input_ids = pad_sequence( + [x["input_ids"] for x in batch], + batch_first=True, + padding_value=padding_idx, + ) + labels = pad_sequence( + [x["labels"] for x in batch], + batch_first=True, + padding_value=ignore_idx, + ) + + input_ids_seq_len = input_ids.shape[-1] + labels_seq_len = labels.shape[-1] + + # Hack to pad correctly and not use max_seq_len, which is costly + if input_ids_seq_len > labels_seq_len: + labels = F.pad( + labels, (0, input_ids_seq_len - labels_seq_len), value=ignore_idx + ) + elif labels_seq_len > input_ids_seq_len: + input_ids = F.pad( + input_ids, + (0, labels_seq_len - input_ids_seq_len), + value=padding_idx, + ) + return {"input_ids": input_ids, "labels": labels} + + +# NOTE Inspired from torchtune.data._collate.py +@dataclass +class MultiModalCollator: + padding_idx: int = 128004 + ignore_idx: int = IGNORE_INDEX + pad_max_tiles: Optional[int] = None + pad_max_images: Optional[int] = None + + def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + """Pad a batch of text sequences, tiled image tensors, aspect ratios, + and cross attention masks. This can be used for both training and inference. + + ``batch`` is expected to be a list of sample dicts containing the following:: + - "input_ids": List[int] of length text_seq_len, varies across samples + - "labels": List[int] of length text_seq_len, varies across samples + - "encoder_input": Dict[str, List[torch.Tensor]] + - "images": List[torch.Tensor], each with shape (n_tiles, c, h, w) + - "aspect_ratio": List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio + + Shape notation: + - c = channel dim + - h = height dim + - w = weight dim + + Note: + For each element in the batch, ``len(images) == len(aspect_ratio)``. + + This collater does the following: + (1) Pad text sequence and encoder mask to the longest sequence length in the batch + (2) Pad image tensors in the tile dimension with zeros to the largest number + of tiles in the batch + (3) Add empty images of zeros to samples up to max number of images in the batch + (4) Pad aspect ratios with (1,1) for all added padding images + + Args: + batch (List[Dict[str, Any]]): A list of sample dicts containing input_ids, + labels, images, and aspect_ratio. + padding_idx (int): Padding index for input token ids. Defaults to 0. + ignore_idx (int): Padding index for labels. Defaults to -100. + pad_max_tiles (Optional[int]): Maximum number of tiles to pad to. If None, will pad to the largest number of tiles + in the batch. Defaults to None. + pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images + in the batch. Defaults to None. + + Returns: + Dict[str, Tensor]: Collated tokens, labels, images, aspect_ratio tensors. + - tokens: Tensor of shape (bsz, max_seq_len) + - labels: Tensor of shape (bsz, max_seq_len) + - images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w) + - aspect_ratio: Tensor of shape (bsz, max_num_images, 2) + + Example: + >>> image_id = 1 + >>> tokens_per_tile = 5 + >>> c, h, w = 1, 1, 1 + >>> batch = [ + ... { + ... "input_ids": [1, 2, 1, 3], "labels": [4, 5, 6, 7], + ... "encoder_input": { + ... # One image with two tiles, one image with three tiles + ... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)], + ... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])], + ... }, + ... }, + ... { + ... "input_ids": [1, 4], "labels": [8, 9], + ... "encoder_input": { + ... # One image with four tiles + ... "images": [torch.ones(4, c, h, w)], + ... "aspect_ratio": [torch.tensor([2, 2])], + ... }, + ... }, + ... ] + ... collator = MultiModalCollator(pad_max_tiles=4) + >>> model_inputs = collator(batch=batch) + >>> print(model_inputs["input_ids"]) + tensor([[1, 2, 1, 3], + [1, 4, 0, 0]]) + >>> print(model_inputs["labels"]) + tensor([[4, 5, 6, 7], + [8, 9, -100, -100]]) + >>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w) + torch.Size([2, 2, 4, 1, 1, 1]) + >>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2) + torch.Size([2, 2, 2]) + >>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four + tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]]) + >>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four + tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]]) + >>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded + tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]]) + >>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample + tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]]) + """ + # Text tokens can be handled independently by existing collaters + text_only = [ + {"input_ids": sample["input_ids"], "labels": sample["labels"]} + for sample in batch + ] + collated_text = padded_collate(text_only, self.padding_idx, self.ignore_idx) + + if self.pad_max_tiles is None: + # Get max number of tiles in batch + max_num_tiles = max(sample["images_tiles"].shape[0] for sample in batch) + else: + max_num_tiles = self.pad_max_tiles + + # Pad images and aspect ratios to max number of tiles + batch_images = [] + batch_aspect_ratios = [] + + for sample in batch: + sample_images = [] + for image in sample["encoder_input"]["images"]: + # Single image in each sample has shape (n_tiles, c, h, w) + n_tiles = image.shape[0] + # Single mask in each sample corresponds to a single image and has shape (text_seq_len, image_seq_len) + # where image_seq_len = n_tiles * tokens_per_tile + padding_tiles = max_num_tiles - n_tiles + + # Image should now have shape (max_num_tiles, c, h, w) + padded_image = F.pad( + image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0 + ) + + sample_images.append(padded_image) + # Stack multiple images and masks per sample in num_images dimension + batch_images.append(torch.stack(sample_images)) + batch_aspect_ratios.append( + torch.stack(sample["encoder_input"]["aspect_ratio"]) + ) + # Finally, pad images, masks, aspect ratios to max number of images in batch + # (bsz, max_num_images, max_num_tiles, c, h, w) + collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0) + # (bsz, max_num_images, 2) + collated_aspect_ratios = pad_sequence( + batch_aspect_ratios, batch_first=True, padding_value=1 + ) + + batch_dict = { + "input_ids": collated_text["input_ids"], + "labels": collated_text["labels"], + "encoder_input": { + "images": collated_images, + "aspect_ratio": collated_aspect_ratios, + }, + } + + return batch_dict diff --git a/torchtitan/experiments/multimodal/mm_dataset.py b/torchtitan/experiments/multimodal/mm_dataset.py new file mode 100644 index 0000000000..a29627aace --- /dev/null +++ b/torchtitan/experiments/multimodal/mm_dataset.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import torch + +from datasets import Dataset, load_dataset +from datasets.distributed import split_dataset_by_node + +from mm_collator import MultiModalCollator +from tokenizer.tiktoken import IGNORE_INDEX, Tokenizer +from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import IterableDataset +from transform import CLIPTransform +from utils import load_image + +from torchtitan.components.dataloader import ParallelAwareDataloader +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger + + +def _load_obelics_dataset(dataset_path: str): + """Load C4 dataset with default configuration.""" + return load_dataset(dataset_path, split="train", streaming=True) + + +def _process_obelics_sample( + sample: dict[str, Any], image_token: str = "<|image|>" +) -> Dict[str, List[Union[str, "PIL.Image.Image"]]]: + """ + This function formats samples from the OBELICS dataset + Returns: + Dict[str, Any]: The transformed sample with the following fields: + - images: List[PIL.Image.Image] with the loaded images + - text: str with the text of the sample ready to be tokenized including the image tokens + Example: + >>> formatted_sample = format_obelics(sample, image_token="<|image|>") + >>> print(formatted_sample["text"]) + ... "<|image|><|image|><|image|> The elephant look cute!<|image|><|image|> The cats are sad :(" + """ + sample_images = [image for image in sample["images"] if image is not None] + sample_text = [ + text if text is not None else image_token for text in sample["texts"] + ] + return { + "images": [load_image(image) for image in sample_images], + "text": "".join(map(str, sample_text)), + } + + +@dataclass +class DatasetConfig: + path: str + loader: Callable + sample_processor: Callable + + +# Add your dataset here here - more information at docs/datasets.md +MM_DATASETS = { + "obelics": DatasetConfig( + path="HuggingFaceM4/OBELICS", + loader=_load_obelics_dataset, + sample_processor=_process_obelics_sample, + ), +} + + +def _validate_mm_dataset( + dataset_name: str, dataset_path: str = None +) -> tuple[str, Callable, Callable]: + """Validate dataset name and path.""" + if dataset_name not in MM_DATASETS: + raise ValueError( + f"Dataset {dataset_name} is not supported. " + f"Supported datasets are: {list(MM_DATASETS.keys())}" + ) + + config = MM_DATASETS[dataset_name] + path = dataset_path or config.path + logger.info(f"Preparing {dataset_name} dataset from {path}") + return path, config.loader, config.sample_processor + + +class MultiModalDataset(IterableDataset, Stateful): + """PyTorch MultiModal Dataset. + + Args: + dataset_name (str): name of the dataset to load + tokenizer (Tokenizer): + Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. + world_size (int): number of data parallel processes participating in training + rank (int): rank of the current data parallel process + infinite (bool): whether to loop infinitely over the dataset + + We currently ONLY support the OBELICS dataset + + Example use: + >>> ds = MultiModalDataset(dataset_name="OBELICS", tokenizer=tokenizer) + >>> for batch in Dataloader(ds, batch_size=8): + print(f"Batch size: {len(batch)}") + Batch size: 8 + """ + + def __init__( + self, + dataset_name: str, + dataset_path: Optional[str], + tokenizer: Tokenizer, + image_token: str = "<|image|>", + tile_size: int = 448, + max_num_tiles: int = 4, + seq_len: int = 2048, + dp_rank: int = 0, + dp_world_size: int = 1, + infinite: bool = False, + ) -> None: + # Force lowercase for consistent comparison + dataset_name = dataset_name.lower() + + path, dataset_loader, sample_processor = _validate_mm_dataset( + dataset_name, dataset_path + ) + ds = dataset_loader(path) + + # TODO: support shuffling + self.dataset_name = dataset_name + self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) + self._tokenizer = tokenizer + self.seq_len = seq_len + self.infinite = infinite + self._sample_processor = sample_processor + self.image_token = ( + image_token # TODO(tj.solergibert) Add `image_token` to JobConfig + ) + # TODO(tj.solergibert) Add `tile_size` & `max_num_tiles` to JobConfig + self.transform_image = CLIPTransform( + image_mean=( + 0.48145466, + 0.4578275, + 0.40821073, + ), # TODO(tj.solergibert) What should we do with `image_mean` & `image_std`?, + image_std=(0.26862954, 0.26130258, 0.27577711), + tile_size=tile_size, + possible_resolutions=None, + max_num_tiles=max_num_tiles, + resample="bilinear", + resize_to_max_canvas=False, + ) + + # variables for checkpointing + self._sample_idx = 0 + + def __iter__(self): + + while True: + for sample in self._get_data_iter(): + try: + sample = self._sample_processor( + sample, image_token=self.image_token + ) + except Exception: + continue + self._sample_idx += 1 + + # CLIP Transform + encoder_input = {"images": [], "aspect_ratio": []} + for image in sample["images"]: + out = self.transform_image(image) + encoder_input["images"].append(out["image"]) + encoder_input["aspect_ratio"].append(out["aspect_ratio"]) + sample["encoder_input"] = encoder_input + + # Tokenize + tokens = self._tokenizer.encode( + sample["text"], + bos=True, + eos=True, + allowed_special=set(["<|image|>"]), + ) + sample["input_ids"] = torch.LongTensor(tokens[:-1]) + sample["labels"] = torch.LongTensor(tokens[1:]) + # Mask BOS, EOS & image tokens from the loss + sample["labels"] = torch.where( + torch.isin( + sample["labels"], + torch.LongTensor( + [ + self._tokenizer.bos_id, + self._tokenizer.eos_id, + self._tokenizer.image_id, + ] + ), + ), + IGNORE_INDEX, + sample["labels"], + ) + # Truncate + sample["input_ids"], sample["labels"] = ( + sample["input_ids"][: self.seq_len], + sample["labels"][: self.seq_len], + ) + yield sample + + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + else: + # Reset offset for the next iteration + self._sample_idx = 0 + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + + def _get_data_iter(self): + if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): + return iter([]) + + it = iter(self._data) + for _ in range(self._sample_idx): + next(it) + return it + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + + def state_dict(self): + return {"sample_idx": self._sample_idx} + + +def build_mm_dataloader( + dp_world_size: int, + dp_rank: int, + tokenizer: Tokenizer, + job_config: JobConfig, + infinite: bool = True, +) -> ParallelAwareDataloader: + """Build a data loader for HuggingFace datasets.""" + dataset_name = job_config.training.dataset + dataset_path = job_config.training.dataset_path + batch_size = job_config.training.batch_size + seq_len = job_config.training.seq_len + pad_max_tiles = 4 # TODO(tj.solergibert) Add `pad_max_tiles` to JobConfig + padding_idx = 128004 # TODO(tj.solergibert) Add `padding_idx` to JobConfig + + hf_ds = MultiModalDataset( + dataset_name=dataset_name, + dataset_path=dataset_path, + tokenizer=tokenizer, + seq_len=seq_len, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + infinite=infinite, + ) + + collate_fn = MultiModalCollator( + padding_idx=padding_idx, pad_max_tiles=pad_max_tiles + ) + + return ParallelAwareDataloader( + dataset=hf_ds, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + batch_size=batch_size, + collate_fn=collate_fn, + ) diff --git a/torchtitan/experiments/multimodal/requirements.txt b/torchtitan/experiments/multimodal/requirements.txt new file mode 100644 index 0000000000..e35531e566 --- /dev/null +++ b/torchtitan/experiments/multimodal/requirements.txt @@ -0,0 +1 @@ +torchvision diff --git a/torchtitan/experiments/multimodal/tokenizer/tiktoken.py b/torchtitan/experiments/multimodal/tokenizer/tiktoken.py new file mode 100644 index 0000000000..9d494a06f6 --- /dev/null +++ b/torchtitan/experiments/multimodal/tokenizer/tiktoken.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os +from pathlib import Path +from typing import ( + AbstractSet, + Any, + cast, + Collection, + Dict, + Iterator, + List, + Literal, + Mapping, + Optional, + Sequence, + Union, +) + +import tiktoken +import torch +from tiktoken.load import load_tiktoken_bpe + +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger + +IMAGE_TOKEN_ID = 128256 +IGNORE_INDEX = -100 + + +class TikTokenizer(Tokenizer): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950 + + def __init__(self, model_path: str): + super().__init__(model_path) + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.special_tokens["<|image|>"] = IMAGE_TOKEN_ID + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self._n_words: int = self.model.n_vocab + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.pad_id: int = -1 + self.image_id = IMAGE_TOKEN_ID + self.stop_tokens = { + self.special_tokens["<|end_of_text|>"], + self.special_tokens["<|eot_id|>"], + } + logger.info( + f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}, IMAGE ID {self.image_id}" + ) + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, + disallowed_special: Optional[Union[Literal["all"], Collection[str]]] = None, + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + assert type(s) is str + allowed_special = allowed_special or set() + disallowed_special = disallowed_special or () + + # The tiktoken tokenizer can handle <=400k chars without + # pyo3_runtime.PanicException. + TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + + # https://github.com/openai/tiktoken/issues/195 + # Here we iterate over subsequences and split if we exceed the limit + # of max consecutive non-whitespace or whitespace characters. + MAX_NO_WHITESPACES_CHARS = 25_000 + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + + def encode_multimodal(self, sample: Mapping[str, Any]) -> List[int]: + """ + Tokenizes a `str` of text and creates `labels` masking BOS, EOS and `image_id` tokens. + """ + # TODO(tj.solergibert) Should we keep `input_ids` OR `tokens` across this class, VisionCrossAttentionMask & the collator? + # For me it makes more sense to split `tokens` between `input_ids` & `labels` as in train.py BUT the `MultimodalDecoder` + # & everything else expects `tokens` + text = sample["text"] + tokens = self.encode( + text, bos=True, eos=True, allowed_special=set(["<|image|>"]) + ) + input_ids = torch.LongTensor(tokens[:-1]) + labels = torch.LongTensor(tokens[1:]) + labels = torch.where( + torch.isin( + labels, torch.LongTensor([self.bos_id, self.eos_id, self.image_id]) + ), + IGNORE_INDEX, + labels, + ) + + assert len(input_ids) == len(labels) # TODO(tj.solergibert) Delete + + sample.update({"tokens": input_ids, "labels": labels}) + + return sample + + +def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer: + return TikTokenizer(job_config.model.tokenizer_path) diff --git a/torchtitan/experiments/multimodal/transform.py b/torchtitan/experiments/multimodal/transform.py new file mode 100644 index 0000000000..ecb0f989ac --- /dev/null +++ b/torchtitan/experiments/multimodal/transform.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Optional, Tuple + +import torch + +import torchvision +from torchvision.transforms.v2 import functional as F + +from utils import ( + find_supported_resolutions, + get_canvas_best_fit, + resize_with_pad, + tile_crop, +) + +from torchtitan.tools.logging import logger + + +class CLIPTransform: + """ + This class accepts images of any size and dynamically resizes, pads, normalizes and tiles it + based on the image aspect ratio and the number of image tiles we allow. + + The algorithm will NOT distort the image to fit a certain aspect ratio, because + that leads to a significant degradation in image quality. + + The user can choose if they want to allow upscaling by using the flag ``resize_to_max_canvas``. + + For example, if an input image is of size 300x800, and we want to allow + a maximum of 16 image tiles, with side 224px, then: + + If ``resize_to_max_canvas=False``, then: + best_resolution = (448, 896) -> smallest canvas, up to 16 tiles, that doesn't require downscaling + image is NOT resized + image is padded (300, 800) -> 448,896 + Image is tiled 2x4, for a final output shape of (8, 3, 224, 224) + + If ``resize_to_max_canvas=True``, then: + best_resolution = (448, 1344) # canvas that allows maximum upscaling, with minimum padding, up to 16 tiles + image is resized without distortion (300,800) -> (448, 1194) #448 is the limiting side for the resize + image is padded (448, 1194) -> (448, 1344) + Image is tiled 2x6, for a final output shape of (10, 3, 224, 224) + + Args: + image_mean (Optional[List[float]]): Mean values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. + image_std (Optional[List[float]]): Standard deviation values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. + possible_resolutions (Optional[List[Tuple[int, int]]]): List of possible resolutions as tuples (height, width). + where each tuple represents a possible canvas to fit the image into when calling ``get_canvas_best_fit``. + If None, this will be calculated using max_num_tiles and tile_size. Default None. + tile_size (int): Size of the tiles to divide the image into. Default 224. + max_num_tiles (Optional[int]): Only used if possible_resolutions is NOT given. + Maximum number of tiles to break an image into. + This will be used to generate possible_resolutions, + e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224. + Default 4. + dtype (torch.dtype): Data type of the output image. Default torch.bfloat16. + resample (str): Resampling method used when resizing images. Supports any enum of + ``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic". + Default 'bilinear'. + resize_to_max_canvas (bool): "If True, the image will be upscaled without distortion to fit the largest possible + resolution from possible_resolutions. + If False, it will pick the resolution that minimizes downscaling, including no downscaling at all. + In this case, the image will only be upscaled if it's size < tile_size. Default False. + + Examples: + >>> image_transform = CLIPImageTransform( + ... image_mean=None, + ... image_std=None, + ... tile_size=224, + ... possible_resolutions=None, + ... max_num_tiles=4, + ... resample="bilinear", + ... resize_to_max_canvas=True, + ...) + >>> # create random image + >>> image = (np.random.rand(100,200,3) * 255).astype(np.uint8) + >>> image = PIL.Image.fromarray(image) + >>> output = image_transform(image) + >>> output['image'].shape # [num_tiles, num_channels, tile_size, tile_size] + torch.Size([2, 3, 224, 224]) + >>> output['ar'] # image best fits the canvas 224x448 + torch.tensor([1,2]) + """ + + def __init__( + self, + *, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + possible_resolutions: Optional[List[Tuple[int, int]]] = None, + tile_size: int = 224, + max_num_tiles: Optional[int] = 4, + dtype: torch.dtype = torch.bfloat16, + resample: str = "bilinear", + resize_to_max_canvas: bool = False, + ) -> None: + + # get_canvas_best_fit + assert ( + possible_resolutions is not None or max_num_tiles is not None + ), f"Either possible_resolutions or max_num_tiles must be given. Got {possible_resolutions} and {max_num_tiles}" + + # If possible_resolutions are not given, then calculate possible ones based on max_num_tiles + if not possible_resolutions and max_num_tiles: + possible_resolutions = find_supported_resolutions( + max_num_tiles=max_num_tiles, tile_size=tile_size + ) + else: + possible_resolutions = possible_resolutions + + self.possible_resolutions = torch.tensor(possible_resolutions).reshape(-1, 2) + logger.debug( + f"Found possible_resolutions: {self.possible_resolutions}. Will fit the images into the canvas with best fit." + ) + + self.resize_to_max_canvas = resize_to_max_canvas + + # normalize + assert (image_mean is None) == ( + image_std is None + ), f"Need to provide both or none of image_mean and image_std. Got {image_mean=} and {image_std=}" + self.mean = image_mean + self.std = image_std + + # resize_with_pad + self.max_size = None if resize_to_max_canvas else tile_size + self.dtype = dtype + self.resample = torchvision.transforms.InterpolationMode[resample.upper()] + + # tile_crop + self.tile_size = tile_size + + def __call__(self, image: torch.Tensor) -> Mapping[str, Any]: + """ + Apply image decoding and transformations to the "image" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with an "image" field containing + a List[Message] to tokenize + + Returns: + Mapping[str, Any]: The sample with an updated "image" filed and added + "aspect_ratio" field. + """ + assert isinstance(image, torch.Tensor), "Input image must be a torch.Tensor." + + image = F.to_image(image) + image = F.grayscale_to_rgb_image(image) + image = F.to_dtype(image, dtype=self.dtype, scale=True) + + # Find the best canvas to fit the image without distortion + best_resolution = get_canvas_best_fit( + image=image, + possible_resolutions=self.possible_resolutions, + resize_to_max_canvas=self.resize_to_max_canvas, + ) + + # resize without distortion + pad to fit best_resolution + image = resize_with_pad( + image=image, + target_size=best_resolution, + resample=self.resample, + max_size=self.max_size, + ) + + # Normalize + if self.mean: + image = F.normalize(image, mean=self.mean, std=self.std) + + # Divide the image into equally sized tiles + image = tile_crop(image=image, tile_size=self.tile_size) + + aspect_ratio = torch.tensor(best_resolution).reshape(-1) // self.tile_size + + return { + "image": image, + "aspect_ratio": aspect_ratio, + } diff --git a/torchtitan/experiments/multimodal/utils.py b/torchtitan/experiments/multimodal/utils.py new file mode 100644 index 0000000000..c927772a5e --- /dev/null +++ b/torchtitan/experiments/multimodal/utils.py @@ -0,0 +1,437 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from collections import defaultdict + +from pathlib import Path +from typing import List, Optional, Set, Tuple, Union +from urllib import request + +import torch +import torchvision +from torchvision.transforms.v2 import functional as F + +# NOTE Copied from torchtune.modules.transforms.vision_utils.tile_crop.py +def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor: + """ + Divides a tensor into equally sized tiles. The tensor should be divisible by tile_size. + + Args: + image (torch.Tensor): Input image to crop into tiles. + tile_size (int): Size of each tile. + + Returns: + torch.Tensor: torch.Tensor of shape [num_tiles, channel_size, tile_size, tile_size] + + Examples: + >>> image = torch.rand(3, 200, 300) + >>> tiles = tile_crop(image, tile_size=50) + >>> tiles.shape # 4x6 = 24 tiles + torch.Size([24, 3, 50, 50]) + + >>> image = torch.rand(3, 400, 600) + >>> tiles = tile_crop(image, tile_size=200) + >>> tiles.shape # 2x3 = 6 tiles + torch.Size([6, 3, 200, 200]) + """ + + channel_size, height, width = image.shape + + # assert sizes are divisible + assert ( + height % tile_size == 0 and width % tile_size == 0 + ), f"Image size {height}x{width} is not divisible by tile size {tile_size}" + + # Reshape to split height and width into tile_size blocks + tiles_height = height // tile_size + tiles_width = width // tile_size + + reshaped = image.view(channel_size, tiles_height, tile_size, tiles_width, tile_size) + + # Transpose to bring tiles together + # We want [tiles_height, tiles_width, channel_size, tile_size, tile_size] + transposed = reshaped.permute(1, 3, 0, 2, 4) + + # Flatten the tiles + tiles = transposed.contiguous().view( + tiles_height * tiles_width, channel_size, tile_size, tile_size + ) + + return tiles + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py +def resize_with_pad( + image: torch.Tensor, + target_size: Tuple[int, int], + resample: torchvision.transforms.InterpolationMode, + max_size: Optional[int] = None, +) -> torch.Tensor: + """ + Resizes and pads an image to target_size without causing distortion. + The user can set max_size to limit upscaling when target_size exceeds image_size. + + Args: + image (torch.Tensor): The input image tensor in the format [..., H, W]. + target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width]. + resample (torchvision.transforms.InterpolationMode): Resampling method used when resizing images. + Supports torchvision.transforms.InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT, + InterpolationMode.BILINEAR and InterpolationMode.BICUBIC. + max_size (Optional[int]): The maximum size to upscale the image to. + If None, will upscale up to target_size. + + Returns: + torch.Tensor: The resized and padded image tensor in the format [..., H, W]. + + Examples: + + Example 1: The image will be upscaled from (300, 800) to (448, 1194), since 448 is the limiting side, + and then padded from (448, 1194) to (448, 1344). + + >>> max_size = None + >>> image = torch.rand([3, 300, 800]) + >>> target_size = (448, 1344) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_size) + + Example 2: The image will stay as is, since 800 > 600, and then padded from (300, 800) to (448, 1344). + + >>> max_size = 600 + >>> image = torch.rand([3, 300, 800]) + >>> target_size = (448, 1344) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_size) + + Example 3: The image will be downscaled from (500, 1000) to (224, 448), + and padded from (224, 448) to (448, 448). + + >>> max_size = 600 + >>> image = torch.rand([3, 500, 1000]) + >>> target_size = (448, 488) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_size) + + """ + + image_height, image_width = image.shape[-2:] + image_size = (image_height, image_width) + + # If target_size requires upscaling, we might want to limit the upscaling to max_size + if max_size is not None: + new_target_height = min(max(image_height, max_size), target_size[0]) + new_target_width = min(max(image_width, max_size), target_size[1]) + target_size_resize = (new_target_height, new_target_width) + else: + target_size_resize = target_size + + # resize to target_size while preserving aspect ratio + new_size_preserving_aspect_ratio = _get_max_res_without_distortion( + image_size=image_size, + target_size=target_size_resize, + ) + + image = F.resize( + inpt=image, + size=list(new_size_preserving_aspect_ratio), + interpolation=resample, + antialias=True, + ) + + image = _pad_image_top_left(image=image, target_size=target_size) + + return image + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py +def _pad_image_top_left( + image: torch.Tensor, + target_size: Tuple[int, int], +) -> torch.Tensor: + """ + Places the image at the top left of the canvas and pads with 0 the right and bottom + to fit to the target resolution. If target_size < image_size, it will crop the image. + + Args: + image (torch.Tensor): The input image tensor in the format [..., H, W]. + target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width]. + + Returns: + torch.Tensor: The padded image tensor in the format [..., H, W]. + """ + + image_size = image.shape[-2:] + + height, width = image_size + target_height, target_width = target_size + + pad_x = target_width - width + pad_y = target_height - height + + padding = [0, 0, pad_x, pad_y] + return F.pad(inpt=image, padding=padding) + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py +def _get_max_res_without_distortion( + image_size: Tuple[int, int], + target_size: Tuple[int, int], +) -> Tuple[int, int]: + """ + Determines the maximum resolution to which an image can be resized to without distorting its + aspect ratio, based on the target resolution. + + For example, if image_size = (200,400) and target_size = (600,800), + scale_h = 600/200 = 3 + scale_w = 800/400 = 2 + So the maximum that we can upscale without distortion is min(scale_h, scale_w) = 2 + + Since scale_w is the limiting side, then new_w = target_w, and new_h = old_h*scale_w + + Args: + image_size (Tuple[int, int]): The original resolution of the image. + target_size (Tuple[int, int]): The desired resolution to fit the image into. + Returns: + Tuple[int, int]: The optimal dimensions to which the image should be resized. + Examples: + >>> _get_max_res_without_distortion([200, 300], target_size = (450, 200)) + (133, 200) + >>> _get_max_res_without_distortion([800, 600], target_size = (450, 1300)) + (450, 337) + """ + + original_height, original_width = image_size + target_height, target_width = target_size + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.floor(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.floor(original_width * scale_h), target_width) + + return new_height, new_width + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py +def _get_factors(n: int) -> Set[int]: + """ + Calculate all factors of a given number, i.e. a divisor that leaves no remainder. + + Args: + n (int): The number to find factors for. + + Returns: + set: A set containing all factors of the number. + + Examples: + >>> _get_factors(n=12) + {1, 2, 3, 4, 6, 12} + """ + factors_set = set() + + for i in range(1, int(n**0.5) + 1): + if n % i == 0: + factors_set.add(i) + factors_set.add(n // i) + return factors_set + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py +def get_canvas_best_fit( + image: torch.Tensor, possible_resolutions: torch.Tensor, resize_to_max_canvas: bool +) -> Tuple[int, int]: + """ + Determines the best canvas possible from a list of possible resolutions to + resize an image to, without distortion. + + For each possible resolution, calculates the scaling factors for + width and height, and selects the smallest one, which is the limiting side. + E.g. if to match a canvas shape you have to upscale an image's height by 2x, and width by 1.5x, + then the maximum upscaling without distortion is min(2, 1.5) = 1.5. + + If there are multiple canvases that satisfy the conditions, + we pick the one with the lowest area to minimize padding. + + Args: + image (torch.Tensor): The image we want to fit into a canvas. + possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each + row represents a possible canvas. + resize_to_max_canvas (bool): If True, pick the canvas that allows maximum scaling. + If False, pick the canvas that minimizes downscaling, including no downscaling at all. + + Returns: + Tuple[int, int]: The best resolution to fit the image into. + + Examples: + >>> image = torch.rand(3, 200, 300) + >>> possible_resolutions = torch.tensor([ + ... [224, 672], + ... [672, 224], + ... [224, 448], + ... [448, 224], + ... [224, 224] + ... ]) + >>> get_canvas_best_fit(image, possible_resolutions, resize_to_max_canvas=False) + (224, 448) + + In the example above, we calculate the scaling factors for each possible resolution + + >>> scale_height = torch.tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200]) + >>> scale_width = torch.tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467]) + >>> scales = torch.tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467]) + + Two options have scaling_factor > 1, since resize_to_max_canvas is False, we pick the smallest + + >>> upscaling_options = torch.tensor([1.1200, 1.1200]) + >>> selected_scale = torch.tensor(1.1200) + + There are two possible options, so we pick the one with the smallest area + + >>> areas = torch.tensor([150528, 100352]) # for resolutions [672, 224] and [224, 448], respectively + >>> optimal_canvas = torch.tensor([224, 448]) # resolution with the smallest area + """ + + original_height, original_width = image.shape[-2:] + + # possible resolutions heights/widths + target_heights, target_widths = ( + possible_resolutions[:, 0], + possible_resolutions[:, 1], + ) + + # scaling factors to resize the image without distortion + scale_w = target_widths / original_width + scale_h = target_heights / original_height + + # get limiting side scaling -> no distortion + scales = torch.where(scale_w > scale_h, scale_h, scale_w) + + # filter only scales that allow upscaling + upscaling_options = scales[scales >= 1] + if len(upscaling_options) > 0: + if resize_to_max_canvas: + selected_scale = torch.max(upscaling_options) + else: + selected_scale = torch.min(upscaling_options) + else: + # no upscaling possible, + # get the minimum downscaling (max scale for scales<1) + downscaling_options = scales[scales < 1] + selected_scale = torch.max(downscaling_options) + + # get all resolutions that support this scaling factor, + # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion + chosen_canvas = possible_resolutions[scales == selected_scale] + + # if there are multiple resolutions, + # get the one with minimum area to reduce padding + if len(chosen_canvas) > 1: + areas = chosen_canvas[:, 0] * chosen_canvas[:, 1] + optimal_idx = torch.argmin(areas) + optimal_canvas = chosen_canvas[optimal_idx] + else: + optimal_canvas = chosen_canvas[0] + + return tuple(optimal_canvas.tolist()) + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py +def find_supported_resolutions( + max_num_tiles: int, tile_size: int +) -> List[Tuple[int, int]]: + """ + Computes all combinations of resolutions, multiple of tile_size, + that contain up to max_num_tiles. Useful for when dividing an image into tiles. + + For example, if we want at most 2 tiles per image, then we can support the + following resolutions: (1x1, 1x2, 2x1) * tile_size + + Args: + max_num_tiles (int): Maximum number of tiles. + tile_size (int): Size of the side of the tile. + + Returns: + List[Tuple[int, int]]: List of possible resolutions as tuples (height, width). + + Examples: + + >>> max_num_tiles = 4 + >>> tile_size = 224 + >>> find_supported_resolutions(max_num_tiles, tile_size) + [(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), (672, 224), (224, 448), (448, 224)] + """ + + # create dictionary {aspect_ratio: [resolution1, ..., resolution n]} + # example {0.25: [(1,4)], 1.0: [(2,2), (1,1)], 4.0: [(4,1)]} + asp_dict = defaultdict(list) + for _tile_size in range(max_num_tiles, 0, -1): + factors = sorted(_get_factors(_tile_size)) + asp_ratios = [(factor, _tile_size // factor) for factor in factors] + for height, width in asp_ratios: + ratio_float = height / width + asp_dict[ratio_float].append((height, width)) + + # get the resolutions multiplied by the tile_size + possible_resolutions = [] + for ar, resolution in asp_dict.items(): + for height, width in resolution: + possible_resolutions.append((height * tile_size, width * tile_size)) + + return possible_resolutions + + +# NOTE Copied from torchtune.data._utils.py +def load_image(image_loc: Union[Path, str]) -> torch.Tensor: + """ + Convenience method to load an image in torch.Tensor format from a local file path or remote source. + + Args: + image_loc (Union[Path, str]): Local file path or remote source pointing to the image + which will be loaded in PIL format. + + Note: + If loading an image from a remote source, the function expects the URL provided in ``image_loc`` + to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg". + + Raises: + ValueError: If the image cannot be loaded from remote source, **or** + if the image cannot be opened as a :class:`~torch.Tensor`. + + Examples: + >>> # Load from remote source + >>> image = load_image("https://www.wikipedia.org/en/bird.jpg") + + >>> # Load from local file path + >>> image = load_image(Path("/home/user/bird.jpg")) + + Returns: + torch.Tensor: The loaded image. + """ + + # If pointing to remote source, try to load to local + if isinstance(image_loc, str) and image_loc.startswith("http"): + try: + image_loc = request.urlopen(image_loc).read() + image = torchvision.io.decode_image( + torch.frombuffer(image_loc, dtype=torch.uint8), + mode="RGB", + ) + except Exception as e: + raise ValueError("Failed to load remote image as torch.Tensor") from e + + # Open the local image as a Tensor image + else: + try: + image = torchvision.io.decode_image(image_loc, mode="RGB") + except Exception as e: + raise ValueError("Failed to load local image as torch.Tensor") from e + + return image From 3ce56164a5575f5dcf7c65fa68b635cd94c31fe5 Mon Sep 17 00:00:00 2001 From: Takuya Akiba <469803+iwiwi@users.noreply.github.com> Date: Wed, 2 Apr 2025 02:04:21 +0900 Subject: [PATCH 12/18] Correct self reference in Trainer (#1037) Fixed a bug where the `Trainer.train` method was incorrectly using `trainer` instead of `self`. When `train.py` is executed and the `if __name__ == "__main__": ...` block runs, the global variable `trainer` is identical to `self`, so execution proceeds without issues. However, when we import and run `Trainer` from other files, the `trainer` variable doesn't exist, causing an error. This PR fixes that issue. --- torchtitan/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index dcefbff421..0ef23f798c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -390,7 +390,7 @@ def train_step(self, inputs: torch.Tensor, labels: torch.Tensor): def train(self): job_config = self.job_config - trainer.checkpointer.load(step=job_config.checkpoint.load_step) + self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Training starts at step {self.step + 1}.") with maybe_enable_profiling( From 465b650c4230ad4bd30ad485647ec8da94eb05a1 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 20 Mar 2025 23:35:17 -0700 Subject: [PATCH 13/18] Move SDPA to a seperate file --- torchtitan/models/attention.py | 55 ++++++++++++++++++++ torchtitan/models/llama/model.py | 47 +++-------------- torchtitan/models/llama/parallelize_llama.py | 15 ++++-- 3 files changed, 72 insertions(+), 45 deletions(-) create mode 100644 torchtitan/models/attention.py diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py new file mode 100644 index 0000000000..a0e90b5cba --- /dev/null +++ b/torchtitan/models/attention.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from typing import Callable, ClassVar, Optional + +import torch +import torch.nn.functional as F +from torch.nn.attention.flex_attention import ( + BlockMask, + create_block_mask, + flex_attention, +) + + +class SDPA(torch.nn.Module): + # We registered flex_attention related attributes as class variables as we + # need to amortize the cost of compilation. Enabling per-instance flex_attention + # is not supported. + block_mask: ClassVar[Optional[BlockMask]] = None + use_flex_attn: ClassVar[bool] = False + flex_attn: ClassVar[Optional[Callable]] = None + + def __init__(self) -> None: + self.use_flex_attn = model_args.use_flex_attn + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + if self.use_flex_attn: + # assert False, (type(xq), type(xk), type(xv)) + self._init_flex_attn(seqlen=seqlen) + return self.flex_attn(xq, xk, xv, block_mask=self.block_mask) + else: + return F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + + @torch.no_grad() + def _init_flex_attn(self, seqlen: int) -> None: + if self.block_mask is not None: + return + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + compiled_create_block_mask = torch.compile(create_block_mask) + self.block_mask = compiled_create_block_mask( + causal_mask, None, None, seqlen, seqlen + ) + self.flex_attn = torch.compile( + flex_attention, mode="max-autotune-no-cudagraphs" + ) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 4ab6da41e8..27b5d8684c 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -8,20 +8,15 @@ from dataclasses import dataclass -from typing import Callable, ClassVar, Optional, Tuple +from typing import Optional import torch -import torch.nn.functional as F from torch import nn -from torch.nn.attention.flex_attention import ( - BlockMask, - create_block_mask, - flex_attention, -) from torchtitan.components.tokenizer import Tokenizer from torchtitan.config_manager import JobConfig +from torchtitan.models.attention import SDPA from torchtitan.models.norms import build_norm from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol @@ -123,7 +118,7 @@ def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. @@ -138,7 +133,7 @@ def apply_rotary_emb( freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) @@ -179,13 +174,6 @@ class Attention(nn.Module): """ - # We registered flex_attention related attributes as class variables as we - # need to amortize the cost of compilation. Enabling per-instance flex_attention - # is not supported. - block_mask: ClassVar[Optional[BlockMask]] = None - use_flex_attn: ClassVar[bool] = False - flex_attn: ClassVar[Optional[Callable]] = None - def __init__(self, model_args: TransformerModelArgs): super().__init__() self.n_heads = model_args.n_heads @@ -205,7 +193,7 @@ def __init__(self, model_args: TransformerModelArgs): self.wo = nn.Linear( model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.use_flex_attn = model_args.use_flex_attn + self.sdpa = SDPA(model_args.use_flex_attn) def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -249,35 +237,14 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - # we use casual mask for training - if self.use_flex_attn: - # assert False, (type(xq), type(xk), type(xv)) - self._init_flex_attn(seqlen=seqlen) - output = self.flex_attn(xq, xk, xv, block_mask=self.block_mask) - else: - output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = self.sdpa(xq, xk, xv) + output = output.transpose( 1, 2 ).contiguous() # (bs, seqlen, n_local_heads, head_dim) output = output.view(bs, seqlen, -1) return self.wo(output) - @torch.no_grad() - def _init_flex_attn(self, seqlen: int) -> None: - if self.block_mask is not None: - return - - def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - compiled_create_block_mask = torch.compile(create_block_mask) - self.block_mask = compiled_create_block_mask( - causal_mask, None, None, seqlen, seqlen - ) - self.flex_attn = torch.compile( - flex_attention, mode="max-autotune-no-cudagraphs" - ) - class FeedForward(nn.Module): """ diff --git a/torchtitan/models/llama/parallelize_llama.py b/torchtitan/models/llama/parallelize_llama.py index 7649d0dfbe..ed2e6f0c78 100644 --- a/torchtitan/models/llama/parallelize_llama.py +++ b/torchtitan/models/llama/parallelize_llama.py @@ -72,15 +72,20 @@ def parallelize_llama( enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) - if job_config.activation_checkpoint.mode != "none": - if ( - job_config.activation_checkpoint.mode == "selective" - and job_config.model.use_flex_attn - ): + if job_config.model.use_flex_attn: + if job_config.activation_checkpoint.mode == "selective": raise ValueError( "FlexAttention is not compatible with selective AC yet. " "See https://github.com/pytorch/pytorch/issues/147879" ) + + if parallel_dims.cp_enabled: + raise ValueError( + "FlexAttention is not compatible with CP yet. " + "We are still working on this." + ) + + if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) # turn on per-TransformerBlock compile after AC wrapping and before FSDP From 8188bb22557a6f030ece66dd69355d3a52bfa6f3 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 21 Mar 2025 09:56:37 -0700 Subject: [PATCH 14/18] Make a separate attention module to cover sdpa logic --- torchtitan/models/attention.py | 11 ++++++----- torchtitan/models/llama/model.py | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index a0e90b5cba..de3d9dff3b 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -25,18 +25,19 @@ class SDPA(torch.nn.Module): use_flex_attn: ClassVar[bool] = False flex_attn: ClassVar[Optional[Callable]] = None - def __init__(self) -> None: - self.use_flex_attn = model_args.use_flex_attn + def __init__(self, use_flex_attn) -> None: + super().__init__() + self.use_flex_attn = use_flex_attn def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> torch.Tensor: if self.use_flex_attn: - # assert False, (type(xq), type(xk), type(xv)) + _, _, seqlen, _ = q.shape self._init_flex_attn(seqlen=seqlen) - return self.flex_attn(xq, xk, xv, block_mask=self.block_mask) + return self.flex_attn(q, k, v, block_mask=self.block_mask) else: - return F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + return F.scaled_dot_product_attention(q, k, v, is_causal=True) @torch.no_grad() def _init_flex_attn(self, seqlen: int) -> None: diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 27b5d8684c..027f96cd23 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -11,6 +11,7 @@ from typing import Optional import torch +import torch.nn.functional as F from torch import nn from torchtitan.components.tokenizer import Tokenizer From b3241862deabfa1551f6f1f96a4904e4cffdd243 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 26 Mar 2025 00:17:32 -0700 Subject: [PATCH 15/18] Document attention support --- torchtitan/config_manager.py | 11 +++ torchtitan/models/attention.py | 118 +++++++++++++++++++++++++------ torchtitan/models/llama/model.py | 10 ++- 3 files changed, 116 insertions(+), 23 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index b2134ecc6b..25fa2b3b91 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -198,6 +198,17 @@ def __init__(self): action="store_true", help="Whether to use Flex Attention.", ) + self.parser.add_argument( + "--model.attn_bias_type", + type=str, + default="causal", + choices=["causal", "block_causal"], + help=""" + Specifies the type of bias/mask used for attention. If SDPA is used, + only the causal mask is supported by default. If FlexAttention is used, + both causal and block_causal masks are supported. + """, + ) self.parser.add_argument( "--model.tokenizer_path", type=str, diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index de3d9dff3b..374fa36e08 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -17,40 +17,118 @@ ) -class SDPA(torch.nn.Module): +class FlexAttn(torch.nn.Module): # We registered flex_attention related attributes as class variables as we # need to amortize the cost of compilation. Enabling per-instance flex_attention # is not supported. block_mask: ClassVar[Optional[BlockMask]] = None - use_flex_attn: ClassVar[bool] = False flex_attn: ClassVar[Optional[Callable]] = None + attn_bias_type: ClassVar[Optional[str]] = None + compiled_create_block_mask: ClassVar[Optional[Callable]] = None - def __init__(self, use_flex_attn) -> None: + def __init__(self, attn_bias_type: str) -> None: super().__init__() - self.use_flex_attn = use_flex_attn + if FlexAttn.attn_bias_type is not None: + assert ( + FlexAttn.attn_bias_type == attn_bias_type + ), "All FlexAttention must have the same configurations." + else: + if attn_bias_type not in ["causal", "block_causal"]: + raise ValueError(f"Unrecognized attn_bias_type {attn_bias_type}.") + FlexAttn.attn_bias_type = attn_bias_type def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> torch.Tensor: - if self.use_flex_attn: - _, _, seqlen, _ = q.shape - self._init_flex_attn(seqlen=seqlen) - return self.flex_attn(q, k, v, block_mask=self.block_mask) - else: - return F.scaled_dot_product_attention(q, k, v, is_causal=True) + assert FlexAttn.block_mask is not None + assert FlexAttn.flex_attn is not None + return FlexAttn.flex_attn(q, k, v, block_mask=FlexAttn.block_mask) + + @classmethod + def _get_causal_mask_fn(cls) -> Callable: + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + return causal_mask + + @classmethod + def _get_block_causal_mask_fn(cls, batch: torch.Tensor, eos_id: int) -> Callable: + mask = batch == eos_id + mask[:, -1] = True + acc_mask = torch.cumsum(torch.where(mask, 1, 0).flatten(), dim=0) + seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32) + seq_idx[1:] = acc_mask[:-1] + def block_causal_mask(b, h, q_idx, kv_idx): + return (seq_idx[q_idx] == seq_idx[kv_idx]) & (q_idx >= kv_idx) + + return block_causal_mask + + @classmethod @torch.no_grad() - def _init_flex_attn(self, seqlen: int) -> None: - if self.block_mask is not None: + def init_attention_bias( + cls, batch: torch.Tensor, eos_id: Optional[int] = None + ) -> None: + if cls.block_mask is not None and cls.attn_bias_type == "causal": + # We don't need to create another block mask for causal masking if existed. return - def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx + match cls.attn_bias_type: + case "causal": + mask_fn = cls._get_causal_mask_fn() + case "block_causal": + mask_fn = cls._get_block_causal_mask_fn(batch, eos_id) + case _: + raise RuntimeError(f"Shouldn't reach here. {cls.attn_bias_type}") - compiled_create_block_mask = torch.compile(create_block_mask) - self.block_mask = compiled_create_block_mask( - causal_mask, None, None, seqlen, seqlen - ) - self.flex_attn = torch.compile( - flex_attention, mode="max-autotune-no-cudagraphs" + seq_len = batch.shape[1] + if cls.compiled_create_block_mask is None: + cls.compiled_create_block_mask = torch.compile(create_block_mask) + cls.block_mask = cls.compiled_create_block_mask( + mask_fn, None, None, seq_len, seq_len ) + cls.flex_attn = torch.compile(flex_attention, mode="max-autotune-no-cudagraphs") + + +class SDPA(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + if attn_bias_type != "causal": + raise ValueError( + "TorchTitan with SDPA currently only supports causal mask." + ) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + return F.scaled_dot_product_attention(q, k, v, is_causal=True) + + @classmethod + @torch.no_grad() + def init_attention_bias( + cls, + batch: torch.Tensor, + eos_id: Optional[int] = None, + ) -> None: + # For SDPA, we don't need to do anything. + return + + +_selected_attention = None + + +def build_attention(use_flex_attn: bool, attn_bias_type: str): + global _selected_attention + if use_flex_attn: + assert _selected_attention is None or _selected_attention == FlexAttn + _selected_attention = FlexAttn + return FlexAttn(attn_bias_type) + else: + assert _selected_attention is None or _selected_attention == SDPA + _selected_attention = SDPA + return SDPA(attn_bias_type) + + +def init_attention_bias(batch: torch.Tensor, eos_id: Optional[int] = None) -> None: + global _selected_attention + _selected_attention.init_attention_bias(batch, eos_id) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 027f96cd23..6cf155c4e3 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -16,8 +16,7 @@ from torchtitan.components.tokenizer import Tokenizer from torchtitan.config_manager import JobConfig - -from torchtitan.models.attention import SDPA +from torchtitan.models.attention import build_attention, init_attention_bias from torchtitan.models.norms import build_norm from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol @@ -41,6 +40,8 @@ class TransformerModelArgs(BaseModelArgs): norm_type: str = "rmsnorm" use_flex_attn: bool = False + attn_bias_type: str = "causal" + eos_id: int = 0 def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: self.norm_type = job_config.model.norm_type @@ -194,7 +195,7 @@ def __init__(self, model_args: TransformerModelArgs): self.wo = nn.Linear( model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.sdpa = SDPA(model_args.use_flex_attn) + self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_bias_type) def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -388,6 +389,7 @@ def __init__(self, model_args: TransformerModelArgs): self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) @@ -468,6 +470,8 @@ def forward(self, tokens: torch.Tensor): torch.Tensor: Output logits after applying the Transformer model. """ + init_attention_bias(tokens, eos_id=self.eos_id) + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens From bfd4472fc53780899891776ce88ad906ffc11dbe Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 26 Mar 2025 09:27:53 -0700 Subject: [PATCH 16/18] Minor fix --- torchtitan/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 374fa36e08..9f7bb5e399 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -91,7 +91,7 @@ def init_attention_bias( class SDPA(torch.nn.Module): - def __init__(self) -> None: + def __init__(self, attn_bias_type: str) -> None: super().__init__() if attn_bias_type != "causal": raise ValueError( From 5f3f809884be595f49c8e2daae71dff8dabe494a Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 26 Mar 2025 09:48:04 -0700 Subject: [PATCH 17/18] Minor fix --- torchtitan/models/llama/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 6cf155c4e3..dede3e04f0 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -48,6 +48,7 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non self.vocab_size = tokenizer.n_words self.max_seq_len = job_config.training.seq_len self.use_flex_attn = job_config.model.use_flex_attn + self.attn_bias_type = job_config.model.attn_bias_type def get_num_flop_per_token(self, num_params: int, seq_len: int) -> int: l, h, q, t = ( From 4d6393991ef4c9aa9fb8fa5545d23356d44fbcbb Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 2 Apr 2025 13:24:42 -0700 Subject: [PATCH 18/18] Allow mix usages of block masks --- tests/integration_tests.py | 1 + torchtitan/config_manager.py | 7 +- torchtitan/models/attention.py | 131 +++++++++++++++---------------- torchtitan/models/llama/model.py | 10 +-- 4 files changed, 74 insertions(+), 75 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 59a73efbf9..1e005706e3 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -302,6 +302,7 @@ def build_test_list(): "--parallelism.data_parallel_shard_degree=4", "--activation_checkpoint.mode='full'", "--model.use_flex_attn", + "--model.attn_mask_type='block_causal'", ] ], "FSDP+FLEX_ATTN", diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 25fa2b3b91..06452d2617 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -196,10 +196,13 @@ def __init__(self): self.parser.add_argument( "--model.use_flex_attn", action="store_true", - help="Whether to use Flex Attention.", + help=""" + Whether to use Flex Attention. + Mixed usage of SDPA and FlexAttention is not upported yet. + """, ) self.parser.add_argument( - "--model.attn_bias_type", + "--model.attn_mask_type", type=str, default="causal", choices=["causal", "block_causal"], diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 9f7bb5e399..9d03a6378b 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -17,32 +17,39 @@ ) +BatchBlockMaskType = tuple[Optional[int], BlockMask] + + class FlexAttn(torch.nn.Module): # We registered flex_attention related attributes as class variables as we - # need to amortize the cost of compilation. Enabling per-instance flex_attention - # is not supported. - block_mask: ClassVar[Optional[BlockMask]] = None - flex_attn: ClassVar[Optional[Callable]] = None - attn_bias_type: ClassVar[Optional[str]] = None - compiled_create_block_mask: ClassVar[Optional[Callable]] = None - - def __init__(self, attn_bias_type: str) -> None: + # need to amortize the cost of compilation. + flex_attn: ClassVar[Callable] = torch.compile( + flex_attention, mode="max-autotune-no-cudagraphs" + ) + compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) + used_attn_mask_types: ClassVar[set[str]] = set() + # Attention mask type to the created (id(batch), BlockMask). + # This allows us to keep track the created block masks for each + # new batch. We will use this to update the block mask when a + # new batch is created. This also allows user to create different + # block masks for different layers. + block_masks: ClassVar[dict[str, BatchBlockMaskType]] = {} + + # Instance variables. + attn_mask_type: str + + def __init__(self, attn_mask_type: str) -> None: super().__init__() - if FlexAttn.attn_bias_type is not None: - assert ( - FlexAttn.attn_bias_type == attn_bias_type - ), "All FlexAttention must have the same configurations." - else: - if attn_bias_type not in ["causal", "block_causal"]: - raise ValueError(f"Unrecognized attn_bias_type {attn_bias_type}.") - FlexAttn.attn_bias_type = attn_bias_type + if attn_mask_type not in ["causal", "block_causal"]: + raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") + self.attn_mask_type = attn_mask_type + FlexAttn.used_attn_mask_types.add(attn_mask_type) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> torch.Tensor: - assert FlexAttn.block_mask is not None - assert FlexAttn.flex_attn is not None - return FlexAttn.flex_attn(q, k, v, block_mask=FlexAttn.block_mask) + block_mask = FlexAttn.block_masks[self.attn_mask_type][1] + return FlexAttn.flex_attn(q, k, v, block_mask=block_mask) @classmethod def _get_causal_mask_fn(cls) -> Callable: @@ -66,34 +73,41 @@ def block_causal_mask(b, h, q_idx, kv_idx): @classmethod @torch.no_grad() - def init_attention_bias( + def init_attention_mask( cls, batch: torch.Tensor, eos_id: Optional[int] = None ) -> None: - if cls.block_mask is not None and cls.attn_bias_type == "causal": - # We don't need to create another block mask for causal masking if existed. - return - - match cls.attn_bias_type: - case "causal": - mask_fn = cls._get_causal_mask_fn() - case "block_causal": - mask_fn = cls._get_block_causal_mask_fn(batch, eos_id) - case _: - raise RuntimeError(f"Shouldn't reach here. {cls.attn_bias_type}") - - seq_len = batch.shape[1] - if cls.compiled_create_block_mask is None: - cls.compiled_create_block_mask = torch.compile(create_block_mask) - cls.block_mask = cls.compiled_create_block_mask( - mask_fn, None, None, seq_len, seq_len - ) - cls.flex_attn = torch.compile(flex_attention, mode="max-autotune-no-cudagraphs") - - -class SDPA(torch.nn.Module): - def __init__(self, attn_bias_type: str) -> None: + for attn_mask_type in cls.used_attn_mask_types: + block_mask = cls.block_masks.get(attn_mask_type, None) + if block_mask is not None: + batch_id = block_mask[0] + if batch_id is None or batch_id == id(batch): + continue + + match attn_mask_type: + case "causal": + batch_id = None + mask_fn = cls._get_causal_mask_fn() + case "block_causal": + batch_id = id(batch) + if eos_id is None: + raise RuntimeError( + "eos_id must be provided for block_causal mask." + ) + mask_fn = cls._get_block_causal_mask_fn(batch, eos_id) + case _: + raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") + + seq_len = batch.shape[1] + block_mask = cls.compiled_create_block_mask( + mask_fn, None, None, seq_len, seq_len + ) + cls.block_masks[attn_mask_type] = (batch_id, block_mask) + + +class ScaledDotProductAttention(torch.nn.Module): + def __init__(self, attn_mask_type: str) -> None: super().__init__() - if attn_bias_type != "causal": + if attn_mask_type != "causal": raise ValueError( "TorchTitan with SDPA currently only supports causal mask." ) @@ -103,32 +117,13 @@ def forward( ) -> torch.Tensor: return F.scaled_dot_product_attention(q, k, v, is_causal=True) - @classmethod - @torch.no_grad() - def init_attention_bias( - cls, - batch: torch.Tensor, - eos_id: Optional[int] = None, - ) -> None: - # For SDPA, we don't need to do anything. - return - - -_selected_attention = None - -def build_attention(use_flex_attn: bool, attn_bias_type: str): - global _selected_attention +def build_attention(use_flex_attn: bool, attn_mask_type: str): if use_flex_attn: - assert _selected_attention is None or _selected_attention == FlexAttn - _selected_attention = FlexAttn - return FlexAttn(attn_bias_type) + return FlexAttn(attn_mask_type) else: - assert _selected_attention is None or _selected_attention == SDPA - _selected_attention = SDPA - return SDPA(attn_bias_type) + return SDPA(attn_mask_type) -def init_attention_bias(batch: torch.Tensor, eos_id: Optional[int] = None) -> None: - global _selected_attention - _selected_attention.init_attention_bias(batch, eos_id) +def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None: + FlexAttn.init_attention_mask(batch, eos_id) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index dede3e04f0..8d483f4a30 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -16,7 +16,7 @@ from torchtitan.components.tokenizer import Tokenizer from torchtitan.config_manager import JobConfig -from torchtitan.models.attention import build_attention, init_attention_bias +from torchtitan.models.attention import build_attention, init_attention_mask from torchtitan.models.norms import build_norm from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol @@ -40,7 +40,7 @@ class TransformerModelArgs(BaseModelArgs): norm_type: str = "rmsnorm" use_flex_attn: bool = False - attn_bias_type: str = "causal" + attn_mask_type: str = "causal" eos_id: int = 0 def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: @@ -48,7 +48,7 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non self.vocab_size = tokenizer.n_words self.max_seq_len = job_config.training.seq_len self.use_flex_attn = job_config.model.use_flex_attn - self.attn_bias_type = job_config.model.attn_bias_type + self.attn_mask_type = job_config.model.attn_mask_type def get_num_flop_per_token(self, num_params: int, seq_len: int) -> int: l, h, q, t = ( @@ -196,7 +196,7 @@ def __init__(self, model_args: TransformerModelArgs): self.wo = nn.Linear( model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_bias_type) + self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -471,7 +471,7 @@ def forward(self, tokens: torch.Tensor): torch.Tensor: Output logits after applying the Transformer model. """ - init_attention_bias(tokens, eos_id=self.eos_id) + init_attention_mask(tokens, eos_id=self.eos_id) # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens