1010import torch
1111import torch .nn .functional as F
1212from torch import nn
13+ from torch .distributed ._tensor import distribute_tensor , DTensor
1314
1415
1516class TokenPositionalEmbedding (nn .Module ):
@@ -137,8 +138,20 @@ def _load_state_dict_hook(
137138 inpt_local_pos_embed = state_dict .get (
138139 prefix + "local_token_positional_embedding"
139140 )
141+
140142 if inpt_local_pos_embed is not None :
141143
144+ # We can only apply F.interpolate to vanilla tensors, not DTensors
145+ # If pos embeds are a DTensor, we gather the full tensor, apply
146+ # interpolate, and then reshard after
147+ if isinstance (inpt_local_pos_embed , DTensor ):
148+ local_embed_is_sharded = True
149+ local_embed_device_mesh = inpt_local_pos_embed .device_mesh
150+ local_embed_placements = inpt_local_pos_embed .placements
151+ inpt_local_pos_embed = inpt_local_pos_embed .full_tensor ()
152+ else :
153+ local_embed_is_sharded = False
154+
142155 # sanity check
143156 inpt_n_tokens_per_tile , inpt_embed_dim = inpt_local_pos_embed .shape
144157 if math .sqrt (inpt_n_tokens_per_tile - 1 ) % 1 != 0 :
@@ -159,6 +172,13 @@ def _load_state_dict_hook(
159172 tgt_patch_grid_size = int (math .sqrt (tgt_n_tokens_per_tile - 1 )),
160173 )
161174
175+ if local_embed_is_sharded :
176+ inpt_local_pos_embed = distribute_tensor (
177+ inpt_local_pos_embed ,
178+ device_mesh = local_embed_device_mesh ,
179+ placements = local_embed_placements ,
180+ )
181+
162182 # update state dict
163183 state_dict [
164184 prefix + "local_token_positional_embedding"
@@ -176,8 +196,20 @@ def _load_state_dict_hook(
176196 inpt_global_pos_embed = state_dict .get (
177197 prefix + "global_token_positional_embedding"
178198 )
199+
179200 if inpt_global_pos_embed is not None :
180201
202+ # We can only apply F.interpolate to vanilla tensors, not DTensors
203+ # If pos embeds are a DTensor, we gather the full tensor, apply
204+ # interpolate, and then reshard after
205+ if isinstance (inpt_global_pos_embed , DTensor ):
206+ global_embed_is_sharded = True
207+ global_embed_device_mesh = inpt_global_pos_embed .device_mesh
208+ global_embed_placements = inpt_global_pos_embed .placements
209+ inpt_global_pos_embed = inpt_global_pos_embed .full_tensor ()
210+ else :
211+ global_embed_is_sharded = False
212+
181213 _ , _ , inpt_n_tokens_per_tile , _ = inpt_global_pos_embed .shape
182214
183215 # sanity check
@@ -202,6 +234,13 @@ def _load_state_dict_hook(
202234 tgt_patch_grid_size = int (math .sqrt (tgt_n_tokens_per_tile - 1 )),
203235 )
204236
237+ if global_embed_is_sharded :
238+ inpt_global_pos_embed = distribute_tensor (
239+ inpt_global_pos_embed ,
240+ device_mesh = global_embed_device_mesh ,
241+ placements = global_embed_placements ,
242+ )
243+
205244 # update state dict
206245 state_dict [
207246 prefix + "global_token_positional_embedding"
@@ -500,6 +539,17 @@ def _load_state_dict_hook(
500539
501540 if embedding is not None :
502541
542+ # We can only apply F.interpolate to vanilla tensors, not DTensors
543+ # If pos embeds are a DTensor, we gather the full tensor, apply
544+ # interpolate, and then reshard after
545+ if isinstance (embedding , DTensor ):
546+ embedding_is_sharded = True
547+ device_mesh = embedding .device_mesh
548+ placements = embedding .placements
549+ embedding = embedding .full_tensor ()
550+ else :
551+ embedding_is_sharded = False
552+
503553 # ckpt pos emb
504554 (
505555 tgt_max_num_tiles_x ,
@@ -534,6 +584,13 @@ def _load_state_dict_hook(
534584 embedding , tgt_max_num_tiles = tgt_max_num_tiles_x
535585 )
536586
587+ if embedding_is_sharded :
588+ embedding_new = distribute_tensor (
589+ embedding_new ,
590+ device_mesh = device_mesh ,
591+ placements = placements ,
592+ )
593+
537594 # update state dict
538595 state_dict [prefix + "embedding" ] = embedding_new
539596 if embedding_new .shape != self .embedding .shape :
0 commit comments