|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +Unified approach for running TorchTitan models with vLLM inference. |
| 9 | +
|
| 10 | +This module automatically registers TorchTitan models with vLLM when imported. |
| 11 | +Uses the canonical TorchTitan model definition directly with vLLM inference engine. |
| 12 | +""" |
| 13 | + |
| 14 | +from torchtitan.protocols.train_spec import get_train_spec, TrainSpec |
| 15 | +from vllm.logger import init_logger |
| 16 | + |
| 17 | +from .utils import create_parallel_dims_from_vllm_config |
| 18 | +from .vllm_wrapper import TorchTitanVLLMModelWrapper |
| 19 | + |
| 20 | + |
| 21 | +logger = init_logger(__name__) |
| 22 | + |
| 23 | + |
| 24 | +def register_torchtitan_model_from_train_spec( |
| 25 | + train_spec: TrainSpec, |
| 26 | + model_name: str, |
| 27 | + model_flavor: str, |
| 28 | +) -> None: |
| 29 | + """ |
| 30 | + Register a TorchTitan model with vLLM using a TrainSpec. |
| 31 | +
|
| 32 | + Args: |
| 33 | + train_spec: TorchTitan TrainSpec containing model components |
| 34 | + model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM") |
| 35 | + model_flavor: Model flavor key (e.g., "0.6B") to select from qwen3_args |
| 36 | +
|
| 37 | + """ |
| 38 | + from vllm.model_executor.models.registry import ModelRegistry |
| 39 | + |
| 40 | + # Get model_args directly from TrainSpec.model_args dict using flavor key |
| 41 | + if isinstance(train_spec.model_args, dict): |
| 42 | + if model_flavor not in train_spec.model_args: |
| 43 | + raise ValueError( |
| 44 | + f"Model flavor '{model_flavor}' not found in train_spec.model_args. " |
| 45 | + f"Available flavors: {list(train_spec.model_args.keys())}" |
| 46 | + ) |
| 47 | + model_args = train_spec.model_args[model_flavor] |
| 48 | + else: |
| 49 | + raise ValueError( |
| 50 | + "train_spec.model_args must be a dict mapping flavor names to ModelArgs" |
| 51 | + ) |
| 52 | + |
| 53 | + # Create dynamic model class directly from TrainSpec components |
| 54 | + class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModelWrapper): |
| 55 | + def __init__(self, *, vllm_config, prefix=""): |
| 56 | + super().__init__( |
| 57 | + model_cls=train_spec.model_cls, |
| 58 | + model_args=model_args, |
| 59 | + state_dict_adapter=train_spec.state_dict_adapter, |
| 60 | + parallelize_fn=train_spec.parallelize_fn, |
| 61 | + vllm_config=vllm_config, |
| 62 | + prefix=prefix, |
| 63 | + ) |
| 64 | + |
| 65 | + # Set the class name |
| 66 | + TorchTitanVLLMModelFromSpec.__name__ = model_name |
| 67 | + TorchTitanVLLMModelFromSpec.__qualname__ = model_name |
| 68 | + |
| 69 | + # Register with vLLM |
| 70 | + ModelRegistry.register_model(model_name, TorchTitanVLLMModelFromSpec) |
| 71 | + |
| 72 | + logger.info( |
| 73 | + f"Successfully registered {model_name} with vLLM using TrainSpec " |
| 74 | + f"(model_cls={train_spec.model_cls.__name__}, flavor={model_flavor})" |
| 75 | + ) |
| 76 | + |
| 77 | + |
| 78 | +# Auto-register TorchTitan models with vLLM when this module is imported |
| 79 | +register_torchtitan_model_from_train_spec( |
| 80 | + train_spec=get_train_spec("qwen3"), |
| 81 | + model_name="Qwen3TorchTitanForCausalLM", |
| 82 | + # TODO: Remove the model_flavor args when registering model, |
| 83 | + # allow passing model flavor option from config system. Now we have to specify |
| 84 | + # model_flavor during registration because we can not pass torchtitan job_config from LLM() Api |
| 85 | + model_flavor="0.6B", |
| 86 | +) |
| 87 | + |
| 88 | + |
| 89 | +__all__ = [ |
| 90 | + "TorchTitanVLLMModelWrapper", |
| 91 | + "create_parallel_dims_from_vllm_config", |
| 92 | + "register_torchtitan_model_from_train_spec", |
| 93 | +] |
0 commit comments