add infra support for HF checkpoint conversion#1404
Conversation
06eb133 to
a5dfd45
Compare
| ) | ||
| list(map(func, self.states[MODEL].model)) | ||
| else: | ||
| dcp.load(state_dict, checkpoint_id=checkpoint_id) |
There was a problem hiding this comment.
Given that we now always flatten the model state_dict, meaning no model.load_state_dict() will be called, we should call model.load_state_dict() manually if model state_dict is in state_dict. You can hard code this information by passing a flag to this API. Please add a TODO so that I know what to fix after I'm back.
torchtitan/protocols/hf_adapter.py
Outdated
|
|
||
| @staticmethod | ||
| @abstractmethod | ||
| def get_hf_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]: |
There was a problem hiding this comment.
convert_to_hf_state_dict is a more accurate name.
There was a problem hiding this comment.
Please see new API. I'm taking a titan-centric approach -- I only define to_hf / from_hf which I think from the context should be clear.
The context is that the class is called StateDictAdapter and args & returns are both state_dict.
In the future we can have pairs like to_meta / from_meta, to_vllm / from_vllm, etc.
There was a problem hiding this comment.
I'm a little bit confusing what "Any" represents here, it might be too ideal to keep the whole state_dict with Tensors in memory. For example large model like deepseek-v3, should we only keep the state_dict name here, and leave the "Any" to be "None"?
de3fc7f to
9291c71
Compare
|
@tianyu-l Thanks for getting HF checkpoint conversion working. |
| @@ -605,8 +624,8 @@ def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType: | |||
|
|
|||
| for filename in os.listdir(checkpoint_id): | |||
| if filename == "model.safetensors.index.json": | |||
There was a problem hiding this comment.
Quick note: This function is not accurate. For smaller models (with only one safetensor file, eg, https://huggingface.co/meta-llama/Llama-3.2-1B/tree/main) , it doesn't have "model.safetensors.index.json" file. This file only exists when models weights are split into several safetensor files.
There was a problem hiding this comment.
Good point. Let me address this in the next PR, as CI is broken. I don't want to submit another commit and get blocked.
This PRs adds a new field `StateDictAdapter` in `TrainSpec` - Currently it only contains a pair of static methods `to_hf` and `from_hf`. Later we could add other pairs like `to_meta` / `from_meta`, `to_vllm` / `from_vllm`, etc. - It is passed to `CheckpointManager` to convert between torchtitan model and HF model during checkpoint save / load. - It could also be potentially used by downstream inference engines which only supports HF models. In order to save / load in HF format, a model is required to have a corresponding `StateDictAdapter` subclass implementation. For Llama3, I created a placeholder `Llama3StateDictAdapter` to be implemented. cc @wesleytruong This PR also renames checkpoint config options `initial_load_model_weights_only` `last_save_model_weights_only` to simply `initial_load_model_only` `last_save_model_only`, respectively. It seems to me that the original names were made corresponding to `torch.load(..., weights_only=True)`. As long as we document & test clearly when this correspondence holds, I prefer the name in torchtitan to be simple and less ambiguous.
This PRs adds a new field `StateDictAdapter` in `TrainSpec` - Currently it only contains a pair of static methods `to_hf` and `from_hf`. Later we could add other pairs like `to_meta` / `from_meta`, `to_vllm` / `from_vllm`, etc. - It is passed to `CheckpointManager` to convert between torchtitan model and HF model during checkpoint save / load. - It could also be potentially used by downstream inference engines which only supports HF models. In order to save / load in HF format, a model is required to have a corresponding `StateDictAdapter` subclass implementation. For Llama3, I created a placeholder `Llama3StateDictAdapter` to be implemented. cc @wesleytruong This PR also renames checkpoint config options `initial_load_model_weights_only` `last_save_model_weights_only` to simply `initial_load_model_only` `last_save_model_only`, respectively. It seems to me that the original names were made corresponding to `torch.load(..., weights_only=True)`. As long as we document & test clearly when this correspondence holds, I prefer the name in torchtitan to be simple and less ambiguous.
This PRs adds a new field `StateDictAdapter` in `TrainSpec` - Currently it only contains a pair of static methods `to_hf` and `from_hf`. Later we could add other pairs like `to_meta` / `from_meta`, `to_vllm` / `from_vllm`, etc. - It is passed to `CheckpointManager` to convert between torchtitan model and HF model during checkpoint save / load. - It could also be potentially used by downstream inference engines which only supports HF models. In order to save / load in HF format, a model is required to have a corresponding `StateDictAdapter` subclass implementation. For Llama3, I created a placeholder `Llama3StateDictAdapter` to be implemented. cc @wesleytruong This PR also renames checkpoint config options `initial_load_model_weights_only` `last_save_model_weights_only` to simply `initial_load_model_only` `last_save_model_only`, respectively. It seems to me that the original names were made corresponding to `torch.load(..., weights_only=True)`. As long as we document & test clearly when this correspondence holds, I prefer the name in torchtitan to be simple and less ambiguous.
This PRs adds a new field `StateDictAdapter` in `TrainSpec` - Currently it only contains a pair of static methods `to_hf` and `from_hf`. Later we could add other pairs like `to_meta` / `from_meta`, `to_vllm` / `from_vllm`, etc. - It is passed to `CheckpointManager` to convert between torchtitan model and HF model during checkpoint save / load. - It could also be potentially used by downstream inference engines which only supports HF models. In order to save / load in HF format, a model is required to have a corresponding `StateDictAdapter` subclass implementation. For Llama3, I created a placeholder `Llama3StateDictAdapter` to be implemented. cc @wesleytruong This PR also renames checkpoint config options `initial_load_model_weights_only` `last_save_model_weights_only` to simply `initial_load_model_only` `last_save_model_only`, respectively. It seems to me that the original names were made corresponding to `torch.load(..., weights_only=True)`. As long as we document & test clearly when this correspondence holds, I prefer the name in torchtitan to be simple and less ambiguous.
This PRs adds a new field `StateDictAdapter` in `TrainSpec` - Currently it only contains a pair of static methods `to_hf` and `from_hf`. Later we could add other pairs like `to_meta` / `from_meta`, `to_vllm` / `from_vllm`, etc. - It is passed to `CheckpointManager` to convert between torchtitan model and HF model during checkpoint save / load. - It could also be potentially used by downstream inference engines which only supports HF models. In order to save / load in HF format, a model is required to have a corresponding `StateDictAdapter` subclass implementation. For Llama3, I created a placeholder `Llama3StateDictAdapter` to be implemented. cc @wesleytruong This PR also renames checkpoint config options `initial_load_model_weights_only` `last_save_model_weights_only` to simply `initial_load_model_only` `last_save_model_only`, respectively. It seems to me that the original names were made corresponding to `torch.load(..., weights_only=True)`. As long as we document & test clearly when this correspondence holds, I prefer the name in torchtitan to be simple and less ambiguous.
This PRs adds a new field
StateDictAdapterinTrainSpecto_hfandfrom_hf. Later we could add other pairs liketo_meta/from_meta,to_vllm/from_vllm, etc.CheckpointManagerto convert between torchtitan model and HF model during checkpoint save / load.In order to save / load in HF format, a model is required to have a corresponding
StateDictAdaptersubclass implementation. For Llama3, I created a placeholderLlama3StateDictAdapterto be implemented. cc @wesleytruongThis PR also renames checkpoint config options
initial_load_model_weights_onlylast_save_model_weights_onlyto simplyinitial_load_model_onlylast_save_model_only, respectively.It seems to me that the original names were made corresponding to
torch.load(..., weights_only=True). As long as we document & test clearly when this correspondence holds, I prefer the name in torchtitan to be simple and less ambiguous.