@@ -522,7 +522,7 @@ def patchify_and_embed(
522522 max_cap_len = max (l_effective_cap_len )
523523 max_img_len = max (l_effective_img_len )
524524
525- position_ids = torch .zeros (bsz , max_seq_len , 3 , dtype = torch .int32 , device = device )
525+ position_ids = torch .zeros (bsz , max_seq_len , 3 , dtype = torch .float32 , device = device )
526526
527527 for i in range (bsz ):
528528 cap_len = l_effective_cap_len [i ]
@@ -531,10 +531,22 @@ def patchify_and_embed(
531531 H_tokens , W_tokens = H // pH , W // pW
532532 assert H_tokens * W_tokens == img_len
533533
534- position_ids [i , :cap_len , 0 ] = torch .arange (cap_len , dtype = torch .int32 , device = device )
534+ rope_options = transformer_options .get ("rope_options" , None )
535+ h_scale = 1.0
536+ w_scale = 1.0
537+ h_start = 0
538+ w_start = 0
539+ if rope_options is not None :
540+ h_scale = rope_options .get ("scale_y" , 1.0 )
541+ w_scale = rope_options .get ("scale_x" , 1.0 )
542+
543+ h_start = rope_options .get ("shift_y" , 0.0 )
544+ w_start = rope_options .get ("shift_x" , 0.0 )
545+
546+ position_ids [i , :cap_len , 0 ] = torch .arange (cap_len , dtype = torch .float32 , device = device )
535547 position_ids [i , cap_len :cap_len + img_len , 0 ] = cap_len
536- row_ids = torch .arange (H_tokens , dtype = torch .int32 , device = device ).view (- 1 , 1 ).repeat (1 , W_tokens ).flatten ()
537- col_ids = torch .arange (W_tokens , dtype = torch .int32 , device = device ).view (1 , - 1 ).repeat (H_tokens , 1 ).flatten ()
548+ row_ids = ( torch .arange (H_tokens , dtype = torch .float32 , device = device ) * h_scale + h_start ).view (- 1 , 1 ).repeat (1 , W_tokens ).flatten ()
549+ col_ids = ( torch .arange (W_tokens , dtype = torch .float32 , device = device ) * w_scale + w_start ).view (1 , - 1 ).repeat (H_tokens , 1 ).flatten ()
538550 position_ids [i , cap_len :cap_len + img_len , 1 ] = row_ids
539551 position_ids [i , cap_len :cap_len + img_len , 2 ] = col_ids
540552
0 commit comments