Skip to content

Commit 6fb3536

Browse files
committed
Fix PE ViT norm layers in NaFlexVit. Still need a PE Giant fix...
1 parent 68790f9 commit 6fb3536

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

timm/models/naflexvit.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,19 +2017,23 @@ def _create_naflexvit_from_eva(
20172017
else:
20182018
rope_type = 'none'
20192019

2020-
# Handle global pooling logic to mirror EVA use
2020+
# Handle norm/pool resolution logic to mirror EVA
20212021
gp = kwargs.pop('global_pool', 'avg')
2022-
fc_norm = kwargs.pop('fc_norm', None)
2023-
if fc_norm is None and gp == 'avg':
2024-
fc_norm = True
2022+
use_pre_transformer_norm = kwargs.pop('use_pre_transformer_norm', False)
2023+
use_post_transformer_norm = kwargs.pop('use_post_transformer_norm', True)
2024+
use_fc_norm = kwargs.pop('use_fc_norm', None)
2025+
if use_fc_norm is None:
2026+
use_fc_norm = gp == 'avg' # default on if avg pool used
20252027

20262028
# Set NaFlexVit-specific parameters
20272029
naflex_kwargs = {
20282030
'pos_embed_grid_size': None, # rely on img_size (// patch_size)
20292031
'class_token': kwargs.get('class_token', True),
20302032
'reg_tokens': kwargs.pop('num_reg_tokens', kwargs.get('reg_tokens', 0)),
20312033
'global_pool': gp,
2032-
'fc_norm': fc_norm,
2034+
'pre_norm': use_pre_transformer_norm,
2035+
'final_norm': use_post_transformer_norm,
2036+
'fc_norm': use_fc_norm,
20332037
'pos_embed': 'learned' if kwargs.pop('use_abs_pos_emb', True) else 'none',
20342038
'rope_type': rope_type,
20352039
'rope_temperature': rope_temperature,

0 commit comments

Comments
 (0)