Skip to content

Commit 68790f9

Browse files
committed
Fix model patch sizes != 16 in validate.py for naflex_loader use
1 parent 6b0842e commit 68790f9

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

validate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,10 @@ def validate(args):
300300

301301
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
302302
if args.naflex_loader:
303+
model_patch_size = None
304+
if hasattr(model, 'embeds') and hasattr(model.embeds, 'patch_size'):
305+
# NaFlexVit models have embeds.patch_size
306+
model_patch_size = model.embeds.patch_size
303307
from timm.data import create_naflex_loader
304308
loader = create_naflex_loader(
305309
dataset,
@@ -315,7 +319,7 @@ def validate(args):
315319
pin_memory=args.pin_mem,
316320
device=device,
317321
img_dtype=model_dtype or torch.float32,
318-
patch_size=16, # Could be derived from model config
322+
patch_size=model_patch_size or (16, 16),
319323
max_seq_len=args.naflex_max_seq_len,
320324
)
321325
else:

0 commit comments

Comments
 (0)