Skip to content

Commit 48a6c75

Browse files
ebsmothersRdoubleA
authored andcommitted
Fix CLIP pos embedding interpolation to work on DTensors (#1739)
1 parent a788425 commit 48a6c75

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

torchtune/models/clip/_position_embeddings.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import torch.nn.functional as F
1212
from torch import nn
13+
from torch.distributed._tensor import distribute_tensor, DTensor
1314

1415

1516
class TokenPositionalEmbedding(nn.Module):
@@ -137,8 +138,20 @@ def _load_state_dict_hook(
137138
inpt_local_pos_embed = state_dict.get(
138139
prefix + "local_token_positional_embedding"
139140
)
141+
140142
if inpt_local_pos_embed is not None:
141143

144+
# We can only apply F.interpolate to vanilla tensors, not DTensors
145+
# If pos embeds are a DTensor, we gather the full tensor, apply
146+
# interpolate, and then reshard after
147+
if isinstance(inpt_local_pos_embed, DTensor):
148+
local_embed_is_sharded = True
149+
local_embed_device_mesh = inpt_local_pos_embed.device_mesh
150+
local_embed_placements = inpt_local_pos_embed.placements
151+
inpt_local_pos_embed = inpt_local_pos_embed.full_tensor()
152+
else:
153+
local_embed_is_sharded = False
154+
142155
# sanity check
143156
inpt_n_tokens_per_tile, inpt_embed_dim = inpt_local_pos_embed.shape
144157
if math.sqrt(inpt_n_tokens_per_tile - 1) % 1 != 0:
@@ -159,6 +172,13 @@ def _load_state_dict_hook(
159172
tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
160173
)
161174

175+
if local_embed_is_sharded:
176+
inpt_local_pos_embed = distribute_tensor(
177+
inpt_local_pos_embed,
178+
device_mesh=local_embed_device_mesh,
179+
placements=local_embed_placements,
180+
)
181+
162182
# update state dict
163183
state_dict[
164184
prefix + "local_token_positional_embedding"
@@ -176,8 +196,20 @@ def _load_state_dict_hook(
176196
inpt_global_pos_embed = state_dict.get(
177197
prefix + "global_token_positional_embedding"
178198
)
199+
179200
if inpt_global_pos_embed is not None:
180201

202+
# We can only apply F.interpolate to vanilla tensors, not DTensors
203+
# If pos embeds are a DTensor, we gather the full tensor, apply
204+
# interpolate, and then reshard after
205+
if isinstance(inpt_global_pos_embed, DTensor):
206+
global_embed_is_sharded = True
207+
global_embed_device_mesh = inpt_global_pos_embed.device_mesh
208+
global_embed_placements = inpt_global_pos_embed.placements
209+
inpt_global_pos_embed = inpt_global_pos_embed.full_tensor()
210+
else:
211+
global_embed_is_sharded = False
212+
181213
_, _, inpt_n_tokens_per_tile, _ = inpt_global_pos_embed.shape
182214

183215
# sanity check
@@ -202,6 +234,13 @@ def _load_state_dict_hook(
202234
tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
203235
)
204236

237+
if global_embed_is_sharded:
238+
inpt_global_pos_embed = distribute_tensor(
239+
inpt_global_pos_embed,
240+
device_mesh=global_embed_device_mesh,
241+
placements=global_embed_placements,
242+
)
243+
205244
# update state dict
206245
state_dict[
207246
prefix + "global_token_positional_embedding"
@@ -500,6 +539,17 @@ def _load_state_dict_hook(
500539

501540
if embedding is not None:
502541

542+
# We can only apply F.interpolate to vanilla tensors, not DTensors
543+
# If pos embeds are a DTensor, we gather the full tensor, apply
544+
# interpolate, and then reshard after
545+
if isinstance(embedding, DTensor):
546+
embedding_is_sharded = True
547+
device_mesh = embedding.device_mesh
548+
placements = embedding.placements
549+
embedding = embedding.full_tensor()
550+
else:
551+
embedding_is_sharded = False
552+
503553
# ckpt pos emb
504554
(
505555
tgt_max_num_tiles_x,
@@ -534,6 +584,13 @@ def _load_state_dict_hook(
534584
embedding, tgt_max_num_tiles=tgt_max_num_tiles_x
535585
)
536586

587+
if embedding_is_sharded:
588+
embedding_new = distribute_tensor(
589+
embedding_new,
590+
device_mesh=device_mesh,
591+
placements=placements,
592+
)
593+
537594
# update state dict
538595
state_dict[prefix + "embedding"] = embedding_new
539596
if embedding_new.shape != self.embedding.shape:

0 commit comments

Comments
 (0)