@@ -60,6 +60,16 @@ bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) {
60
60
<< " V1LayerParameter" ;
61
61
}
62
62
}
63
+ // NetParameter uses old style input fields; try to upgrade it.
64
+ if (NetNeedsInputUpgrade (*param)) {
65
+ LOG (INFO) << " Attempting to upgrade input file specified using deprecated "
66
+ << " input fields: " << param_file;
67
+ UpgradeNetInput (param);
68
+ LOG (INFO) << " Successfully upgraded file specified using deprecated "
69
+ << " input fields." ;
70
+ LOG (WARNING) << " Note that future Caffe releases will only support "
71
+ << " input layers and not input fields." ;
72
+ }
63
73
return success;
64
74
}
65
75
@@ -937,6 +947,41 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) {
937
947
}
938
948
}
939
949
950
+ bool NetNeedsInputUpgrade (const NetParameter& net_param) {
951
+ return net_param.input_size () > 0 ;
952
+ }
953
+
954
+ void UpgradeNetInput (NetParameter* net_param) {
955
+ LayerParameter* layer_param = net_param->add_layer ();
956
+ layer_param->set_name (" input" );
957
+ layer_param->set_type (" Input" );
958
+ InputParameter* input_param = layer_param->mutable_input_param ();
959
+ bool has_shape = net_param->input_shape_size () > 0 ;
960
+ // Convert input fields into a layer.
961
+ for (int i = 0 ; i < net_param->input_size (); ++i) {
962
+ layer_param->add_top (net_param->input (i));
963
+ if (has_shape) {
964
+ input_param->add_shape ()->CopyFrom (net_param->input_shape (i));
965
+ } else {
966
+ // Turn legacy input dimensions into shape.
967
+ BlobShape* shape = input_param->add_shape ();
968
+ int first_dim = i*4 ;
969
+ int last_dim = first_dim + 4 ;
970
+ for (int j = first_dim; j < last_dim; j++) {
971
+ shape->add_dim (net_param->input_dim (j));
972
+ }
973
+ }
974
+ }
975
+ // Swap input layer to beginning of net to satisfy layer dependencies.
976
+ for (int i = net_param->layer_size () - 1 ; i > 0 ; --i) {
977
+ net_param->mutable_layer (i-1 )->Swap (net_param->mutable_layer (i));
978
+ }
979
+ // Clear inputs.
980
+ net_param->clear_input ();
981
+ net_param->clear_input_shape ();
982
+ net_param->clear_input_dim ();
983
+ }
984
+
940
985
// Return true iff the solver contains any old solver_type specified as enums
941
986
bool SolverNeedsTypeUpgrade (const SolverParameter& solver_param) {
942
987
if (solver_param.has_solver_type ()) {
0 commit comments