|
14 | 14 | import math
|
15 | 15 | import re
|
16 | 16 | from functools import partial
|
17 |
| -from typing import Optional, Sequence, Tuple |
| 17 | +from typing import List, Optional, Sequence, Tuple, Union |
18 | 18 |
|
19 | 19 | import torch
|
20 | 20 | from torch import nn
|
21 | 21 |
|
22 | 22 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
23 | 23 | from timm.layers import trunc_normal_, to_2tuple
|
24 | 24 | from ._builder import build_model_with_cfg
|
| 25 | +from ._features import feature_take_indices |
25 | 26 | from ._registry import register_model, generate_default_cfgs
|
26 | 27 | from .vision_transformer import Block
|
27 | 28 |
|
@@ -254,6 +255,71 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
254 | 255 | if self.head_dist is not None:
|
255 | 256 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
256 | 257 |
|
| 258 | + def forward_intermediates( |
| 259 | + self, |
| 260 | + x: torch.Tensor, |
| 261 | + indices: Optional[Union[int, List[int]]] = None, |
| 262 | + norm: bool = False, |
| 263 | + stop_early: bool = False, |
| 264 | + output_fmt: str = 'NCHW', |
| 265 | + intermediates_only: bool = False, |
| 266 | + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: |
| 267 | + """ Forward features that returns intermediates. |
| 268 | +
|
| 269 | + Args: |
| 270 | + x: Input image tensor |
| 271 | + indices: Take last n blocks if int, all if None, select matching indices if sequence |
| 272 | + norm: Apply norm layer to compatible intermediates |
| 273 | + stop_early: Stop iterating over blocks when last desired intermediate hit |
| 274 | + output_fmt: Shape of intermediate feature outputs |
| 275 | + intermediates_only: Only return intermediate features |
| 276 | + Returns: |
| 277 | +
|
| 278 | + """ |
| 279 | + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' |
| 280 | + intermediates = [] |
| 281 | + take_indices, max_index = feature_take_indices(len(self.transformers), indices) |
| 282 | + |
| 283 | + # forward pass |
| 284 | + x = self.patch_embed(x) |
| 285 | + x = self.pos_drop(x + self.pos_embed) |
| 286 | + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) |
| 287 | + |
| 288 | + last_idx = len(self.transformers) - 1 |
| 289 | + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript |
| 290 | + stages = self.transformers |
| 291 | + else: |
| 292 | + stages = self.transformers[:max_index + 1] |
| 293 | + |
| 294 | + for feat_idx, stage in enumerate(stages): |
| 295 | + x, cls_tokens = stage((x, cls_tokens)) |
| 296 | + if feat_idx in take_indices: |
| 297 | + intermediates.append(x) |
| 298 | + |
| 299 | + if intermediates_only: |
| 300 | + return intermediates |
| 301 | + |
| 302 | + if feat_idx == last_idx: |
| 303 | + cls_tokens = self.norm(cls_tokens) |
| 304 | + |
| 305 | + return cls_tokens, intermediates |
| 306 | + |
| 307 | + def prune_intermediate_layers( |
| 308 | + self, |
| 309 | + indices: Union[int, List[int]] = 1, |
| 310 | + prune_norm: bool = False, |
| 311 | + prune_head: bool = True, |
| 312 | + ): |
| 313 | + """ Prune layers not required for specified intermediates. |
| 314 | + """ |
| 315 | + take_indices, max_index = feature_take_indices(len(self.transformers), indices) |
| 316 | + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 |
| 317 | + if prune_norm: |
| 318 | + self.norm = nn.Identity() |
| 319 | + if prune_head: |
| 320 | + self.reset_classifier(0, '') |
| 321 | + return take_indices |
| 322 | + |
257 | 323 | def forward_features(self, x):
|
258 | 324 | x = self.patch_embed(x)
|
259 | 325 | x = self.pos_drop(x + self.pos_embed)
|
|
0 commit comments