2
2
3
3
import torch
4
4
from torch import is_tensor , randn
5
- from torch .nn import Module , Parameter
5
+ from torch .nn import Module , Linear , Parameter
6
6
from torch .utils ._pytree import tree_flatten , tree_unflatten
7
7
8
8
from einops import rearrange , repeat
@@ -26,7 +26,8 @@ def __init__(
26
26
dim_emb = None ,
27
27
time_seq_len = None ,
28
28
embed_is_channel_first = False ,
29
- output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
29
+ output_pos_add_pos_emb = 0 , # defaults to first output position to add embedding
30
+ proj_embed_to_dim = None
30
31
):
31
32
super ().__init__ ()
32
33
self .image_net = image_net
@@ -35,11 +36,23 @@ def __init__(
35
36
self .add_time_pos_emb = add_time_pos_emb
36
37
self .output_pos_add_pos_emb = output_pos_add_pos_emb
37
38
39
+ # maybe project the image embedding
40
+
41
+ self .embed_proj = None
42
+
43
+ if exists (proj_embed_to_dim ):
44
+ assert exists (dim_emb ), '`dim_emb` must be passed in'
45
+ self .embed_proj = Linear (dim_emb , proj_embed_to_dim )
46
+
47
+ # time positional embedding
48
+
38
49
if add_time_pos_emb :
39
50
assert exists (dim_emb ) and exists (time_seq_len ), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output'
40
51
self .time_seq_len = time_seq_len
41
52
42
- self .pos_emb = Parameter (randn (time_seq_len , dim_emb ) * 1e-2 )
53
+ dim_pos_emb = default (proj_embed_to_dim , dim_emb )
54
+
55
+ self .pos_emb = Parameter (randn (time_seq_len , dim_pos_emb ) * 1e-2 )
43
56
44
57
self .embed_is_channel_first = embed_is_channel_first
45
58
@@ -79,6 +92,15 @@ def forward(
79
92
80
93
outputs = tuple (rearrange (t , '(b t) ... -> b t ...' , t = time ) if is_tensor (t ) and t .numel () > 1 else t for t in outputs )
81
94
95
+ # maybe project embedding
96
+
97
+ if exists (self .embed_proj ):
98
+ outputs = list (outputs )
99
+
100
+ embed = outputs [self .output_pos_add_pos_emb ]
101
+
102
+ outputs [self .output_pos_add_pos_emb ] = self .embed_proj (embed )
103
+
82
104
# maybe add time positional embedding
83
105
84
106
if add_time_pos_emb :
@@ -131,9 +153,9 @@ def forward(
131
153
from vit_pytorch .extractor import Extractor
132
154
v = Extractor (v )
133
155
134
- video_acceptor = AcceptVideoWrapper (v , add_time_pos_emb = True , output_pos_add_pos_emb = 1 , time_seq_len = 12 , dim_emb = 1024 )
156
+ video_acceptor = AcceptVideoWrapper (v , add_time_pos_emb = True , output_pos_add_pos_emb = 1 , time_seq_len = 12 , dim_emb = 1024 , proj_embed_to_dim = 512 )
135
157
136
158
logits , embeddings = video_acceptor (videos , eval_with_no_grad = True ) # always (batch, channels, time, height, width) - time is always dimension 2
137
159
138
160
assert logits .shape == (1 , 7 , 1000 )
139
- assert embeddings .shape == (1 , 7 , 65 , 1024 )
161
+ assert embeddings .shape == (1 , 7 , 65 , 512 )
0 commit comments