@@ -79,14 +79,12 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
79
79
}
80
80
// Setup layer.
81
81
const LayerParameter& layer_param = param.layer (layer_id);
82
-
83
- if (layer_param.skip_propagate_down_size () > 0 ) {
84
- CHECK_EQ (layer_param.skip_propagate_down_size (),
82
+ if (layer_param.propagate_down_size () > 0 ) {
83
+ CHECK_EQ (layer_param.propagate_down_size (),
85
84
layer_param.bottom_size ())
86
- << " skip_propagate_down param must be specified"
85
+ << " propagate_down param must be specified"
87
86
<< " either 0 or bottom_size times " ;
88
87
}
89
-
90
88
layers_.push_back (LayerRegistry<Dtype>::CreateLayer (layer_param));
91
89
layer_names_.push_back (layer_param.name ());
92
90
LOG (INFO) << " Creating Layer " << layer_param.name ();
@@ -95,13 +93,8 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
95
93
// Figure out this layer's input and output
96
94
for (int bottom_id = 0 ; bottom_id < layer_param.bottom_size ();
97
95
++bottom_id) {
98
- bool skip_propagate_down = false ;
99
- // Check if the backpropagation on bottom_id should be skipped
100
- if (layer_param.skip_propagate_down_size () > 0 )
101
- skip_propagate_down = layer_param.skip_propagate_down (bottom_id);
102
96
const int blob_id = AppendBottom (param, layer_id, bottom_id,
103
- &available_blobs, &blob_name_to_idx,
104
- skip_propagate_down);
97
+ &available_blobs, &blob_name_to_idx);
105
98
// If a blob needs backward, this layer should provide it.
106
99
need_backward |= blob_need_backward_[blob_id];
107
100
}
@@ -165,15 +158,33 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
165
158
// Go through the net backwards to determine which blobs contribute to the
166
159
// loss. We can skip backward computation for blobs that don't contribute
167
160
// to the loss.
161
+ // Also checks if all bottom blobs don't need backward computation (possible
162
+ // because the skip_propagate_down param) and so we can skip bacward
163
+ // computation for the entire layer
168
164
set<string> blobs_under_loss;
165
+ set<string> blobs_skip_backp;
169
166
for (int layer_id = layers_.size () - 1 ; layer_id >= 0 ; --layer_id) {
170
167
bool layer_contributes_loss = false ;
168
+ bool layer_skip_propagate_down = true ;
171
169
for (int top_id = 0 ; top_id < top_vecs_[layer_id].size (); ++top_id) {
172
170
const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]];
173
171
if (layers_[layer_id]->loss (top_id) ||
174
172
(blobs_under_loss.find (blob_name) != blobs_under_loss.end ())) {
175
173
layer_contributes_loss = true ;
174
+ }
175
+ if (blobs_skip_backp.find (blob_name) == blobs_skip_backp.end ()) {
176
+ layer_skip_propagate_down = false ;
177
+ }
178
+ if (layer_contributes_loss && !layer_skip_propagate_down)
176
179
break ;
180
+ }
181
+ // If this layer can skip backward computation, also all his bottom blobs
182
+ // don't need backpropagation
183
+ if (layer_need_backward_[layer_id] && layer_skip_propagate_down) {
184
+ layer_need_backward_[layer_id] = false ;
185
+ for (int bottom_id = 0 ; bottom_id < bottom_vecs_[layer_id].size ();
186
+ ++bottom_id) {
187
+ bottom_need_backward_[layer_id][bottom_id] = false ;
177
188
}
178
189
}
179
190
if (!layer_contributes_loss) { layer_need_backward_[layer_id] = false ; }
@@ -192,6 +203,11 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
192
203
} else {
193
204
bottom_need_backward_[layer_id][bottom_id] = false ;
194
205
}
206
+ if (!bottom_need_backward_[layer_id][bottom_id]) {
207
+ const string& blob_name =
208
+ blob_names_[bottom_id_vecs_[layer_id][bottom_id]];
209
+ blobs_skip_backp.insert (blob_name);
210
+ }
195
211
}
196
212
}
197
213
// Handle force_backward if needed.
@@ -383,7 +399,7 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
383
399
template <typename Dtype>
384
400
int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
385
401
const int bottom_id, set<string>* available_blobs,
386
- map<string, int >* blob_name_to_idx, bool skip_propagate_down ) {
402
+ map<string, int >* blob_name_to_idx) {
387
403
const LayerParameter& layer_param = param.layer (layer_id);
388
404
const string& blob_name = layer_param.bottom (bottom_id);
389
405
if (available_blobs->find (blob_name) == available_blobs->end ()) {
@@ -395,8 +411,12 @@ int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
395
411
bottom_vecs_[layer_id].push_back (blobs_[blob_id].get ());
396
412
bottom_id_vecs_[layer_id].push_back (blob_id);
397
413
available_blobs->erase (blob_name);
414
+ bool propagate_down = true ;
415
+ // Check if the backpropagation on bottom_id should be skipped
416
+ if (layer_param.propagate_down_size () > 0 )
417
+ propagate_down = layer_param.propagate_down (bottom_id);
398
418
const bool need_backward = blob_need_backward_[blob_id] &
399
- !skip_propagate_down ;
419
+ propagate_down ;
400
420
bottom_need_backward_[layer_id].push_back (need_backward);
401
421
return blob_id;
402
422
}
0 commit comments