Skip to content

Commit bddd04b

Browse files
committed
deprecate input fields and upgrade automagically
1 parent 00598ca commit bddd04b

File tree

3 files changed

+54
-3
lines changed

3 files changed

+54
-3
lines changed

include/caffe/util/upgrade_proto.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ bool UpgradeV1LayerParameter(const V1LayerParameter& v1_layer_param,
5959

6060
const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type);
6161

62+
// Return true iff the Net contains input fields.
63+
bool NetNeedsInputUpgrade(const NetParameter& net_param);
64+
65+
// Perform all necessary transformations to upgrade input fields into layers.
66+
void UpgradeNetInput(NetParameter* net_param);
67+
6268
// Return true iff the solver contains any old solver_type specified as enums
6369
bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param);
6470

src/caffe/proto/caffe.proto

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ message FillerParameter {
6363

6464
message NetParameter {
6565
optional string name = 1; // consider giving the network a name
66-
// The input blobs to the network.
66+
// DEPRECATED. See InputParameter. The input blobs to the network.
6767
repeated string input = 3;
68-
// The shape of the input blobs.
68+
// DEPRECATED. See InputParameter. The shape of the input blobs.
6969
repeated BlobShape input_shape = 8;
7070

71-
// 4D input dimensions -- deprecated. Use "shape" instead.
71+
// 4D input dimensions -- deprecated. Use "input_shape" instead.
7272
// If specified, for each input blob there should be four
7373
// values specifying the num, channels, height and width of the input blob.
7474
// Thus, there should be a total of (4 * #input) numbers.

src/caffe/util/upgrade_proto.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) {
6060
<< "V1LayerParameter";
6161
}
6262
}
63+
// NetParameter uses old style input fields; try to upgrade it.
64+
if (NetNeedsInputUpgrade(*param)) {
65+
LOG(INFO) << "Attempting to upgrade input file specified using deprecated "
66+
<< "input fields: " << param_file;
67+
UpgradeNetInput(param);
68+
LOG(INFO) << "Successfully upgraded file specified using deprecated "
69+
<< "input fields.";
70+
LOG(WARNING) << "Note that future Caffe releases will only support "
71+
<< "input layers and not input fields.";
72+
}
6373
return success;
6474
}
6575

@@ -937,6 +947,41 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) {
937947
}
938948
}
939949

950+
bool NetNeedsInputUpgrade(const NetParameter& net_param) {
951+
return net_param.input_size() > 0;
952+
}
953+
954+
void UpgradeNetInput(NetParameter* net_param) {
955+
LayerParameter* layer_param = net_param->add_layer();
956+
layer_param->set_name("input");
957+
layer_param->set_type("Input");
958+
InputParameter* input_param = layer_param->mutable_input_param();
959+
bool has_shape = net_param->input_shape_size() > 0;
960+
// Convert input fields into a layer.
961+
for (int i = 0; i < net_param->input_size(); ++i) {
962+
layer_param->add_top(net_param->input(i));
963+
if (has_shape) {
964+
input_param->add_shape()->CopyFrom(net_param->input_shape(i));
965+
} else {
966+
// Turn legacy input dimensions into shape.
967+
BlobShape* shape = input_param->add_shape();
968+
int first_dim = i*4;
969+
int last_dim = first_dim + 4;
970+
for (int j = first_dim; j < last_dim; j++) {
971+
shape->add_dim(net_param->input_dim(j));
972+
}
973+
}
974+
}
975+
// Swap input layer to beginning of net to satisfy layer dependencies.
976+
for (int i = net_param->layer_size() - 1; i > 0; --i) {
977+
net_param->mutable_layer(i-1)->Swap(net_param->mutable_layer(i));
978+
}
979+
// Clear inputs.
980+
net_param->clear_input();
981+
net_param->clear_input_shape();
982+
net_param->clear_input_dim();
983+
}
984+
940985
// Return true iff the solver contains any old solver_type specified as enums
941986
bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param) {
942987
if (solver_param.has_solver_type()) {

0 commit comments

Comments
 (0)