@@ -2326,46 +2326,46 @@ TYPED_TEST(NetTest, TestSkipPropagateDown) {
2326
2326
this ->InitSkipPropNet (false );
2327
2327
vector<bool > vec_layer_need_backward = this ->net_ ->layer_need_backward ();
2328
2328
for (int layer_id = 0 ; layer_id < this ->net_ ->layers ().size (); ++layer_id) {
2329
- if (this ->net_ ->layer_names ()[layer_id] == " loss" ) {
2329
+ string layer_name = this ->net_ ->layer_names ()[layer_id];
2330
+ if (layer_name == " loss" ) {
2330
2331
// access to bottom_need_backward coresponding to label's blob
2331
2332
bool need_back = this ->net_ ->bottom_need_backward ()[layer_id][1 ];
2332
2333
// if propagate_down is true, the loss layer will try to
2333
2334
// backpropagate on labels
2334
- CHECK_EQ (need_back, true )
2335
- << " bottom_need_backward should be True" ;
2335
+ EXPECT_TRUE (need_back) << " bottom_need_backward should be True" ;
2336
+ }
2337
+ // layer_need_backward should be True except for data and silence layers
2338
+ if (layer_name.find (" data" ) != std::string::npos ||
2339
+ layer_name == " silence" ) {
2340
+ EXPECT_FALSE (vec_layer_need_backward[layer_id])
2341
+ << " layer_need_backward for " << layer_name << " should be False" ;
2342
+ } else {
2343
+ EXPECT_TRUE (vec_layer_need_backward[layer_id])
2344
+ << " layer_need_backward for " << layer_name << " should be True" ;
2336
2345
}
2337
- if (this ->net_ ->layer_names ()[layer_id] == " ip_fake_labels" )
2338
- CHECK_EQ (vec_layer_need_backward[layer_id], true )
2339
- << " layer_need_backward for ip_fake_labels should be True" ;
2340
- if (this ->net_ ->layer_names ()[layer_id] == " argmax" )
2341
- CHECK_EQ (vec_layer_need_backward[layer_id], true )
2342
- << " layer_need_backward for argmax should be True" ;
2343
- if (this ->net_ ->layer_names ()[layer_id] == " innerproduct" )
2344
- CHECK_EQ (vec_layer_need_backward[layer_id], true )
2345
- << " layer_need_backward for innerproduct should be True" ;
2346
2346
}
2347
2347
// check bottom_need_backward if propagat_down is false
2348
2348
this ->InitSkipPropNet (true );
2349
2349
vec_layer_need_backward.clear ();
2350
2350
vec_layer_need_backward = this ->net_ ->layer_need_backward ();
2351
2351
for (int layer_id = 0 ; layer_id < this ->net_ ->layers ().size (); ++layer_id) {
2352
- if (this ->net_ ->layer_names ()[layer_id] == " loss" ) {
2352
+ string layer_name = this ->net_ ->layer_names ()[layer_id];
2353
+ if (layer_name == " loss" ) {
2353
2354
// access to bottom_need_backward coresponding to label's blob
2354
2355
bool need_back = this ->net_ ->bottom_need_backward ()[layer_id][1 ];
2355
2356
// if propagate_down is false, the loss layer will not try to
2356
2357
// backpropagate on labels
2357
- CHECK_EQ (need_back, false )
2358
- << " bottom_need_backward should be False" ;
2358
+ EXPECT_FALSE (need_back) << " bottom_need_backward should be False" ;
2359
+ }
2360
+ // layer_need_backward should be False except for innerproduct and
2361
+ // loss layers
2362
+ if (layer_name == " innerproduct" || layer_name == " loss" ) {
2363
+ EXPECT_TRUE (vec_layer_need_backward[layer_id])
2364
+ << " layer_need_backward for " << layer_name << " should be True" ;
2365
+ } else {
2366
+ EXPECT_FALSE (vec_layer_need_backward[layer_id])
2367
+ << " layer_need_backward for " << layer_name << " should be False" ;
2359
2368
}
2360
- if (this ->net_ ->layer_names ()[layer_id] == " ip_fake_labels" )
2361
- CHECK_EQ (vec_layer_need_backward[layer_id], false )
2362
- << " layer_need_backward for ip_fake_labels should be False" ;
2363
- if (this ->net_ ->layer_names ()[layer_id] == " argmax" )
2364
- CHECK_EQ (vec_layer_need_backward[layer_id], false )
2365
- << " layer_need_backward for argmax should be False" ;
2366
- if (this ->net_ ->layer_names ()[layer_id] == " innerproduct" )
2367
- CHECK_EQ (vec_layer_need_backward[layer_id], true )
2368
- << " layer_need_backward for innerproduct should be True" ;
2369
2369
}
2370
2370
}
2371
2371
0 commit comments