-
-
Notifications
You must be signed in to change notification settings - Fork 5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Eva forward pass results in the following error:
File "/home/suraj/Repositories/mm_seg/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/suraj/Repositories/mm_seg/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/suraj/Repositories/mm_seg/.venv/lib/python3.11/site-packages/timm/models/eva.py", line 180, in forward
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/suraj/Repositories/mm_seg/.venv/lib/python3.11/site-packages/timm/layers/pos_embed_sincos.py", line 234, in apply_rot_embed_cat
return x * cos_emb + rot(x) * sin_emb
~~^~~~~~~~~
RuntimeError: The size of tensor a (12) must match the size of tensor b (8) at non-singleton dimension 1
when use_rot_pos_emb=True
and patch_drop_rate > 0
To Reproduce
Steps to reproduce the behavior:
import torch
import timm
from timm.models import Eva
print(timm.__version__)
try:
print("Running Eva with patch_drop_rate=0.75, use_rot_pos_emb=True, batch size 1")
model = Eva(
patch_drop_rate=0.75,
use_rot_pos_emb=True,
)
x = torch.randn(1, 3, 224, 224)
out = model(x)
print("Output shape:", out.shape)
except Exception as e:
print("Error during Eva run with patch_drop_rate=0.75, batch size 1:", e)
try:
print("Running Eva with patch_drop_rate=0.75, use_rot_pos_emb=True, batch size 8")
x = torch.randn(8, 3, 224, 224)
out = model(x)
print("Output shape:", out.shape)
except Exception as e:
print("Error during Eva run with patch_drop_rate=0.75, batch size 8:", e)
try:
print("Running Eva with use_rot_pos_emb=True, batch size 8")
model = Eva(
use_rot_pos_emb=True,
)
x = torch.randn(8, 3, 224, 224)
out = model(x)
print("Output shape:", out.shape)
except Exception as e:
print("Error during Eva run with use_rot_pos_emb=True, batch size 8:", e)
Run the above code on timm > 1.0.17 and latest. This works well fine on 1.0.16
Expected behavior
This PR: 94e13b7#diff-bdbbd2d1e2b861d94d624f4811a259083e58111a2c53142609688113252377b0
removed code that handles the change in shape when
pytorch-image-models/timm/models/eva.py
Line 748 in b2034bb
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices) |
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working