Skip to content
4 changes: 2 additions & 2 deletions timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
unfreeze_batch_norm_2d,
)
from .padding import get_padding, get_same_padding, pad_same
from .patch_dropout import PatchDropout
from .patch_dropout import PatchDropout, PatchDropoutWithIndices, patch_dropout_forward
from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed
from .pool1d import global_pool_nlc
from .pool2d_same import AvgPool2dSame, create_pool2d
Expand Down Expand Up @@ -144,7 +144,7 @@
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .trace_utils import _assert, _float_to_int
from .typing import LayerType, PadType
from .typing import LayerType, PadType, disable_compiler
from .weight_init import (
trunc_normal_,
trunc_normal_tf_,
Expand Down
3 changes: 2 additions & 1 deletion timm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
norm_layer: Type[nn.Module] = None,
qk_norm: bool = False,
scale_norm: bool = False,
proj_bias: bool = True,
):
"""Initialize the Attention module.

Expand Down Expand Up @@ -161,7 +162,7 @@ def __init__(
self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity()
self.proj = nn.Linear(attn_dim, dim)
self.proj = nn.Linear(attn_dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)

def forward(
Expand Down
116 changes: 85 additions & 31 deletions timm/layers/patch_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,104 @@
import torch.nn as nn


def patch_dropout_forward(
x: torch.Tensor,
prob: float,
num_prefix_tokens: int,
ordered: bool,
training: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Common forward logic for patch dropout.

Args:
x: Input tensor of shape (B, L, D)
prob: Dropout probability
num_prefix_tokens: Number of prefix tokens to preserve
ordered: Whether to maintain patch order
training: Whether in training mode

Returns:
Tuple of (output tensor, keep_indices or None)
"""
if not training or prob == 0.:
return x, None

if num_prefix_tokens:
prefix_tokens, x = x[:, :num_prefix_tokens], x[:, num_prefix_tokens:]
else:
prefix_tokens = None

B = x.shape[0]
L = x.shape[1]
num_keep = max(1, int(L * (1. - prob)))
keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep]

if ordered:
# NOTE does not need to maintain patch order in typical transformer use,
# but possibly useful for debug / visualization
keep_indices = keep_indices.sort(dim=-1)[0]

x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))

if prefix_tokens is not None:
x = torch.cat((prefix_tokens, x), dim=1)

return x, keep_indices


class PatchDropout(nn.Module):
"""
Patch Dropout without returning indices.
https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
"""
return_indices: torch.jit.Final[bool]

def __init__(
self,
prob: float = 0.5,
num_prefix_tokens: int = 1,
ordered: bool = False,
return_indices: bool = False,
):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
self.ordered = ordered
self.return_indices = return_indices

def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
if not self.training or self.prob == 0.:
if self.return_indices:
return x, None
return x

if self.num_prefix_tokens:
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
else:
prefix_tokens = None

B = x.shape[0]
L = x.shape[1]
num_keep = max(1, int(L * (1. - self.prob)))
keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep]
if self.ordered:
# NOTE does not need to maintain patch order in typical transformer use,
# but possibly useful for debug / visualization
keep_indices = keep_indices.sort(dim=-1)[0]
x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))

if prefix_tokens is not None:
x = torch.cat((prefix_tokens, x), dim=1)

if self.return_indices:
return x, keep_indices
return x

def forward(self, x: torch.Tensor) -> torch.Tensor:
output, _ = patch_dropout_forward(
x,
self.prob,
self.num_prefix_tokens,
self.ordered,
self.training
)
return output


class PatchDropoutWithIndices(nn.Module):
"""
Patch Dropout that returns both output and keep indices.
https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
"""

def __init__(
self,
prob: float = 0.5,
num_prefix_tokens: int = 1,
ordered: bool = False,
):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
self.ordered = ordered

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
return patch_dropout_forward(
x,
self.prob,
self.num_prefix_tokens,
self.ordered,
self.training
)
148 changes: 144 additions & 4 deletions timm/layers/pos_embed_sincos.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,41 @@ def apply_rot_embed_cat(x: torch.Tensor, emb):
return x * cos_emb + rot(x) * sin_emb


def apply_keep_indices_nlc(x, pos_embed, keep_indices):
pos_embed = pos_embed.unsqueeze(0).expand(x.shape[0], -1, -1)
pos_embed = pos_embed.gather(1, keep_indices.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1]))
return pos_embed
def apply_keep_indices_nlc(
x: torch.Tensor,
pos_embed: torch.Tensor,
keep_indices: torch.Tensor,
pos_embed_has_batch: bool = False,
) -> torch.Tensor:
""" Apply keep indices to different ROPE shapes

Expected pos_embed shapes:
* [seq_len, pos_embed_dim] --> output [batch_size, seq_len, pos_embed_dim]
* [num_heads, seq_len, pos_embed_dim] --> output [batch_size, num_heads, seq_len, pos_embed_dim]
* [depth, num_heads, seq_len, pos_embed_dim] --> output [batch_size, depth, num_heads, seq_len, pos_embed_dim]

And all of the above with leading batch dimension already present if `pos_embed_has_batch == True`

"""
if pos_embed_has_batch:
# Pos embed already includes batch dim
_assert(pos_embed.ndim >= 3, 'Incorrect number of dimensions') # At least [batch, seq_len, pos_embed_dim]
else:
# Add batch dimension and expand to batch size
_assert(pos_embed.ndim >= 2, 'Incorrect number of dimensions') # At least [seq_len, pos_embed_dim]
expand_shape = (x.shape[0],) + (-1,) * pos_embed.ndim
pos_embed = pos_embed.unsqueeze(0).expand(expand_shape)

