Skip to content

[BUG] Eva model crashes on applying RotaryEmbeddingCat when patch drop rate is set #2549

@surajpaib

Description

@surajpaib

Describe the bug
Eva forward pass results in the following error:

  File "/home/suraj/Repositories/mm_seg/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/suraj/Repositories/mm_seg/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/suraj/Repositories/mm_seg/.venv/lib/python3.11/site-packages/timm/models/eva.py", line 180, in forward
    q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/suraj/Repositories/mm_seg/.venv/lib/python3.11/site-packages/timm/layers/pos_embed_sincos.py", line 234, in apply_rot_embed_cat
    return x * cos_emb + rot(x) * sin_emb
           ~~^~~~~~~~~
RuntimeError: The size of tensor a (12) must match the size of tensor b (8) at non-singleton dimension 1

when use_rot_pos_emb=True and patch_drop_rate > 0

To Reproduce
Steps to reproduce the behavior:

import torch
import timm
from timm.models import Eva

print(timm.__version__)

try:
    print("Running Eva with patch_drop_rate=0.75, use_rot_pos_emb=True, batch size 1")
    model = Eva(
        patch_drop_rate=0.75,
        use_rot_pos_emb=True,
    )
    x = torch.randn(1, 3, 224, 224)
    out = model(x)
    print("Output shape:", out.shape)
except Exception as e:
    print("Error during Eva run with patch_drop_rate=0.75, batch size 1:", e)

try:
    print("Running Eva with patch_drop_rate=0.75, use_rot_pos_emb=True, batch size 8")
    x = torch.randn(8, 3, 224, 224)
    out = model(x)
    print("Output shape:", out.shape)
except Exception as e:
    print("Error during Eva run with patch_drop_rate=0.75, batch size 8:", e)

try:
    print("Running Eva with use_rot_pos_emb=True, batch size 8")
    model = Eva(
        use_rot_pos_emb=True,
    )
    x = torch.randn(8, 3, 224, 224)
    out = model(x)
    print("Output shape:", out.shape)
except Exception as e:
    print("Error during Eva run with use_rot_pos_emb=True, batch size 8:", e)

Run the above code on timm > 1.0.17 and latest. This works well fine on 1.0.16

Expected behavior
This PR: 94e13b7#diff-bdbbd2d1e2b861d94d624f4811a259083e58111a2c53142609688113252377b0
removed code that handles the change in shape when

rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
is applied to the positional embeddings.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions