Skip to content

Commit b22dc0e

Browse files
committed
add a wrapper for accepting video and processing the images individually, optionally able to add time positional embeddings - for use in two robotics work
1 parent db05a14 commit b22dc0e

File tree

2 files changed

+111
-1
lines changed

2 files changed

+111
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.10.1',
9+
version = '1.11.0',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description = long_description,

vit_pytorch/accept_video_wrapper.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch
2+
from torch import is_tensor, randn
3+
from torch.nn import Module, Parameter
4+
from torch.utils._pytree import tree_flatten, tree_unflatten
5+
6+
from einops import rearrange, repeat
7+
8+
# helper functions
9+
10+
def exists(v):
11+
return v is not None
12+
13+
def default(v, d):
14+
return v if exists(v) else d
15+
16+
# classes
17+
18+
class AcceptVideoWrapper(Module):
19+
def __init__(
20+
self,
21+
image_net: Module,
22+
forward_function = 'forward',
23+
add_time_pos_emb = False,
24+
dim_emb = None,
25+
time_seq_len = None,
26+
output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
27+
):
28+
super().__init__()
29+
self.image_net = image_net
30+
self.forward_function = forward_function # for openclip, used in TRI-LBM
31+
32+
self.add_time_pos_emb = add_time_pos_emb
33+
self.output_pos_add_pos_emb = output_pos_add_pos_emb
34+
35+
if add_time_pos_emb:
36+
assert exists(dim_emb) and exists(time_seq_len), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output'
37+
self.time_seq_len = time_seq_len
38+
39+
self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2)
40+
41+
def forward(
42+
self,
43+
video # (b c t h w)
44+
):
45+
add_time_pos_emb = self.add_time_pos_emb
46+
batch, time = video.shape[0], video.shape[2]
47+
48+
# maybe validate time positional embedding
49+
50+
if add_time_pos_emb:
51+
assert time <= self.time_seq_len, f'received video with {time} frames but `time_seq_len` ({self.time_seq_len}) is too low'
52+
53+
video = rearrange(video, 'b c t h w -> b t c h w')
54+
55+
video = rearrange(video, 'b t ... -> (b t) ...')
56+
57+
func = getattr(self.image_net, self.forward_function)
58+
59+
outputs = func(video)
60+
61+
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
62+
63+
outputs, tree_spec = tree_flatten(outputs)
64+
65+
outputs = tuple(rearrange(t, '(b t) ... -> b t ...', t = time) if is_tensor(t) and t.numel() > 1 else t for t in outputs)
66+
67+
# maybe add time positional embedding
68+
69+
if add_time_pos_emb:
70+
pos_emb = repeat(self.pos_emb, 't d -> b t 1 d', b = batch)
71+
72+
outputs = list(outputs)
73+
embed = outputs[self.output_pos_add_pos_emb]
74+
75+
embed = embed + pos_emb
76+
77+
outputs[self.output_pos_add_pos_emb] = embed
78+
79+
return tree_unflatten(outputs, tree_spec)
80+
81+
# main
82+
83+
if __name__ == '__main__':
84+
from vit_pytorch import ViT
85+
86+
v = ViT(
87+
image_size = 256,
88+
patch_size = 32,
89+
num_classes = 1000,
90+
dim = 1024,
91+
depth = 6,
92+
heads = 16,
93+
mlp_dim = 2048,
94+
dropout = 0.1,
95+
emb_dropout = 0.1
96+
)
97+
98+
videos = torch.randn(1, 3, 10, 256, 256)
99+
100+
# step up the difficulty and return embeddings for robotics
101+
102+
from vit_pytorch.extractor import Extractor
103+
v = Extractor(v)
104+
105+
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 10, dim_emb = 1024)
106+
107+
logits, embeddings = video_acceptor(videos) # always (batch, channels, time, height, width) - time is always dimension 2
108+
109+
assert logits.shape == (1, 10, 1000)
110+
assert embeddings.shape == (1, 10, 65, 1024)

0 commit comments

Comments
 (0)