# Reshape keep_indices to add singleton dims
keep_shape = (keep_indices.shape[0],) + (1,) * (pos_embed.ndim - 3) + (keep_indices.shape[1], 1)
keep_indices = keep_indices.view(keep_shape)

# Expand all dims to match position embedding except the gather dim (second-last)
keep_expand = list(pos_embed.shape)
keep_expand[-2] = -1
keep_indices = keep_indices.expand(keep_expand)

return pos_embed.gather(-2, keep_indices)


def build_rotary_pos_embed(
Expand Down Expand Up @@ -484,6 +515,59 @@ def get_embed(self, shape: Optional[List[int]] = None):
else:
assert False, "get_embed() requires pre-computed pos embed or valid shape w/ pre-computed bands"

def get_batch_embeds(
self,
shapes: List[Tuple[int, int]],
seq_len: Optional[int] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Generate ROPE embeddings for multiple grid shapes efficiently.

Computes embeddings for the maximum grid size once, then extracts
and flattens the relevant portions for each requested shape.

Args:
shapes: List of (H, W) tuples representing different grid sizes

Returns:
List of concatenated sin/cos embeddings for each shape,
where each tensor has shape (H*W, dim)
"""
if not shapes:
return []

# Check if we have pre-computed bands
if self.bands is None:
# If we have pre-computed pos_embed for a fixed shape, we can't do batch generation
raise RuntimeError("Batch embedding generation requires cached bands, not pre-computed embeddings")

# Find max dimensions across all shapes
max_h = max(h for h, w in shapes)
max_w = max(w for h, w in shapes)

# Generate embeddings for max size ONCE
sin_emb, cos_emb = build_rotary_pos_embed(
feat_shape=(max_h, max_w),
bands=self.bands,
in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
)

# sin_emb and cos_emb are (max_h * max_w, dim//2)
# concat and reshape to 2D for slicing
rope_embed_2d = torch.cat([sin_emb, cos_emb], dim=-1).view(max_h, max_w, -1)

if seq_len is not None:
flat_embeds = torch.zeros(len(shapes), seq_len, rope_embed_2d.shape[-1]).type_as(sin_emb)
for i, (h, w) in enumerate(shapes):
src_len = h * w
flat_embeds[i, :src_len] = rope_embed_2d[:h, :w].reshape(src_len, -1)
return flat_embeds
else:
flat_embeds_list = [rope_embed_2d[:h, :w].reshape(h * w, -1) for h, w in shapes]
return flat_embeds_list

def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
pos_embed = self.get_embed(x.shape[2:])
Expand Down Expand Up @@ -642,6 +726,62 @@ def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:

return get_mixed_freqs(self.freqs, t_x, t_y)

def get_batch_embeds(
self,
shapes: List[Tuple[int, int]],
seq_len: Optional[int] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Generate ROPE embeddings for multiple grid shapes efficiently.

Computes embeddings for the maximum grid size once, then extracts
and flattens the relevant portions for each requested shape.

Args:
shapes: List of (H, W) tuples representing different grid sizes
seq_len: If provided, return padded tensor of this length. Otherwise return list.

Returns:
If seq_len is provided: Padded tensor of shape (len(shapes), depth, num_heads, seq_len, dim)
Otherwise: List of tensors with shape (depth, num_heads, H*W, dim) for each shape
"""
if not shapes:
return []

# Find max dimensions
max_h = max(h for h, w in shapes)
max_w = max(w for h, w in shapes)

# Generate embeddings for max size ONCE
t_x, t_y = get_mixed_grid(
[max_h, max_w],
grid_indexing=self.grid_indexing,
device=self.freqs.device
)
max_embed = get_mixed_freqs(self.freqs, t_x, t_y) # (depth, num_heads, max_h*max_w, dim)

# Reshape to 2D grid for easy slicing
depth, num_heads, _, dim = max_embed.shape
max_embed_2d = max_embed.view(depth, num_heads, max_h, max_w, dim)

if seq_len is not None:
# Return padded tensor
B = len(shapes)
padded = torch.zeros(B, depth, num_heads, seq_len, dim, device=self.freqs.device, dtype=self.freqs.dtype)
for i, (h, w) in enumerate(shapes):
# Slice and flatten
embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim)
actual_len = h * w
padded[i, :, :, :actual_len] = embed_slice
return padded
else:
# Return list
results = []
for h, w in shapes:
# Slice and flatten
embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim)
results.append(embed_slice)
return results

def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
pos_embed = self.get_embed(x.shape[2:])
Expand Down
29 changes: 28 additions & 1 deletion timm/layers/typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,34 @@
from typing import Callable, Tuple, Type, Union
from contextlib import nullcontext
from functools import wraps
from typing import Callable, Tuple, Type, TypeVar, Union, overload, ContextManager

import torch

__all__ = ["LayerType", "PadType", "nullwrap", "disable_compiler"]


LayerType = Union[str, Callable, Type[torch.nn.Module]]
PadType = Union[str, int, Tuple[int, int]]

F = TypeVar("F", bound=Callable[..., object])


@overload
def nullwrap(fn: F) -> F: ... # decorator form

@overload
def nullwrap(fn: None = ...) -> ContextManager: ... # context‑manager form

def nullwrap(fn: F | None = None):
# as a context manager
if fn is None:
return nullcontext() # `with nullwrap():`

# as a decorator
@wraps(fn)
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)
return wrapper # `@nullwrap`


disable_compiler = getattr(getattr(torch, "compiler", None), "disable", None) or nullwrap
Loading