Skip to content

Commit 1317671

Browse files
committed
tests and comments fixed
1 parent 8c3625f commit 1317671

File tree

3 files changed

+28
-28
lines changed

3 files changed

+28
-28
lines changed

src/caffe/net.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
8282
if (layer_param.propagate_down_size() > 0) {
8383
CHECK_EQ(layer_param.propagate_down_size(),
8484
layer_param.bottom_size())
85-
<< "propagate_down param must be specified"
86-
<< "either 0 or bottom_size times ";
85+
<< "propagate_down param must be specified "
86+
<< "either 0 or bottom_size times ";
8787
}
8888
layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));
8989
layer_names_.push_back(layer_param.name());
@@ -415,7 +415,7 @@ int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
415415
// Check if the backpropagation on bottom_id should be skipped
416416
if (layer_param.propagate_down_size() > 0)
417417
propagate_down = layer_param.propagate_down(bottom_id);
418-
const bool need_backward = blob_need_backward_[blob_id] &
418+
const bool need_backward = blob_need_backward_[blob_id] &&
419419
propagate_down;
420420
bottom_need_backward_[layer_id].push_back(need_backward);
421421
return blob_id;

src/caffe/proto/caffe.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ message LayerParameter {
282282
repeated BlobProto blobs = 7;
283283

284284
// Specifies on which bottoms the backpropagation should be skipped.
285-
// The size must be either 0 or equals to the number of bottoms.
285+
// The size must be either 0 or equal to the number of bottoms.
286286
repeated bool propagate_down = 11;
287287

288288
// Rules controlling whether and when a layer is included in the network,

src/caffe/test/test_net.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2326,46 +2326,46 @@ TYPED_TEST(NetTest, TestSkipPropagateDown) {
23262326
this->InitSkipPropNet(false);
23272327
vector<bool> vec_layer_need_backward = this->net_->layer_need_backward();
23282328
for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
2329-
if (this->net_->layer_names()[layer_id] == "loss") {
2329+
string layer_name = this->net_->layer_names()[layer_id];
2330+
if (layer_name == "loss") {
23302331
// access to bottom_need_backward coresponding to label's blob
23312332
bool need_back = this->net_->bottom_need_backward()[layer_id][1];
23322333
// if propagate_down is true, the loss layer will try to
23332334
// backpropagate on labels
2334-
CHECK_EQ(need_back, true)
2335-
<< "bottom_need_backward should be True";
2335+
EXPECT_TRUE(need_back) << "bottom_need_backward should be True";
2336+
}
2337+
// layer_need_backward should be True except for data and silence layers
2338+
if (layer_name.find("data") != std::string::npos ||
2339+
layer_name == "silence") {
2340+
EXPECT_FALSE(vec_layer_need_backward[layer_id])
2341+
<< "layer_need_backward for " << layer_name << " should be False";
2342+
} else {
2343+
EXPECT_TRUE(vec_layer_need_backward[layer_id])
2344+
<< "layer_need_backward for " << layer_name << " should be True";
23362345
}
2337-
if (this->net_->layer_names()[layer_id] == "ip_fake_labels")
2338-
CHECK_EQ(vec_layer_need_backward[layer_id], true)
2339-
<< "layer_need_backward for ip_fake_labels should be True";
2340-
if (this->net_->layer_names()[layer_id] == "argmax")
2341-
CHECK_EQ(vec_layer_need_backward[layer_id], true)
2342-
<< "layer_need_backward for argmax should be True";
2343-
if (this->net_->layer_names()[layer_id] == "innerproduct")
2344-
CHECK_EQ(vec_layer_need_backward[layer_id], true)
2345-
<< "layer_need_backward for innerproduct should be True";
23462346
}
23472347
// check bottom_need_backward if propagat_down is false
23482348
this->InitSkipPropNet(true);
23492349
vec_layer_need_backward.clear();
23502350
vec_layer_need_backward = this->net_->layer_need_backward();
23512351
for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
2352-
if (this->net_->layer_names()[layer_id] == "loss") {
2352+
string layer_name = this->net_->layer_names()[layer_id];
2353+
if (layer_name == "loss") {
23532354
// access to bottom_need_backward coresponding to label's blob
23542355
bool need_back = this->net_->bottom_need_backward()[layer_id][1];
23552356
// if propagate_down is false, the loss layer will not try to
23562357
// backpropagate on labels
2357-
CHECK_EQ(need_back, false)
2358-
<< "bottom_need_backward should be False";
2358+
EXPECT_FALSE(need_back) << "bottom_need_backward should be False";
2359+
}
2360+
// layer_need_backward should be False except for innerproduct and
2361+
// loss layers
2362+
if (layer_name == "innerproduct" || layer_name == "loss") {
2363+
EXPECT_TRUE(vec_layer_need_backward[layer_id])
2364+
<< "layer_need_backward for " << layer_name << " should be True";
2365+
} else {
2366+
EXPECT_FALSE(vec_layer_need_backward[layer_id])
2367+
<< "layer_need_backward for " << layer_name << " should be False";
23592368
}
2360-
if (this->net_->layer_names()[layer_id] == "ip_fake_labels")
2361-
CHECK_EQ(vec_layer_need_backward[layer_id], false)
2362-
<< "layer_need_backward for ip_fake_labels should be False";
2363-
if (this->net_->layer_names()[layer_id] == "argmax")
2364-
CHECK_EQ(vec_layer_need_backward[layer_id], false)
2365-
<< "layer_need_backward for argmax should be False";
2366-
if (this->net_->layer_names()[layer_id] == "innerproduct")
2367-
CHECK_EQ(vec_layer_need_backward[layer_id], true)
2368-
<< "layer_need_backward for innerproduct should be True";
23692369
}
23702370
}
23712371

0 commit comments

Comments
 (0)