@@ -234,10 +234,40 @@ def apply_rot_embed_cat(x: torch.Tensor, emb):
234
234
return x * cos_emb + rot (x ) * sin_emb
235
235
236
236
237
- def apply_keep_indices_nlc (x , pos_embed , keep_indices ):
238
- pos_embed = pos_embed .unsqueeze (0 ).expand (x .shape [0 ], - 1 , - 1 )
239
- pos_embed = pos_embed .gather (1 , keep_indices .unsqueeze (- 1 ).expand (- 1 , - 1 , pos_embed .shape [- 1 ]))
240
- return pos_embed
237
+ def apply_keep_indices_nlc (
238
+ x : torch .Tensor ,
239
+ pos_embed : torch .Tensor ,
240
+ keep_indices : torch .Tensor ,
241
+ pos_embed_has_batch : bool = False ,
242
+ ) -> torch .Tensor :
243
+ """ Apply keep indices to different ROPE shapes
244
+ Expected shapes:
245
+ * pos_embed shape [seq_len, pos_embed_dim] → output [batch_size, seq_len, pos_embed_dim]
246
+ * pos_embed shape [num_heads, seq_len, pos_embed_dim] → output [batch_size, num_heads, seq_len, pos_embed_dim]
247
+ * pos_embed shape [depth, num_heads, seq_len, pos_embed_dim] → output [batch_size, depth, num_heads, seq_len, pos_embed_dim]
248
+
249
+ And all of the above with leading batch dimension already present if `pos_embed_has_batch == True`
250
+
251
+ """
252
+ if pos_embed_has_batch :
253
+ # Pos embed already includes batch dim
254
+ _assert (pos_embed .ndim >= 3 , 'Incorrect number of dimensions' ) # At least [batch, seq_len, pos_embed_dim]
255
+ else :
256
+ # Add batch dimension and expand to batch size
257
+ _assert (pos_embed .ndim >= 2 , 'Incorrect number of dimensions' ) # At least [seq_len, pos_embed_dim]
258
+ expand_shape = (x .shape [0 ],) + (- 1 ,) * pos_embed .ndim
259
+ pos_embed = pos_embed .unsqueeze (0 ).expand (expand_shape )
260
+
261
+ # Reshape keep_indices to add singleton dims
262
+ keep_shape = (keep_indices .shape [0 ],) + (1 ,) * (pos_embed .ndim - 3 ) + (keep_indices .shape [1 ], 1 )
263
+ keep_indices = keep_indices .view (keep_shape )
264
+
265
+ # Expand all dims to match position embedding except the gather dim (second-last)
266
+ keep_expand = list (pos_embed .shape )
267
+ keep_expand [- 2 ] = - 1
268
+ keep_indices = keep_indices .expand (keep_expand )
269
+
270
+ return pos_embed .gather (- 2 , keep_indices )
241
271
242
272
243
273
def build_rotary_pos_embed (
@@ -484,6 +514,59 @@ def get_embed(self, shape: Optional[List[int]] = None):
484
514
else :
485
515
assert False , "get_embed() requires pre-computed pos embed or valid shape w/ pre-computed bands"
486
516
517
+ def get_batch_embeds (
518
+ self ,
519
+ shapes : List [Tuple [int , int ]],
520
+ seq_len : Optional [int ] = None ,
521
+ ) -> Union [torch .Tensor , List [torch .Tensor ]]:
522
+ """Generate ROPE embeddings for multiple grid shapes efficiently.
523
+
524
+ Computes embeddings for the maximum grid size once, then extracts
525
+ and flattens the relevant portions for each requested shape.
526
+
527
+ Args:
528
+ shapes: List of (H, W) tuples representing different grid sizes
529
+
530
+ Returns:
531
+ List of concatenated sin/cos embeddings for each shape,
532
+ where each tensor has shape (H*W, dim)
533
+ """
534
+ if not shapes :
535
+ return []
536
+
537
+ # Check if we have pre-computed bands
538
+ if self .bands is None :
539
+ # If we have pre-computed pos_embed for a fixed shape, we can't do batch generation
540
+ raise RuntimeError ("Batch embedding generation requires cached bands, not pre-computed embeddings" )
541
+
542
+ # Find max dimensions across all shapes
543
+ max_h = max (h for h , w in shapes )
544
+ max_w = max (w for h , w in shapes )
545
+
546
+ # Generate embeddings for max size ONCE
547
+ sin_emb , cos_emb = build_rotary_pos_embed (
548
+ feat_shape = (max_h , max_w ),
549
+ bands = self .bands ,
550
+ in_pixels = self .in_pixels ,
551
+ ref_feat_shape = self .ref_feat_shape ,
552
+ grid_offset = self .grid_offset ,
553
+ grid_indexing = self .grid_indexing ,
554
+ )
555
+
556
+ # sin_emb and cos_emb are (max_h * max_w, dim//2)
557
+ # concat and reshape to 2D for slicing
558
+ rope_embed_2d = torch .cat ([sin_emb , cos_emb ], dim = - 1 ).view (max_h , max_w , - 1 )
559
+
560
+ if seq_len is not None :
561
+ flat_embeds = torch .zeros (len (shapes ), seq_len , rope_embed_2d .shape [- 1 ]).type_as (sin_emb )
562
+ for i , (h , w ) in enumerate (shapes ):
563
+ src_len = h * w
564
+ flat_embeds [i , :src_len ] = rope_embed_2d [:h , :w ].reshape (src_len , - 1 )
565
+ return flat_embeds
566
+ else :
567
+ flat_embeds_list = [rope_embed_2d [:h , :w ].reshape (h * w , - 1 ) for h , w in shapes ]
568
+ return flat_embeds_list
569
+
487
570
def forward (self , x ):
488
571
# assuming channel-first tensor where spatial dim are >= 2
489
572
pos_embed = self .get_embed (x .shape [2 :])
@@ -642,6 +725,62 @@ def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
642
725
643
726
return get_mixed_freqs (self .freqs , t_x , t_y )
644
727
728
+ def get_batch_embeds (
729
+ self ,
730
+ shapes : List [Tuple [int , int ]],
731
+ seq_len : Optional [int ] = None ,
732
+ ) -> Union [torch .Tensor , List [torch .Tensor ]]:
733
+ """Generate ROPE embeddings for multiple grid shapes efficiently.
734
+
735
+ Computes embeddings for the maximum grid size once, then extracts
736
+ and flattens the relevant portions for each requested shape.
737
+
738
+ Args:
739
+ shapes: List of (H, W) tuples representing different grid sizes
740
+ seq_len: If provided, return padded tensor of this length. Otherwise return list.
741
+
742
+ Returns:
743
+ If seq_len is provided: Padded tensor of shape (len(shapes), depth, num_heads, seq_len, dim)
744
+ Otherwise: List of tensors with shape (depth, num_heads, H*W, dim) for each shape
745
+ """
746
+ if not shapes :
747
+ return []
748
+
749
+ # Find max dimensions
750
+ max_h = max (h for h , w in shapes )
751
+ max_w = max (w for h , w in shapes )
752
+
753
+ # Generate embeddings for max size ONCE
754
+ t_x , t_y = get_mixed_grid (
755
+ [max_h , max_w ],
756
+ grid_indexing = self .grid_indexing ,
757
+ device = self .freqs .device
758
+ )
759
+ max_embed = get_mixed_freqs (self .freqs , t_x , t_y ) # (depth, num_heads, max_h*max_w, dim)
760
+
761
+ # Reshape to 2D grid for easy slicing
762
+ depth , num_heads , _ , dim = max_embed .shape
763
+ max_embed_2d = max_embed .view (depth , num_heads , max_h , max_w , dim )
764
+
765
+ if seq_len is not None :
766
+ # Return padded tensor
767
+ B = len (shapes )
768
+ padded = torch .zeros (B , depth , num_heads , seq_len , dim , device = self .freqs .device , dtype = self .freqs .dtype )
769
+ for i , (h , w ) in enumerate (shapes ):
770
+ # Slice and flatten
771
+ embed_slice = max_embed_2d [:, :, :h , :w ].reshape (depth , num_heads , h * w , dim )
772
+ actual_len = h * w
773
+ padded [i , :, :, :actual_len ] = embed_slice
774
+ return padded
775
+ else :
776
+ # Return list
777
+ results = []
778
+ for h , w in shapes :
779
+ # Slice and flatten
780
+ embed_slice = max_embed_2d [:, :, :h , :w ].reshape (depth , num_heads , h * w , dim )
781
+ results .append (embed_slice )
782
+ return results
783
+
645
784
def forward (self , x ):
646
785
# assuming channel-first tensor where spatial dim are >= 2
647
786
pos_embed = self .get_embed (x .shape [2 :])
0 commit comments