Skip to content

Commit 6063f82

Browse files
committed
Add ROPE support to NaFlexVit (axial and mixed), and support loading of most (all?) EVA based vit models & weights. Fix PatchDropout use w/ ROPE and NaFlexVit. Fix #2549
1 parent b2034bb commit 6063f82

File tree

7 files changed

+883
-171
lines changed

7 files changed

+883
-171
lines changed

timm/layers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104
unfreeze_batch_norm_2d,
105105
)
106106
from .padding import get_padding, get_same_padding, pad_same
107-
from .patch_dropout import PatchDropout
107+
from .patch_dropout import PatchDropout, PatchDropoutWithIndices, patch_dropout_forward
108108
from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed
109109
from .pool1d import global_pool_nlc
110110
from .pool2d_same import AvgPool2dSame, create_pool2d
@@ -144,7 +144,7 @@
144144
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
145145
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
146146
from .trace_utils import _assert, _float_to_int
147-
from .typing import LayerType, PadType
147+
from .typing import LayerType, PadType, disable_compiler
148148
from .weight_init import (
149149
trunc_normal_,
150150
trunc_normal_tf_,

timm/layers/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(
120120
norm_layer: Type[nn.Module] = None,
121121
qk_norm: bool = False,
122122
scale_norm: bool = False,
123+
proj_bias: bool = True,
123124
):
124125
"""Initialize the Attention module.
125126
@@ -161,7 +162,7 @@ def __init__(
161162
self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
162163
self.attn_drop = nn.Dropout(attn_drop)
163164
self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity()
164-
self.proj = nn.Linear(attn_dim, dim)
165+
self.proj = nn.Linear(attn_dim, dim, bias=proj_bias)
165166
self.proj_drop = nn.Dropout(proj_drop)
166167

167168
def forward(

timm/layers/patch_dropout.py

Lines changed: 85 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,50 +4,104 @@
44
import torch.nn as nn
55

66

7+
def patch_dropout_forward(
8+
x: torch.Tensor,
9+
prob: float,
10+
num_prefix_tokens: int,
11+
ordered: bool,
12+
training: bool,
13+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
14+
"""
15+
Common forward logic for patch dropout.
16+
17+
Args:
18+
x: Input tensor of shape (B, L, D)
19+
prob: Dropout probability
20+
num_prefix_tokens: Number of prefix tokens to preserve
21+
ordered: Whether to maintain patch order
22+
training: Whether in training mode
23+
24+
Returns:
25+
Tuple of (output tensor, keep_indices or None)
26+
"""
27+
if not training or prob == 0.:
28+
return x, None
29+
30+
if num_prefix_tokens:
31+
prefix_tokens, x = x[:, :num_prefix_tokens], x[:, num_prefix_tokens:]
32+
else:
33+
prefix_tokens = None
34+
35+
B = x.shape[0]
36+
L = x.shape[1]
37+
num_keep = max(1, int(L * (1. - prob)))
38+
keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep]
39+
40+
if ordered:
41+
# NOTE does not need to maintain patch order in typical transformer use,
42+
# but possibly useful for debug / visualization
43+
keep_indices = keep_indices.sort(dim=-1)[0]
44+
45+
x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
46+
47+
if prefix_tokens is not None:
48+
x = torch.cat((prefix_tokens, x), dim=1)
49+
50+
return x, keep_indices
51+
52+
753
class PatchDropout(nn.Module):
854
"""
55+
Patch Dropout without returning indices.
956
https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
1057
"""
11-
return_indices: torch.jit.Final[bool]
1258

1359
def __init__(
1460
self,
1561
prob: float = 0.5,
1662
num_prefix_tokens: int = 1,
1763
ordered: bool = False,
18-
return_indices: bool = False,
1964
):
2065
super().__init__()
2166
assert 0 <= prob < 1.
2267
self.prob = prob
2368
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
2469
self.ordered = ordered
25-
self.return_indices = return_indices
26-
27-
def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
28-
if not self.training or self.prob == 0.:
29-
if self.return_indices:
30-
return x, None
31-
return x
32-
33-
if self.num_prefix_tokens:
34-
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
35-
else:
36-
prefix_tokens = None
37-
38-
B = x.shape[0]
39-
L = x.shape[1]
40-
num_keep = max(1, int(L * (1. - self.prob)))
41-
keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep]
42-
if self.ordered:
43-
# NOTE does not need to maintain patch order in typical transformer use,
44-
# but possibly useful for debug / visualization
45-
keep_indices = keep_indices.sort(dim=-1)[0]
46-
x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
47-
48-
if prefix_tokens is not None:
49-
x = torch.cat((prefix_tokens, x), dim=1)
50-
51-
if self.return_indices:
52-
return x, keep_indices
53-
return x
70+
71+
def forward(self, x: torch.Tensor) -> torch.Tensor:
72+
output, _ = patch_dropout_forward(
73+
x,
74+
self.prob,
75+
self.num_prefix_tokens,
76+
self.ordered,
77+
self.training
78+
)
79+
return output
80+
81+
82+
class PatchDropoutWithIndices(nn.Module):
83+
"""
84+
Patch Dropout that returns both output and keep indices.
85+
https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
86+
"""
87+
88+
def __init__(
89+
self,
90+
prob: float = 0.5,
91+
num_prefix_tokens: int = 1,
92+
ordered: bool = False,
93+
):
94+
super().__init__()
95+
assert 0 <= prob < 1.
96+
self.prob = prob
97+
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
98+
self.ordered = ordered
99+
100+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
101+
return patch_dropout_forward(
102+
x,
103+
self.prob,
104+
self.num_prefix_tokens,
105+
self.ordered,
106+
self.training
107+
)

timm/layers/pos_embed_sincos.py

Lines changed: 143 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,40 @@ def apply_rot_embed_cat(x: torch.Tensor, emb):
234234
return x * cos_emb + rot(x) * sin_emb
235235

236236

237-
def apply_keep_indices_nlc(x, pos_embed, keep_indices):
238-
pos_embed = pos_embed.unsqueeze(0).expand(x.shape[0], -1, -1)
239-
pos_embed = pos_embed.gather(1, keep_indices.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1]))
240-
return pos_embed
237+
def apply_keep_indices_nlc(
238+
x: torch.Tensor,
239+
pos_embed: torch.Tensor,
240+
keep_indices: torch.Tensor,
241+
pos_embed_has_batch: bool = False,
242+
) -> torch.Tensor:
243+
""" Apply keep indices to different ROPE shapes
244+
Expected shapes:
245+
* pos_embed shape [seq_len, pos_embed_dim] → output [batch_size, seq_len, pos_embed_dim]
246+
* pos_embed shape [num_heads, seq_len, pos_embed_dim] → output [batch_size, num_heads, seq_len, pos_embed_dim]
247+
* pos_embed shape [depth, num_heads, seq_len, pos_embed_dim] → output [batch_size, depth, num_heads, seq_len, pos_embed_dim]
248+
249+
And all of the above with leading batch dimension already present if `pos_embed_has_batch == True`
250+
251+
"""
252+
if pos_embed_has_batch:
253+
# Pos embed already includes batch dim
254+
_assert(pos_embed.ndim >= 3, 'Incorrect number of dimensions') # At least [batch, seq_len, pos_embed_dim]
255+
else:
256+
# Add batch dimension and expand to batch size
257+
_assert(pos_embed.ndim >= 2, 'Incorrect number of dimensions') # At least [seq_len, pos_embed_dim]
258+
expand_shape = (x.shape[0],) + (-1,) * pos_embed.ndim
259+
pos_embed = pos_embed.unsqueeze(0).expand(expand_shape)
260+
261+
# Reshape keep_indices to add singleton dims
262+
keep_shape = (keep_indices.shape[0],) + (1,) * (pos_embed.ndim - 3) + (keep_indices.shape[1], 1)
263+
keep_indices = keep_indices.view(keep_shape)
264+
265+
# Expand all dims to match position embedding except the gather dim (second-last)
266+
keep_expand = list(pos_embed.shape)
267+
keep_expand[-2] = -1
268+
keep_indices = keep_indices.expand(keep_expand)
269+
270+
return pos_embed.gather(-2, keep_indices)
241271

242272

243273
def build_rotary_pos_embed(
@@ -484,6 +514,59 @@ def get_embed(self, shape: Optional[List[int]] = None):
484514
else:
485515
assert False, "get_embed() requires pre-computed pos embed or valid shape w/ pre-computed bands"
486516

517+
def get_batch_embeds(
518+
self,
519+
shapes: List[Tuple[int, int]],
520+
seq_len: Optional[int] = None,
521+
) -> Union[torch.Tensor, List[torch.Tensor]]:
522+
"""Generate ROPE embeddings for multiple grid shapes efficiently.
523+
524+
Computes embeddings for the maximum grid size once, then extracts
525+
and flattens the relevant portions for each requested shape.
526+
527+
Args:
528+
shapes: List of (H, W) tuples representing different grid sizes
529+
530+
Returns:
531+
List of concatenated sin/cos embeddings for each shape,
532+
where each tensor has shape (H*W, dim)
533+
"""
534+
if not shapes:
535+
return []
536+
537+
# Check if we have pre-computed bands
538+
if self.bands is None:
539+
# If we have pre-computed pos_embed for a fixed shape, we can't do batch generation
540+
raise RuntimeError("Batch embedding generation requires cached bands, not pre-computed embeddings")
541+
542+
# Find max dimensions across all shapes
543+
max_h = max(h for h, w in shapes)
544+
max_w = max(w for h, w in shapes)
545+
546+
# Generate embeddings for max size ONCE
547+
sin_emb, cos_emb = build_rotary_pos_embed(
548+
feat_shape=(max_h, max_w),
549+
bands=self.bands,
550+
in_pixels=self.in_pixels,
551+
ref_feat_shape=self.ref_feat_shape,
552+
grid_offset=self.grid_offset,
553+
grid_indexing=self.grid_indexing,
554+
)
555+
556+
# sin_emb and cos_emb are (max_h * max_w, dim//2)
557+
# concat and reshape to 2D for slicing
558+
rope_embed_2d = torch.cat([sin_emb, cos_emb], dim=-1).view(max_h, max_w, -1)
559+
560+
if seq_len is not None:
561+
flat_embeds = torch.zeros(len(shapes), seq_len, rope_embed_2d.shape[-1]).type_as(sin_emb)
562+
for i, (h, w) in enumerate(shapes):
563+
src_len = h * w
564+
flat_embeds[i, :src_len] = rope_embed_2d[:h, :w].reshape(src_len, -1)
565+
return flat_embeds
566+
else:
567+
flat_embeds_list = [rope_embed_2d[:h, :w].reshape(h * w, -1) for h, w in shapes]
568+
return flat_embeds_list
569+
487570
def forward(self, x):
488571
# assuming channel-first tensor where spatial dim are >= 2
489572
pos_embed = self.get_embed(x.shape[2:])
@@ -642,6 +725,62 @@ def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
642725

643726
return get_mixed_freqs(self.freqs, t_x, t_y)
644727

728+
def get_batch_embeds(
729+
self,
730+
shapes: List[Tuple[int, int]],
731+
seq_len: Optional[int] = None,
732+
) -> Union[torch.Tensor, List[torch.Tensor]]:
733+
"""Generate ROPE embeddings for multiple grid shapes efficiently.
734+
735+
Computes embeddings for the maximum grid size once, then extracts
736+
and flattens the relevant portions for each requested shape.
737+
738+
Args:
739+
shapes: List of (H, W) tuples representing different grid sizes
740+
seq_len: If provided, return padded tensor of this length. Otherwise return list.
741+
742+
Returns:
743+
If seq_len is provided: Padded tensor of shape (len(shapes), depth, num_heads, seq_len, dim)
744+
Otherwise: List of tensors with shape (depth, num_heads, H*W, dim) for each shape
745+
"""
746+
if not shapes:
747+
return []
748+
749+
# Find max dimensions
750+
max_h = max(h for h, w in shapes)
751+
max_w = max(w for h, w in shapes)
752+
753+
# Generate embeddings for max size ONCE
754+
t_x, t_y = get_mixed_grid(
755+
[max_h, max_w],
756+
grid_indexing=self.grid_indexing,
757+
device=self.freqs.device
758+
)
759+
max_embed = get_mixed_freqs(self.freqs, t_x, t_y) # (depth, num_heads, max_h*max_w, dim)
760+
761+
# Reshape to 2D grid for easy slicing
762+
depth, num_heads, _, dim = max_embed.shape
763+
max_embed_2d = max_embed.view(depth, num_heads, max_h, max_w, dim)
764+
765+
if seq_len is not None:
766+
# Return padded tensor
767+
B = len(shapes)
768+
padded = torch.zeros(B, depth, num_heads, seq_len, dim, device=self.freqs.device, dtype=self.freqs.dtype)
769+
for i, (h, w) in enumerate(shapes):
770+
# Slice and flatten
771+
embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim)
772+
actual_len = h * w
773+
padded[i, :, :, :actual_len] = embed_slice
774+
return padded
775+
else:
776+
# Return list
777+
results = []
778+
for h, w in shapes:
779+
# Slice and flatten
780+
embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim)
781+
results.append(embed_slice)
782+
return results
783+
645784
def forward(self, x):
646785
# assuming channel-first tensor where spatial dim are >= 2
647786
pos_embed = self.get_embed(x.shape[2:])

timm/layers/typing.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,34 @@
1-
from typing import Callable, Tuple, Type, Union
1+
from contextlib import nullcontext
2+
from functools import wraps
3+
from typing import Callable, Tuple, Type, TypeVar, Union, overload, ContextManager
24

35
import torch
46

7+
__all__ = ["LayerType", "PadType", "nullwrap", "disable_compiler"]
8+
59

610
LayerType = Union[str, Callable, Type[torch.nn.Module]]
711
PadType = Union[str, int, Tuple[int, int]]
12+
13+
F = TypeVar("F", bound=Callable[..., object])
14+
15+
16+
@overload
17+
def nullwrap(fn: F) -> F: ... # decorator form
18+
19+
@overload
20+
def nullwrap(fn: None = ...) -> ContextManager: ... # context‑manager form
21+
22+
def nullwrap(fn: F | None = None):
23+
# as a context manager
24+
if fn is None:
25+
return nullcontext() # `with nullwrap():`
26+
27+
# as a decorator
28+
@wraps(fn)
29+
def wrapper(*args, **kwargs):
30+
return fn(*args, **kwargs)
31+
return wrapper # `@nullwrap`
32+
33+
34+
disable_compiler = getattr(getattr(torch, "compiler", None), "disable", None) or nullwrap

0 commit comments

Comments
 (0)