Skip to content

add infra support for HF checkpoint conversion#1404

Merged
tianyu-l merged 1 commit intomainfrom
hf_conversion
Jul 20, 2025
Merged

add infra support for HF checkpoint conversion#1404
tianyu-l merged 1 commit intomainfrom
hf_conversion

Conversation

@tianyu-l
Copy link
Contributor

@tianyu-l tianyu-l commented Jul 16, 2025

This PRs adds a new field StateDictAdapter in TrainSpec

  • Currently it only contains a pair of static methods to_hf and from_hf. Later we could add other pairs like to_meta / from_meta, to_vllm / from_vllm, etc.
  • It is passed to CheckpointManager to convert between torchtitan model and HF model during checkpoint save / load.
  • It could also be potentially used by downstream inference engines which only supports HF models.

In order to save / load in HF format, a model is required to have a corresponding StateDictAdapter subclass implementation. For Llama3, I created a placeholder Llama3StateDictAdapter to be implemented. cc @wesleytruong

This PR also renames checkpoint config options initial_load_model_weights_only last_save_model_weights_only to simply initial_load_model_only last_save_model_only, respectively.
It seems to me that the original names were made corresponding to torch.load(..., weights_only=True). As long as we document & test clearly when this correspondence holds, I prefer the name in torchtitan to be simple and less ambiguous.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 16, 2025
@tianyu-l tianyu-l force-pushed the hf_conversion branch 2 times, most recently from 06eb133 to a5dfd45 Compare July 16, 2025 06:48
)
list(map(func, self.states[MODEL].model))
else:
dcp.load(state_dict, checkpoint_id=checkpoint_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we now always flatten the model state_dict, meaning no model.load_state_dict() will be called, we should call model.load_state_dict() manually if model state_dict is in state_dict. You can hard code this information by passing a flag to this API. Please add a TODO so that I know what to fix after I'm back.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do


@staticmethod
@abstractmethod
def get_hf_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert_to_hf_state_dict is a more accurate name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see new API. I'm taking a titan-centric approach -- I only define to_hf / from_hf which I think from the context should be clear.

The context is that the class is called StateDictAdapter and args & returns are both state_dict.

In the future we can have pairs like to_meta / from_meta, to_vllm / from_vllm, etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little bit confusing what "Any" represents here, it might be too ideal to keep the whole state_dict with Tensors in memory. For example large model like deepseek-v3, should we only keep the state_dict name here, and leave the "Any" to be "None"?

@tianyu-l tianyu-l force-pushed the hf_conversion branch 3 times, most recently from de3fc7f to 9291c71 Compare July 16, 2025 09:11
@idoh
Copy link
Contributor

idoh commented Jul 16, 2025

@tianyu-l Thanks for getting HF checkpoint conversion working.
This is incredibly helpful for research and compatibility with evaluation frameworks.

@@ -605,8 +624,8 @@ def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType:

for filename in os.listdir(checkpoint_id):
if filename == "model.safetensors.index.json":
Copy link
Contributor

@wwwjn wwwjn Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick note: This function is not accurate. For smaller models (with only one safetensor file, eg, https://huggingface.co/meta-llama/Llama-3.2-1B/tree/main) , it doesn't have "model.safetensors.index.json" file. This file only exists when models weights are split into several safetensor files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let me address this in the next PR, as CI is broken. I don't want to submit another commit and get blocked.

@tianyu-l tianyu-l merged commit beb29a1 into main Jul 20, 2025
10 checks passed
@tianyu-l tianyu-l deleted the hf_conversion branch July 20, 2025 23:21
idoh pushed a commit to idoh/torchtitan that referenced this pull request Jul 28, 2025
This PRs adds a new field `StateDictAdapter` in `TrainSpec`
- Currently it only contains a pair of static methods `to_hf` and
`from_hf`. Later we could add other pairs like `to_meta` / `from_meta`,
`to_vllm` / `from_vllm`, etc.
- It is passed to `CheckpointManager` to convert between torchtitan
model and HF model during checkpoint save / load.
- It could also be potentially used by downstream inference engines
which only supports HF models.

In order to save / load in HF format, a model is required to have a
corresponding `StateDictAdapter` subclass implementation. For Llama3, I
created a placeholder `Llama3StateDictAdapter` to be implemented. cc
@wesleytruong


This PR also renames checkpoint config options
`initial_load_model_weights_only` `last_save_model_weights_only` to
simply `initial_load_model_only` `last_save_model_only`, respectively.
It seems to me that the original names were made corresponding to
`torch.load(..., weights_only=True)`. As long as we document & test
clearly when this correspondence holds, I prefer the name in torchtitan
to be simple and less ambiguous.
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
This PRs adds a new field `StateDictAdapter` in `TrainSpec`
- Currently it only contains a pair of static methods `to_hf` and
`from_hf`. Later we could add other pairs like `to_meta` / `from_meta`,
`to_vllm` / `from_vllm`, etc.
- It is passed to `CheckpointManager` to convert between torchtitan
model and HF model during checkpoint save / load.
- It could also be potentially used by downstream inference engines
which only supports HF models.

In order to save / load in HF format, a model is required to have a
corresponding `StateDictAdapter` subclass implementation. For Llama3, I
created a placeholder `Llama3StateDictAdapter` to be implemented. cc
@wesleytruong


This PR also renames checkpoint config options
`initial_load_model_weights_only` `last_save_model_weights_only` to
simply `initial_load_model_only` `last_save_model_only`, respectively.
It seems to me that the original names were made corresponding to
`torch.load(..., weights_only=True)`. As long as we document & test
clearly when this correspondence holds, I prefer the name in torchtitan
to be simple and less ambiguous.
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
This PRs adds a new field `StateDictAdapter` in `TrainSpec`
- Currently it only contains a pair of static methods `to_hf` and
`from_hf`. Later we could add other pairs like `to_meta` / `from_meta`,
`to_vllm` / `from_vllm`, etc.
- It is passed to `CheckpointManager` to convert between torchtitan
model and HF model during checkpoint save / load.
- It could also be potentially used by downstream inference engines
which only supports HF models.

In order to save / load in HF format, a model is required to have a
corresponding `StateDictAdapter` subclass implementation. For Llama3, I
created a placeholder `Llama3StateDictAdapter` to be implemented. cc
@wesleytruong


This PR also renames checkpoint config options
`initial_load_model_weights_only` `last_save_model_weights_only` to
simply `initial_load_model_only` `last_save_model_only`, respectively.
It seems to me that the original names were made corresponding to
`torch.load(..., weights_only=True)`. As long as we document & test
clearly when this correspondence holds, I prefer the name in torchtitan
to be simple and less ambiguous.
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
This PRs adds a new field `StateDictAdapter` in `TrainSpec`
- Currently it only contains a pair of static methods `to_hf` and
`from_hf`. Later we could add other pairs like `to_meta` / `from_meta`,
`to_vllm` / `from_vllm`, etc.
- It is passed to `CheckpointManager` to convert between torchtitan
model and HF model during checkpoint save / load.
- It could also be potentially used by downstream inference engines
which only supports HF models.

In order to save / load in HF format, a model is required to have a
corresponding `StateDictAdapter` subclass implementation. For Llama3, I
created a placeholder `Llama3StateDictAdapter` to be implemented. cc
@wesleytruong


This PR also renames checkpoint config options
`initial_load_model_weights_only` `last_save_model_weights_only` to
simply `initial_load_model_only` `last_save_model_only`, respectively.
It seems to me that the original names were made corresponding to
`torch.load(..., weights_only=True)`. As long as we document & test
clearly when this correspondence holds, I prefer the name in torchtitan
to be simple and less ambiguous.
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
This PRs adds a new field `StateDictAdapter` in `TrainSpec`
- Currently it only contains a pair of static methods `to_hf` and
`from_hf`. Later we could add other pairs like `to_meta` / `from_meta`,
`to_vllm` / `from_vllm`, etc.
- It is passed to `CheckpointManager` to convert between torchtitan
model and HF model during checkpoint save / load.
- It could also be potentially used by downstream inference engines
which only supports HF models.

In order to save / load in HF format, a model is required to have a
corresponding `StateDictAdapter` subclass implementation. For Llama3, I
created a placeholder `Llama3StateDictAdapter` to be implemented. cc
@wesleytruong


This PR also renames checkpoint config options
`initial_load_model_weights_only` `last_save_model_weights_only` to
simply `initial_load_model_only` `last_save_model_only`, respectively.
It seems to me that the original names were made corresponding to
`torch.load(..., weights_only=True)`. As long as we document & test
clearly when this correspondence holds, I prefer the name in torchtitan
to be simple and less ambiguous.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants