Skip to content

Commit 6b0842e

Browse files
committed
Cleanup eva kwarg handling for naflexvit adaptation, remove debug print
1 parent 9625dc2 commit 6b0842e

File tree

1 file changed

+10
-32
lines changed

1 file changed

+10
-32
lines changed

timm/models/naflexvit.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,42 +2003,21 @@ def _create_naflexvit_from_eva(
20032003
Returns:
20042004
NaFlexVit model instance
20052005
"""
2006-
# Map EVA-specific parameters to NaFlexVit equivalents
2007-
2008-
# Handle EVA's unique parameters
2009-
kwargs.pop('no_embed_class', None) # EVA specific, not used in NaFlexVit
2010-
2011-
# abs pos embed
2012-
use_abs_pos_emb = kwargs.pop('use_abs_pos_emb', True)
2006+
# Handle EVA's unique parameters & block args
2007+
kwargs.pop('no_embed_class', None) # EVA specific, not used in NaFlexVit (always no-embed)
20132008

20142009
# Map EVA's rope parameters
20152010
use_rot_pos_emb = kwargs.pop('use_rot_pos_emb', False)
20162011
rope_mixed_mode = kwargs.pop('rope_mixed_mode', False)
20172012
rope_temperature = kwargs.pop('rope_temperature', 10000.)
20182013
rope_grid_offset = kwargs.pop('rope_grid_offset', 0.)
20192014
rope_grid_indexing = kwargs.pop('rope_grid_indexing', 'ij')
2020-
2021-
# Get EVA's attn_type directly
2022-
attn_type = kwargs.pop('attn_type', 'eva')
2023-
2024-
# Determine rope_type based on EVA parameters
20252015
if use_rot_pos_emb:
20262016
rope_type = 'mixed' if rope_mixed_mode else 'axial'
20272017
else:
20282018
rope_type = 'none'
20292019

2030-
# Handle EVA's swiglu_mlp and scale_mlp
2031-
swiglu_mlp = kwargs.pop('swiglu_mlp', False)
2032-
scale_mlp = kwargs.pop('scale_mlp', False)
2033-
scale_attn_inner = kwargs.pop('scale_attn_inner', False)
2034-
2035-
# Map qkv_fused parameter
2036-
qkv_fused = kwargs.pop('qkv_fused', True)
2037-
2038-
# Handle register tokens
2039-
num_reg_tokens = kwargs.pop('num_reg_tokens', kwargs.get('reg_tokens', 0))
2040-
2041-
# Handle global pooling
2020+
# Handle global pooling logic to mirror EVA use
20422021
gp = kwargs.pop('global_pool', 'avg')
20432022
fc_norm = kwargs.pop('fc_norm', None)
20442023
if fc_norm is None and gp == 'avg':
@@ -2048,23 +2027,22 @@ def _create_naflexvit_from_eva(
20482027
naflex_kwargs = {
20492028
'pos_embed_grid_size': None, # rely on img_size (// patch_size)
20502029
'class_token': kwargs.get('class_token', True),
2051-
'reg_tokens': num_reg_tokens,
2030+
'reg_tokens': kwargs.pop('num_reg_tokens', kwargs.get('reg_tokens', 0)),
20522031
'global_pool': gp,
20532032
'fc_norm': fc_norm,
2054-
'pos_embed': 'learned' if use_abs_pos_emb else 'none',
2033+
'pos_embed': 'learned' if kwargs.pop('use_abs_pos_emb', True) else 'none',
20552034
'rope_type': rope_type,
20562035
'rope_temperature': rope_temperature,
20572036
'rope_grid_offset': rope_grid_offset,
20582037
'rope_grid_indexing': rope_grid_indexing,
20592038
'rope_ref_feat_shape': kwargs.get('ref_feat_shape', None),
2060-
'attn_type': attn_type,
2061-
'swiglu_mlp': swiglu_mlp,
2062-
'scale_mlp': scale_mlp,
2063-
'scale_attn_inner': scale_attn_inner,
2064-
'qkv_fused': qkv_fused,
2039+
'attn_type': kwargs.pop('attn_type', 'eva'),
2040+
'swiglu_mlp': kwargs.pop('swiglu_mlp', False),
2041+
'qkv_fused': kwargs.pop('qkv_fused', True),
2042+
'scale_mlp_norm': kwargs.pop('scale_mlp', False),
2043+
'scale_attn_inner_norm': kwargs.pop('scale_attn_inner', False),
20652044
**kwargs # Pass remaining kwargs through
20662045
}
2067-
print(naflex_kwargs)
20682046

20692047
return _create_naflexvit(variant, pretrained, **naflex_kwargs)
20702048

0 commit comments

Comments
 (0)