Skip to content

Commit 8c3625f

Browse files
committed
layer_need_backward_ is now correct when propagate_down is specified
- skip_propagate_down is now called propagate_now - the graph of the net is visited top-down to correctly update layer_need_backward_ - tests updated
1 parent 13bfe68 commit 8c3625f

File tree

4 files changed

+65
-24
lines changed

4 files changed

+65
-24
lines changed

include/caffe/net.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ class Net {
137137
inline const vector<Dtype>& blob_loss_weights() const {
138138
return blob_loss_weights_;
139139
}
140+
inline const vector<bool>& layer_need_backward() const {
141+
return layer_need_backward_;
142+
}
140143
/// @brief returns the parameters
141144
inline const vector<shared_ptr<Blob<Dtype> > >& params() const {
142145
return params_;
@@ -192,8 +195,7 @@ class Net {
192195
/// @brief Append a new bottom blob to the net.
193196
int AppendBottom(const NetParameter& param, const int layer_id,
194197
const int bottom_id, set<string>* available_blobs,
195-
map<string, int>* blob_name_to_idx,
196-
bool skip_propagate = false);
198+
map<string, int>* blob_name_to_idx);
197199
/// @brief Append a new parameter blob to the net.
198200
void AppendParam(const NetParameter& param, const int layer_id,
199201
const int param_id);

src/caffe/net.cpp

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,12 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
7979
}
8080
// Setup layer.
8181
const LayerParameter& layer_param = param.layer(layer_id);
82-
83-
if (layer_param.skip_propagate_down_size() > 0) {
84-
CHECK_EQ(layer_param.skip_propagate_down_size(),
82+
if (layer_param.propagate_down_size() > 0) {
83+
CHECK_EQ(layer_param.propagate_down_size(),
8584
layer_param.bottom_size())
86-
<< "skip_propagate_down param must be specified"
85+
<< "propagate_down param must be specified"
8786
<< "either 0 or bottom_size times ";
8887
}
89-
9088
layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));
9189
layer_names_.push_back(layer_param.name());
9290
LOG(INFO) << "Creating Layer " << layer_param.name();
@@ -95,13 +93,8 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
9593
// Figure out this layer's input and output
9694
for (int bottom_id = 0; bottom_id < layer_param.bottom_size();
9795
++bottom_id) {
98-
bool skip_propagate_down = false;
99-
// Check if the backpropagation on bottom_id should be skipped
100-
if (layer_param.skip_propagate_down_size() > 0)
101-
skip_propagate_down = layer_param.skip_propagate_down(bottom_id);
10296
const int blob_id = AppendBottom(param, layer_id, bottom_id,
103-
&available_blobs, &blob_name_to_idx,
104-
skip_propagate_down);
97+
&available_blobs, &blob_name_to_idx);
10598
// If a blob needs backward, this layer should provide it.
10699
need_backward |= blob_need_backward_[blob_id];
107100
}
@@ -165,15 +158,33 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
165158
// Go through the net backwards to determine which blobs contribute to the
166159
// loss. We can skip backward computation for blobs that don't contribute
167160
// to the loss.
161+
// Also checks if all bottom blobs don't need backward computation (possible
162+
// because the skip_propagate_down param) and so we can skip bacward
163+
// computation for the entire layer
168164
set<string> blobs_under_loss;
165+
set<string> blobs_skip_backp;
169166
for (int layer_id = layers_.size() - 1; layer_id >= 0; --layer_id) {
170167
bool layer_contributes_loss = false;
168+
bool layer_skip_propagate_down = true;
171169
for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) {
172170
const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]];
173171
if (layers_[layer_id]->loss(top_id) ||
174172
(blobs_under_loss.find(blob_name) != blobs_under_loss.end())) {
175173
layer_contributes_loss = true;
174+
}
175+
if (blobs_skip_backp.find(blob_name) == blobs_skip_backp.end()) {
176+
layer_skip_propagate_down = false;
177+
}
178+
if (layer_contributes_loss && !layer_skip_propagate_down)
176179
break;
180+
}
181+
// If this layer can skip backward computation, also all his bottom blobs
182+
// don't need backpropagation
183+
if (layer_need_backward_[layer_id] && layer_skip_propagate_down) {
184+
layer_need_backward_[layer_id] = false;
185+
for (int bottom_id = 0; bottom_id < bottom_vecs_[layer_id].size();
186+
++bottom_id) {
187+
bottom_need_backward_[layer_id][bottom_id] = false;
177188
}
178189
}
179190
if (!layer_contributes_loss) { layer_need_backward_[layer_id] = false; }
@@ -192,6 +203,11 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
192203
} else {
193204
bottom_need_backward_[layer_id][bottom_id] = false;
194205
}
206+
if (!bottom_need_backward_[layer_id][bottom_id]) {
207+
const string& blob_name =
208+
blob_names_[bottom_id_vecs_[layer_id][bottom_id]];
209+
blobs_skip_backp.insert(blob_name);
210+
}
195211
}
196212
}
197213
// Handle force_backward if needed.
@@ -383,7 +399,7 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
383399
template <typename Dtype>
384400
int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
385401
const int bottom_id, set<string>* available_blobs,
386-
map<string, int>* blob_name_to_idx, bool skip_propagate_down) {
402+
map<string, int>* blob_name_to_idx) {
387403
const LayerParameter& layer_param = param.layer(layer_id);
388404
const string& blob_name = layer_param.bottom(bottom_id);
389405
if (available_blobs->find(blob_name) == available_blobs->end()) {
@@ -395,8 +411,12 @@ int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
395411
bottom_vecs_[layer_id].push_back(blobs_[blob_id].get());
396412
bottom_id_vecs_[layer_id].push_back(blob_id);
397413
available_blobs->erase(blob_name);
414+
bool propagate_down = true;
415+
// Check if the backpropagation on bottom_id should be skipped
416+
if (layer_param.propagate_down_size() > 0)
417+
propagate_down = layer_param.propagate_down(bottom_id);
398418
const bool need_backward = blob_need_backward_[blob_id] &
399-
!skip_propagate_down;
419+
propagate_down;
400420
bottom_need_backward_[layer_id].push_back(need_backward);
401421
return blob_id;
402422
}

src/caffe/proto/caffe.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ message LayerParameter {
283283

284284
// Specifies on which bottoms the backpropagation should be skipped.
285285
// The size must be either 0 or equals to the number of bottoms.
286-
repeated bool skip_propagate_down = 11;
286+
repeated bool propagate_down = 11;
287287

288288
// Rules controlling whether and when a layer is included in the network,
289289
// based on the current NetState. You may specify a non-zero number of rules

src/caffe/test/test_net.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ class NetTest : public MultiDeviceTest<TypeParam> {
5555
}
5656
}
5757

58-
59-
6058
virtual void InitTinyNet(const bool force_backward = false,
6159
const bool accuracy_layer = false) {
6260
string proto =
@@ -701,9 +699,9 @@ class NetTest : public MultiDeviceTest<TypeParam> {
701699
" bottom: 'innerproduct' "
702700
" bottom: 'label_argmax' ";
703701
if (test_skip_true)
704-
proto += " skip_propagate_down: [false, true] ";
702+
proto += " propagate_down: [true, false] ";
705703
else
706-
proto += " skip_propagate_down: [false, false] ";
704+
proto += " propagate_down: [true, true] ";
707705
proto +=
708706
" top: 'cross_entropy_loss' "
709707
" type: 'SigmoidCrossEntropyLoss' "
@@ -2324,29 +2322,50 @@ TYPED_TEST(NetTest, TestReshape) {
23242322
}
23252323

23262324
TYPED_TEST(NetTest, TestSkipPropagateDown) {
2327-
// check bottom_need_backward if skip_propagat_down is false
2325+
// check bottom_need_backward if propagate_down is true
23282326
this->InitSkipPropNet(false);
2327+
vector<bool> vec_layer_need_backward = this->net_->layer_need_backward();
23292328
for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
23302329
if (this->net_->layer_names()[layer_id] == "loss") {
23312330
// access to bottom_need_backward coresponding to label's blob
23322331
bool need_back = this->net_->bottom_need_backward()[layer_id][1];
2333-
// if skip_propagate_down is false, the loss layer will try to
2332+
// if propagate_down is true, the loss layer will try to
23342333
// backpropagate on labels
23352334
CHECK_EQ(need_back, true)
23362335
<< "bottom_need_backward should be True";
23372336
}
2337+
if (this->net_->layer_names()[layer_id] == "ip_fake_labels")
2338+
CHECK_EQ(vec_layer_need_backward[layer_id], true)
2339+
<< "layer_need_backward for ip_fake_labels should be True";
2340+
if (this->net_->layer_names()[layer_id] == "argmax")
2341+
CHECK_EQ(vec_layer_need_backward[layer_id], true)
2342+
<< "layer_need_backward for argmax should be True";
2343+
if (this->net_->layer_names()[layer_id] == "innerproduct")
2344+
CHECK_EQ(vec_layer_need_backward[layer_id], true)
2345+
<< "layer_need_backward for innerproduct should be True";
23382346
}
2339-
// check bottom_need_backward if skip_propagat_down is true
2347+
// check bottom_need_backward if propagat_down is false
23402348
this->InitSkipPropNet(true);
2349+
vec_layer_need_backward.clear();
2350+
vec_layer_need_backward = this->net_->layer_need_backward();
23412351
for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
23422352
if (this->net_->layer_names()[layer_id] == "loss") {
23432353
// access to bottom_need_backward coresponding to label's blob
23442354
bool need_back = this->net_->bottom_need_backward()[layer_id][1];
2345-
// if skip_propagate_down is true, the loss layer will not try to
2355+
// if propagate_down is false, the loss layer will not try to
23462356
// backpropagate on labels
23472357
CHECK_EQ(need_back, false)
23482358
<< "bottom_need_backward should be False";
23492359
}
2360+
if (this->net_->layer_names()[layer_id] == "ip_fake_labels")
2361+
CHECK_EQ(vec_layer_need_backward[layer_id], false)
2362+
<< "layer_need_backward for ip_fake_labels should be False";
2363+
if (this->net_->layer_names()[layer_id] == "argmax")
2364+
CHECK_EQ(vec_layer_need_backward[layer_id], false)
2365+
<< "layer_need_backward for argmax should be False";
2366+
if (this->net_->layer_names()[layer_id] == "innerproduct")
2367+
CHECK_EQ(vec_layer_need_backward[layer_id], true)
2368+
<< "layer_need_backward for innerproduct should be True";
23502369
}
23512370
}
23522371

0 commit comments

Comments
 (0)