Skip to content

Commit f8bec5e

Browse files
committed
able to project the image embedding before applying time positional embedding for accept video wrapper
1 parent 297e7d0 commit f8bec5e

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.11.5',
9+
version = '1.11.6',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description = long_description,

vit_pytorch/accept_video_wrapper.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from torch import is_tensor, randn
5-
from torch.nn import Module, Parameter
5+
from torch.nn import Module, Linear, Parameter
66
from torch.utils._pytree import tree_flatten, tree_unflatten
77

88
from einops import rearrange, repeat
@@ -26,7 +26,8 @@ def __init__(
2626
dim_emb = None,
2727
time_seq_len = None,
2828
embed_is_channel_first = False,
29-
output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
29+
output_pos_add_pos_emb = 0, # defaults to first output position to add embedding
30+
proj_embed_to_dim = None
3031
):
3132
super().__init__()
3233
self.image_net = image_net
@@ -35,11 +36,23 @@ def __init__(
3536
self.add_time_pos_emb = add_time_pos_emb
3637
self.output_pos_add_pos_emb = output_pos_add_pos_emb
3738

39+
# maybe project the image embedding
40+
41+
self.embed_proj = None
42+
43+
if exists(proj_embed_to_dim):
44+
assert exists(dim_emb), '`dim_emb` must be passed in'
45+
self.embed_proj = Linear(dim_emb, proj_embed_to_dim)
46+
47+
# time positional embedding
48+
3849
if add_time_pos_emb:
3950
assert exists(dim_emb) and exists(time_seq_len), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output'
4051
self.time_seq_len = time_seq_len
4152

42-
self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2)
53+
dim_pos_emb = default(proj_embed_to_dim, dim_emb)
54+
55+
self.pos_emb = Parameter(randn(time_seq_len, dim_pos_emb) * 1e-2)
4356

4457
self.embed_is_channel_first = embed_is_channel_first
4558

@@ -79,6 +92,15 @@ def forward(
7992

8093
outputs = tuple(rearrange(t, '(b t) ... -> b t ...', t = time) if is_tensor(t) and t.numel() > 1 else t for t in outputs)
8194

95+
# maybe project embedding
96+
97+
if exists(self.embed_proj):
98+
outputs = list(outputs)
99+
100+
embed = outputs[self.output_pos_add_pos_emb]
101+
102+
outputs[self.output_pos_add_pos_emb] = self.embed_proj(embed)
103+
82104
# maybe add time positional embedding
83105

84106
if add_time_pos_emb:
@@ -131,9 +153,9 @@ def forward(
131153
from vit_pytorch.extractor import Extractor
132154
v = Extractor(v)
133155

134-
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024)
156+
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024, proj_embed_to_dim = 512)
135157

136158
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
137159

138160
assert logits.shape == (1, 7, 1000)
139-
assert embeddings.shape == (1, 7, 65, 1024)
161+
assert embeddings.shape == (1, 7, 65, 512)

0 commit comments

Comments
 (0)