Skip to content

Commit bde994e

Browse files
authored
Refine some configurations in TSN model (#4853)
1 parent 22cf383 commit bde994e

11 files changed

+57
-96
lines changed

dygraph/tsn/README.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# TSN 视频分类模型
2-
本目录下为基于PaddlePaddle 动态图实现的TSN视频分类模型。模型支持PaddlePaddle Fluid 1.8, GPU, Linux。
2+
本目录下为基于PaddlePaddle 动态图实现的TSN视频分类模型。模型支持PaddlePaddle Fluid 2.0, GPU, Linux。
33

44
---
55
## 内容
@@ -50,7 +50,7 @@ Temporal Segment Network (TSN) 是视频分类领域经典的基于2D-CNN的解
5050

5151
## 数据准备
5252

53-
TSN的训练数据采用UCF101动作识别数据集。数据下载及处理请参考[数据说明](./data/dataset/ucf101/README.md)。数据处理完成后,会在`./data/dataset/ucf101/`目录下,生成一下文件
53+
TSN的训练数据采用UCF101动作识别数据集。数据下载及处理请参考[数据说明](./data/dataset/ucf101/README.md)。数据处理完成后,会在`./data/dataset/ucf101/`目录下,生成以下文件
5454
- `videos/` : 用于存放UCF101数据的视频文件。
5555
- `rawframes/` : 用于存放UCF101视频文件的frame数据。
5656
- `annotations/` : 用于存储UCF101数据集的标注文件。
@@ -60,16 +60,21 @@ TSN的训练数据采用UCF101动作识别数据集。数据下载及处理请
6060

6161

6262
## 模型训练
63+
TSN模型训练,需要加载基于imagenet pretrain的ResNet50参数。可通过输入如下命令下载(默认权重文件会存放在当前目前下`./ResNet50_pretrained/`):
64+
```bash
65+
bash download_pretrain.sh
66+
```
6367

64-
TSN模型支持输入数据为video和frame格式。数据准备完毕后,可以通过如下方式启动不同格式的训练。
68+
TSN模型支持输入数据为video和frame格式。数据以及预训练参数准备完毕后,可以通过如下方式启动不同格式的训练。
6569

6670
1. 多卡训练(输入为frame格式)
6771
```bash
6872
bash multi_gpus_run.sh ./multi_tsn_frame.yaml
6973
```
7074
多卡训练所使用的gpu可以通过如下方式设置:
7175
- 修改`multi_gpus_run.sh``export CUDA_VISIBLE_DEVICES=0,1,2,3`(默认为0,1,2,3表示使用0,1,2,3卡号的gpu进行训练)
72-
- 注意:多卡、frame格式的训练参数配置文件为`multi_tsn_frame.yaml`。若修改了batchsize则学习率也要做相应的修改,规则为大batchsize用大lr,即同倍数增长缩小关系。例如,默认四卡batchsize=128,lr=0.001,若batchsize=64,lr=0.0005。
76+
- 若需要修改预训练权重文件的加载路径,可在`multi_gpus_run.sh`中修改`pretrain`参数(默认`pretrain="./ResNet50_pretrained/"`
77+
- 注意:多卡、frame格式的训练参数配置文件为`multi_tsn_frame.yaml`。若修改了batchsize则学习率也要做相应的修改,规则为大batchsize用大lr,即同倍数增大缩小关系。例如,默认四卡batchsize=128,lr=0.001,若batchsize=64,lr=0.0005。
7378

7479
2. 多卡训练(输入为video格式)
7580
```bash
@@ -85,7 +90,8 @@ bash single_gpu_run.sh ./single_tsn_frame.yaml
8590
```
8691
单卡训练所使用的gpu可以通过如下方式设置:
8792
- 修改 `single_gpu_run.sh` 中的 `export CUDA_VISIBLE_DEVICES=0` (表示使用gpu 0 进行模型训练)
88-
- 注意:单卡、frame格式的训练参数配置文件为`single_tsn_frame.yaml`。若修改了batchsize则学习率也要做相应的修改,规则为大batchsize用大lr,即同倍数增长缩小关系。默认单卡batchsize=64,lr=0.0005;若batchsize=32,lr=0.00025。
93+
- 若需要修改预训练权重文件的加载路径,可在`single_gpu_run.sh`中修改`pretrain`参数(默认`pretrain="./ResNet50_pretrained/"`
94+
- 注意:单卡、frame格式的训练参数配置文件为`single_tsn_frame.yaml`。若修改了batchsize则学习率也要做相应的修改,规则为大batchsize用大lr,即同倍数增长缩小关系。默认单卡batchsize=32,lr=0.00025;若batchsize=64,lr=0.0005。
8995

9096
4. 单卡训练(输入为video格式)
9197
```bash

dygraph/tsn/download_pretrain.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
wget https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz
3+
tar -xzvf ResNet50_pretrained.tar.gz
4+
rm -rf ResNet50_pretrained.tar.gz

dygraph/tsn/multi_gpus_run.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
configs=$1
2-
pretrain="" # set pretrain model path if needed
2+
pretrain="./ResNet50_pretrained/" # set pretrain model path if needed
33
resume="" # set checkpoints model path if u want to resume training
44
save_dir=""
55
use_gpu=True
66
use_data_parallel=True
77

8-
export CUDA_VISIBLE_DEVICES=4,5,6,7
8+
export CUDA_VISIBLE_DEVICES=0,1,2,3
99

1010

1111
echo $mode "TSN" $configs $resume $pretrain

dygraph/tsn/multi_tsn_frame.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ TRAIN:
1313
epoch: 80
1414
short_size: 256
1515
target_size: 224
16-
num_reader_threads: 12
17-
buf_size: 1024
16+
num_reader_threads: 16
17+
buf_size: 256
1818
batch_size: 128
1919
use_gpu: True
2020
filelist: "./data/dataset/ucf101/ucf101_train_split_1_rawframes.txt"
@@ -28,15 +28,15 @@ TRAIN:
2828
VALID:
2929
short_size: 256
3030
target_size: 224
31-
num_reader_threads: 12
32-
buf_size: 1024
31+
num_reader_threads: 16
32+
buf_size: 256
3333
batch_size: 128
3434
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
3535

3636
TEST:
3737
short_size: 256
3838
target_size: 224
39-
num_reader_threads: 12
40-
buf_size: 1024
41-
batch_size: 64
42-
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
39+
num_reader_threads: 16
40+
buf_size: 256
41+
batch_size: 128
42+
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"

dygraph/tsn/multi_tsn_video.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ TRAIN:
1313
epoch: 80
1414
short_size: 256
1515
target_size: 224
16-
num_reader_threads: 12
17-
buf_size: 1024
16+
num_reader_threads: 16
17+
buf_size: 256
1818
batch_size: 128
1919
use_gpu: True
2020
filelist: "./data/dataset/ucf101/ucf101_train_split_1_videos.txt"
@@ -28,15 +28,15 @@ TRAIN:
2828
VALID:
2929
short_size: 256
3030
target_size: 224
31-
num_reader_threads: 12
32-
buf_size: 1024
31+
num_reader_threads: 16
32+
buf_size: 256
3333
batch_size: 128
3434
filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
3535

3636
TEST:
3737
short_size: 256
3838
target_size: 224
39-
num_reader_threads: 12
40-
buf_size: 1024
41-
batch_size: 64
39+
num_reader_threads: 16
40+
buf_size: 256
41+
batch_size: 128
4242
filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt"

dygraph/tsn/reader/__init__.py

Whitespace-only changes.

dygraph/tsn/single_gpu_run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
configs=$1
2-
pretrain="" # set pretrain model path if needed
2+
pretrain="./ResNet50_pretrained/" # set pretrain model path if needed
33
resume="" # set checkpoints model path if u want to resume training
44
save_dir=""
55
use_gpu=True

dygraph/tsn/single_tsn_frame.yaml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ TRAIN:
1313
epoch: 80
1414
short_size: 256
1515
target_size: 224
16-
num_reader_threads: 12
17-
buf_size: 1024
18-
batch_size: 64
16+
num_reader_threads: 8
17+
buf_size: 64
18+
batch_size: 32
1919
use_gpu: True
2020
filelist: "./data/dataset/ucf101/ucf101_train_split_1_rawframes.txt"
21-
learning_rate: 0.0005
21+
learning_rate: 0.00025
2222
learning_rate_decay: 0.1
2323
decay_epochs: [30, 60]
2424
l2_weight_decay: 1e-4
@@ -28,15 +28,15 @@ TRAIN:
2828
VALID:
2929
short_size: 256
3030
target_size: 224
31-
num_reader_threads: 12
32-
buf_size: 1024
33-
batch_size: 128
31+
num_reader_threads: 8
32+
buf_size: 64
33+
batch_size: 32
3434
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
3535

3636
TEST:
3737
short_size: 256
3838
target_size: 224
39-
num_reader_threads: 12
40-
buf_size: 1024
41-
batch_size: 64
39+
num_reader_threads: 8
40+
buf_size: 64
41+
batch_size: 32
4242
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"

dygraph/tsn/single_tsn_video.yaml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ TRAIN:
1313
epoch: 80
1414
short_size: 256
1515
target_size: 224
16-
num_reader_threads: 12
17-
buf_size: 1024
18-
batch_size: 64
16+
num_reader_threads: 8
17+
buf_size: 64
18+
batch_size: 32
1919
use_gpu: True
2020
filelist: "./data/dataset/ucf101/ucf101_train_split_1_videos.txt"
21-
learning_rate: 0.0005
21+
learning_rate: 0.00025
2222
learning_rate_decay: 0.1
2323
decay_epochs: [30, 60]
2424
l2_weight_decay: 1e-4
@@ -28,15 +28,15 @@ TRAIN:
2828
VALID:
2929
short_size: 256
3030
target_size: 224
31-
num_reader_threads: 12
32-
buf_size: 1024
33-
batch_size: 128
31+
num_reader_threads: 8
32+
buf_size: 64
33+
batch_size: 32
3434
filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
3535

3636
TEST:
3737
short_size: 256
3838
target_size: 224
39-
num_reader_threads: 12
40-
buf_size: 1024
41-
batch_size: 64
39+
num_reader_threads: 8
40+
buf_size: 64
41+
batch_size: 32
4242
filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt"

dygraph/tsn/train.py

Lines changed: 4 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -95,53 +95,6 @@ def parse_args():
9595
return args
9696

9797

98-
def decompress(path):
99-
t = tarfile.open(path)
100-
print("path[0] {}".format(os.path.split(path)[0]))
101-
t.extractall(path=os.path.split(path)[0])
102-
t.close()
103-
104-
105-
def download(url, path):
106-
weight_dir = os.path.split(path)[0]
107-
if not os.path.exists(weight_dir):
108-
os.makedirs(weight_dir)
109-
110-
path = path + ".tar.gz"
111-
print("path {}".format(path))
112-
wget.download(url, path)
113-
decompress(path)
114-
115-
116-
def pretrain_info():
117-
return (
118-
'ResNet50_pretrained',
119-
'https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz'
120-
)
121-
122-
123-
def download_pretrained(pretrained):
124-
if pretrained is not None:
125-
WEIGHT_DIR = pretrained
126-
else:
127-
WEIGHT_DIR = os.path.join(os.path.expanduser('~'), '.paddle', 'weights')
128-
129-
path, url = pretrain_info()
130-
if not path:
131-
return None
132-
133-
path = os.path.join(WEIGHT_DIR, path)
134-
if not os.path.isdir(WEIGHT_DIR):
135-
logger.info('{} not exists, will be created automatically.'.format(
136-
WEIGHT_DIR))
137-
os.makedirs(WEIGHT_DIR)
138-
if os.path.exists(path):
139-
return path
140-
logger.info("Download pretrain weights of ResNet50 from {}".format(url))
141-
download(url, path)
142-
return path
143-
144-
14598
def init_model(model, pre_state_dict):
14699
param_state_dict = {}
147100
model_dict = model.state_dict()
@@ -224,17 +177,14 @@ def train(args):
224177
train_config = merge_configs(config, 'train', vars(args))
225178
valid_config = merge_configs(config, 'valid', vars(args))
226179
print_configs(train_config, 'Train')
227-
228-
# get the pretrained weights
229-
pretrained_path = download_pretrained(args.pretrain)
230-
231180
use_data_parallel = args.use_data_parallel
181+
232182
trainer_count = fluid.dygraph.parallel.Env().nranks
233183

234184
# (data_parallel step1/6)
235185
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
236186
if use_data_parallel else fluid.CUDAPlace(0)
237-
pre_state_dict = fluid.load_program_state(pretrained_path)
187+
pre_state_dict = fluid.load_program_state(args.pretrain)
238188

239189
with fluid.dygraph.guard(place):
240190
if use_data_parallel:
@@ -342,7 +292,8 @@ def train(args):
342292
model_path = os.path.join(
343293
args.checkpoint,
344294
"_" + model_path_pre + "_epoch{}".format(epoch))
345-
fluid.dygraph.save_dygraph(video_model.state_dict(), model_path)
295+
fluid.dygraph.save_dygraph(
296+
video_model.state_dict(), model_path)
346297
fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path)
347298

348299
if args.validate:

0 commit comments

Comments
 (0)