@@ -2017,19 +2017,23 @@ def _create_naflexvit_from_eva(
2017
2017
else :
2018
2018
rope_type = 'none'
2019
2019
2020
- # Handle global pooling logic to mirror EVA use
2020
+ # Handle norm/pool resolution logic to mirror EVA
2021
2021
gp = kwargs .pop ('global_pool' , 'avg' )
2022
- fc_norm = kwargs .pop ('fc_norm' , None )
2023
- if fc_norm is None and gp == 'avg' :
2024
- fc_norm = True
2022
+ use_pre_transformer_norm = kwargs .pop ('use_pre_transformer_norm' , False )
2023
+ use_post_transformer_norm = kwargs .pop ('use_post_transformer_norm' , True )
2024
+ use_fc_norm = kwargs .pop ('use_fc_norm' , None )
2025
+ if use_fc_norm is None :
2026
+ use_fc_norm = gp == 'avg' # default on if avg pool used
2025
2027
2026
2028
# Set NaFlexVit-specific parameters
2027
2029
naflex_kwargs = {
2028
2030
'pos_embed_grid_size' : None , # rely on img_size (// patch_size)
2029
2031
'class_token' : kwargs .get ('class_token' , True ),
2030
2032
'reg_tokens' : kwargs .pop ('num_reg_tokens' , kwargs .get ('reg_tokens' , 0 )),
2031
2033
'global_pool' : gp ,
2032
- 'fc_norm' : fc_norm ,
2034
+ 'pre_norm' : use_pre_transformer_norm ,
2035
+ 'final_norm' : use_post_transformer_norm ,
2036
+ 'fc_norm' : use_fc_norm ,
2033
2037
'pos_embed' : 'learned' if kwargs .pop ('use_abs_pos_emb' , True ) else 'none' ,
2034
2038
'rope_type' : rope_type ,
2035
2039
'rope_temperature' : rope_temperature ,
0 commit comments