Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 73 additions & 29 deletions timm/layers/pos_embed_sincos.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def __init__(
self.dim = dim
self.max_res = max_res
self.temperature = temperature
self.linear_bands = linear_bands
self.in_pixels = in_pixels
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
Expand Down Expand Up @@ -383,17 +384,7 @@ def __init__(
self.pos_embed_cos = None
else:
# cache full sin/cos embeddings if shape provided up front
emb_sin, emb_cos = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=dim,
max_res=max_res,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
temperature=self.temperature,
)
emb_sin, emb_cos = self._get_pos_embed_values(feat_shape)
self.bands = None
self.register_buffer(
'pos_embed_sin',
Expand All @@ -406,6 +397,30 @@ def __init__(
persistent=False,
)

def _get_pos_embed_values(self, feat_shape: List[int]):
emb_sin, emb_cos = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=self.dim,
max_res=self.max_res,
temperature=self.temperature,
linear_bands=self.linear_bands,
in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
)
return emb_sin, emb_cos

def update_feat_shape(self, feat_shape: List[int]):
if self.feat_shape is not None and feat_shape != self.feat_shape:
# only update if feat_shape was set and different from previous value
assert self.pos_embed_sin is not None
assert self.pos_embed_cos is not None
emb_sin, emb_cos = self._get_pos_embed_values(feat_shape)
self.pos_embed_sin = emb_sin.to(self.pos_embed_sin.device, self.pos_embed_sin.dtype)
self.pos_embed_cos = emb_cos.to(self.pos_embed_cos.device, self.pos_embed_cos.dtype)
self.feat_shape = feat_shape

def get_embed(self, shape: Optional[List[int]] = None):
if shape is not None and self.bands is not None:
# rebuild embeddings every call, use if target shape changes
Expand Down Expand Up @@ -453,6 +468,7 @@ def __init__(
self.max_res = max_res
self.temperature = temperature
self.in_pixels = in_pixels
self.linear_bands = linear_bands
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
self.grid_offset = grid_offset
Expand Down Expand Up @@ -480,27 +496,40 @@ def __init__(
self.pos_embed = None
else:
# cache full sin/cos embeddings if shape provided up front
embeds = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=dim,
max_res=max_res,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
temperature=self.temperature,
)
self.bands = None
self.register_buffer(
'pos_embed',
torch.cat(embeds, -1),
self._get_pos_embed_values(feat_shape=feat_shape),
persistent=False,
)

def _get_pos_embed_values(self, feat_shape: List[int]):
embeds = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=self.dim,
max_res=self.max_res,
temperature=self.temperature,
linear_bands=self.linear_bands,
in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
)
return torch.cat(embeds, -1)

def update_feat_shape(self, feat_shape: List[int]):
if self.feat_shape is not None and feat_shape != self.feat_shape:
# only update if feat_shape was set and different from previous value
assert self.pos_embed is not None
self.pos_embed = self._get_pos_embed_values(feat_shape).to(
device=self.pos_embed.device,
dtype=self.pos_embed.dtype,
)
self.feat_shape = feat_shape

def get_embed(self, shape: Optional[List[int]] = None):
if shape is not None and self.bands is not None:
# rebuild embeddings every call, use if target shape changes
# rebuild embeddings from cached bands every call, use if target shape changes
embeds = build_rotary_pos_embed(
shape,
self.bands,
Expand Down Expand Up @@ -684,6 +713,7 @@ def __init__(

head_dim = dim // num_heads
assert head_dim % 4 == 0, f"head_dim must be divisible by 4, got {head_dim}"

freqs = init_random_2d_freqs(
head_dim,
depth,
Expand All @@ -692,18 +722,32 @@ def __init__(
rotate=True,
) # (2, depth, num_heads, head_dim//2)
self.freqs = nn.Parameter(freqs)

if feat_shape is not None:
# cache pre-computed grid
t_x, t_y = get_mixed_grid(
feat_shape,
grid_indexing=grid_indexing,
device=self.freqs.device
)
t_x, t_y = self._get_grid_values(feat_shape)
self.register_buffer('t_x', t_x, persistent=False)
self.register_buffer('t_y', t_y, persistent=False)
else:
self.t_x = self.t_y = None

def _get_grid_values(self, feat_shape: Optional[List[int]]):
t_x, t_y = get_mixed_grid(
feat_shape,
grid_indexing=self.grid_indexing,
device=self.freqs.device
)
return t_x, t_y

def update_feat_shape(self, feat_shape: Optional[List[int]]):
if self.feat_shape is not None and feat_shape != self.feat_shape:
assert self.t_x is not None
assert self.t_y is not None
t_x, t_y = self._get_grid_values(feat_shape)
self.t_x = t_x.to(self.t_x.device, self.t_x.dtype)
self.t_y = t_y.to(self.t_y.device, self.t_y.dtype)
self.feat_shape = feat_shape

def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
"""Generate rotary embeddings for the given spatial shape.

Expand Down
29 changes: 29 additions & 0 deletions timm/models/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,35 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None)
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

def set_input_size(
self,
img_size: Optional[Tuple[int, int]] = None,
patch_size: Optional[Tuple[int, int]] = None,
) -> None:
"""Update the input image resolution and patch size.

Args:
img_size: New input resolution, if None current resolution is used.
patch_size: New patch size, if None existing patch size is used.
"""
prev_grid_size = self.patch_embed.grid_size
self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)

if self.pos_embed is not None:
num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens
if num_new_tokens != self.pos_embed.shape[1]:
self.pos_embed = nn.Parameter(resample_abs_pos_embed(
self.pos_embed,
new_size=self.patch_embed.grid_size,
old_size=prev_grid_size,
num_prefix_tokens=num_prefix_tokens,
verbose=True,
))

if self.rope is not None:
self.rope.update_feat_shape(self.patch_embed.grid_size)

def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.dynamic_img_size:
B, H, W, C = x.shape
Expand Down