@@ -2003,42 +2003,21 @@ def _create_naflexvit_from_eva(
2003
2003
Returns:
2004
2004
NaFlexVit model instance
2005
2005
"""
2006
- # Map EVA-specific parameters to NaFlexVit equivalents
2007
-
2008
- # Handle EVA's unique parameters
2009
- kwargs .pop ('no_embed_class' , None ) # EVA specific, not used in NaFlexVit
2010
-
2011
- # abs pos embed
2012
- use_abs_pos_emb = kwargs .pop ('use_abs_pos_emb' , True )
2006
+ # Handle EVA's unique parameters & block args
2007
+ kwargs .pop ('no_embed_class' , None ) # EVA specific, not used in NaFlexVit (always no-embed)
2013
2008
2014
2009
# Map EVA's rope parameters
2015
2010
use_rot_pos_emb = kwargs .pop ('use_rot_pos_emb' , False )
2016
2011
rope_mixed_mode = kwargs .pop ('rope_mixed_mode' , False )
2017
2012
rope_temperature = kwargs .pop ('rope_temperature' , 10000. )
2018
2013
rope_grid_offset = kwargs .pop ('rope_grid_offset' , 0. )
2019
2014
rope_grid_indexing = kwargs .pop ('rope_grid_indexing' , 'ij' )
2020
-
2021
- # Get EVA's attn_type directly
2022
- attn_type = kwargs .pop ('attn_type' , 'eva' )
2023
-
2024
- # Determine rope_type based on EVA parameters
2025
2015
if use_rot_pos_emb :
2026
2016
rope_type = 'mixed' if rope_mixed_mode else 'axial'
2027
2017
else :
2028
2018
rope_type = 'none'
2029
2019
2030
- # Handle EVA's swiglu_mlp and scale_mlp
2031
- swiglu_mlp = kwargs .pop ('swiglu_mlp' , False )
2032
- scale_mlp = kwargs .pop ('scale_mlp' , False )
2033
- scale_attn_inner = kwargs .pop ('scale_attn_inner' , False )
2034
-
2035
- # Map qkv_fused parameter
2036
- qkv_fused = kwargs .pop ('qkv_fused' , True )
2037
-
2038
- # Handle register tokens
2039
- num_reg_tokens = kwargs .pop ('num_reg_tokens' , kwargs .get ('reg_tokens' , 0 ))
2040
-
2041
- # Handle global pooling
2020
+ # Handle global pooling logic to mirror EVA use
2042
2021
gp = kwargs .pop ('global_pool' , 'avg' )
2043
2022
fc_norm = kwargs .pop ('fc_norm' , None )
2044
2023
if fc_norm is None and gp == 'avg' :
@@ -2048,23 +2027,22 @@ def _create_naflexvit_from_eva(
2048
2027
naflex_kwargs = {
2049
2028
'pos_embed_grid_size' : None , # rely on img_size (// patch_size)
2050
2029
'class_token' : kwargs .get ('class_token' , True ),
2051
- 'reg_tokens' : num_reg_tokens ,
2030
+ 'reg_tokens' : kwargs . pop ( ' num_reg_tokens' , kwargs . get ( 'reg_tokens' , 0 )) ,
2052
2031
'global_pool' : gp ,
2053
2032
'fc_norm' : fc_norm ,
2054
- 'pos_embed' : 'learned' if use_abs_pos_emb else 'none' ,
2033
+ 'pos_embed' : 'learned' if kwargs . pop ( ' use_abs_pos_emb' , True ) else 'none' ,
2055
2034
'rope_type' : rope_type ,
2056
2035
'rope_temperature' : rope_temperature ,
2057
2036
'rope_grid_offset' : rope_grid_offset ,
2058
2037
'rope_grid_indexing' : rope_grid_indexing ,
2059
2038
'rope_ref_feat_shape' : kwargs .get ('ref_feat_shape' , None ),
2060
- 'attn_type' : attn_type ,
2061
- 'swiglu_mlp' : swiglu_mlp ,
2062
- 'scale_mlp ' : scale_mlp ,
2063
- 'scale_attn_inner ' : scale_attn_inner ,
2064
- 'qkv_fused ' : qkv_fused ,
2039
+ 'attn_type' : kwargs . pop ( ' attn_type' , 'eva' ) ,
2040
+ 'swiglu_mlp' : kwargs . pop ( ' swiglu_mlp' , False ) ,
2041
+ 'qkv_fused ' : kwargs . pop ( 'qkv_fused' , True ) ,
2042
+ 'scale_mlp_norm ' : kwargs . pop ( 'scale_mlp' , False ) ,
2043
+ 'scale_attn_inner_norm ' : kwargs . pop ( 'scale_attn_inner' , False ) ,
2065
2044
** kwargs # Pass remaining kwargs through
2066
2045
}
2067
- print (naflex_kwargs )
2068
2046
2069
2047
return _create_naflexvit (variant , pretrained , ** naflex_kwargs )
2070
2048
0 commit comments