@@ -95,53 +95,6 @@ def parse_args():
95
95
return args
96
96
97
97
98
- def decompress (path ):
99
- t = tarfile .open (path )
100
- print ("path[0] {}" .format (os .path .split (path )[0 ]))
101
- t .extractall (path = os .path .split (path )[0 ])
102
- t .close ()
103
-
104
-
105
- def download (url , path ):
106
- weight_dir = os .path .split (path )[0 ]
107
- if not os .path .exists (weight_dir ):
108
- os .makedirs (weight_dir )
109
-
110
- path = path + ".tar.gz"
111
- print ("path {}" .format (path ))
112
- wget .download (url , path )
113
- decompress (path )
114
-
115
-
116
- def pretrain_info ():
117
- return (
118
- 'ResNet50_pretrained' ,
119
- 'https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz'
120
- )
121
-
122
-
123
- def download_pretrained (pretrained ):
124
- if pretrained is not None :
125
- WEIGHT_DIR = pretrained
126
- else :
127
- WEIGHT_DIR = os .path .join (os .path .expanduser ('~' ), '.paddle' , 'weights' )
128
-
129
- path , url = pretrain_info ()
130
- if not path :
131
- return None
132
-
133
- path = os .path .join (WEIGHT_DIR , path )
134
- if not os .path .isdir (WEIGHT_DIR ):
135
- logger .info ('{} not exists, will be created automatically.' .format (
136
- WEIGHT_DIR ))
137
- os .makedirs (WEIGHT_DIR )
138
- if os .path .exists (path ):
139
- return path
140
- logger .info ("Download pretrain weights of ResNet50 from {}" .format (url ))
141
- download (url , path )
142
- return path
143
-
144
-
145
98
def init_model (model , pre_state_dict ):
146
99
param_state_dict = {}
147
100
model_dict = model .state_dict ()
@@ -224,17 +177,14 @@ def train(args):
224
177
train_config = merge_configs (config , 'train' , vars (args ))
225
178
valid_config = merge_configs (config , 'valid' , vars (args ))
226
179
print_configs (train_config , 'Train' )
227
-
228
- # get the pretrained weights
229
- pretrained_path = download_pretrained (args .pretrain )
230
-
231
180
use_data_parallel = args .use_data_parallel
181
+
232
182
trainer_count = fluid .dygraph .parallel .Env ().nranks
233
183
234
184
# (data_parallel step1/6)
235
185
place = fluid .CUDAPlace (fluid .dygraph .parallel .Env ().dev_id ) \
236
186
if use_data_parallel else fluid .CUDAPlace (0 )
237
- pre_state_dict = fluid .load_program_state (pretrained_path )
187
+ pre_state_dict = fluid .load_program_state (args . pretrain )
238
188
239
189
with fluid .dygraph .guard (place ):
240
190
if use_data_parallel :
@@ -342,7 +292,8 @@ def train(args):
342
292
model_path = os .path .join (
343
293
args .checkpoint ,
344
294
"_" + model_path_pre + "_epoch{}" .format (epoch ))
345
- fluid .dygraph .save_dygraph (video_model .state_dict (), model_path )
295
+ fluid .dygraph .save_dygraph (
296
+ video_model .state_dict (), model_path )
346
297
fluid .dygraph .save_dygraph (optimizer .state_dict (), model_path )
347
298
348
299
if args .validate :
0 commit comments