Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 34 additions & 25 deletions src/caffe/util/upgrade_proto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
namespace caffe {

bool NetNeedsUpgrade(const NetParameter& net_param) {
return NetNeedsV0ToV1Upgrade(net_param) || NetNeedsV1ToV2Upgrade(net_param);
return NetNeedsV0ToV1Upgrade(net_param) || NetNeedsV1ToV2Upgrade(net_param)
|| NetNeedsDataUpgrade(net_param) || NetNeedsInputUpgrade(net_param);
}

bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) {
Expand Down Expand Up @@ -655,12 +656,14 @@ void UpgradeNetDataTransformation(NetParameter* net_param) {
}

bool UpgradeV1Net(const NetParameter& v1_net_param, NetParameter* net_param) {
bool is_fully_compatible = true;
if (v1_net_param.layer_size() > 0) {
LOG(ERROR) << "Input NetParameter to be upgraded already specifies 'layer' "
<< "fields; these will be ignored for the upgrade.";
is_fully_compatible = false;
LOG(FATAL) << "Refusing to upgrade inconsistent NetParameter input; "
<< "the definition includes both 'layer' and 'layers' fields. "
<< "The current format defines 'layer' fields with string type like "
<< "layer { type: 'Layer' ... } and not layers { type: LAYER ... }. "
<< "Manually switch the definition to 'layer' format to continue.";
}
bool is_fully_compatible = true;
net_param->CopyFrom(v1_net_param);
net_param->clear_layers();
net_param->clear_layer();
Expand Down Expand Up @@ -952,29 +955,35 @@ bool NetNeedsInputUpgrade(const NetParameter& net_param) {
}

void UpgradeNetInput(NetParameter* net_param) {
LayerParameter* layer_param = net_param->add_layer();
layer_param->set_name("input");
layer_param->set_type("Input");
InputParameter* input_param = layer_param->mutable_input_param();
// Collect inputs and convert to Input layer definitions.
// If the NetParameter holds an input alone, without shape/dim, then
// it's a legacy caffemodel and simply stripping the input field is enough.
bool has_shape = net_param->input_shape_size() > 0;
// Convert input fields into a layer.
for (int i = 0; i < net_param->input_size(); ++i) {
layer_param->add_top(net_param->input(i));
if (has_shape) {
input_param->add_shape()->CopyFrom(net_param->input_shape(i));
} else {
// Turn legacy input dimensions into shape.
BlobShape* shape = input_param->add_shape();
int first_dim = i*4;
int last_dim = first_dim + 4;
for (int j = first_dim; j < last_dim; j++) {
shape->add_dim(net_param->input_dim(j));
bool has_dim = net_param->input_dim_size() > 0;
if (has_shape || has_dim) {
LayerParameter* layer_param = net_param->add_layer();
layer_param->set_name("input");
layer_param->set_type("Input");
InputParameter* input_param = layer_param->mutable_input_param();
// Convert input fields into a layer.
for (int i = 0; i < net_param->input_size(); ++i) {
layer_param->add_top(net_param->input(i));
if (has_shape) {
input_param->add_shape()->CopyFrom(net_param->input_shape(i));
} else {
// Turn legacy input dimensions into shape.
BlobShape* shape = input_param->add_shape();
int first_dim = i*4;
int last_dim = first_dim + 4;
for (int j = first_dim; j < last_dim; j++) {
shape->add_dim(net_param->input_dim(j));
}
}
}
}
// Swap input layer to beginning of net to satisfy layer dependencies.
for (int i = net_param->layer_size() - 1; i > 0; --i) {
net_param->mutable_layer(i-1)->Swap(net_param->mutable_layer(i));
// Swap input layer to beginning of net to satisfy layer dependencies.
for (int i = net_param->layer_size() - 1; i > 0; --i) {
net_param->mutable_layer(i-1)->Swap(net_param->mutable_layer(i));
}
}
// Clear inputs.
net_param->clear_input();
Expand Down
5 changes: 3 additions & 2 deletions tools/upgrade_net_proto_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using std::ofstream;
using namespace caffe; // NOLINT(build/namespaces)

int main(int argc, char** argv) {
FLAGS_alsologtostderr = 1; // Print output to stderr (while still logging)
::google::InitGoogleLogging(argv[0]);
if (argc != 3) {
LOG(ERROR) << "Usage: "
Expand All @@ -39,11 +40,11 @@ int main(int argc, char** argv) {
<< "see details above.";
}
} else {
LOG(ERROR) << "File already in V1 proto format: " << argv[1];
LOG(ERROR) << "File already in latest proto format: " << input_filename;
}

WriteProtoToBinaryFile(net_param, argv[2]);

LOG(ERROR) << "Wrote upgraded NetParameter binary proto to " << argv[2];
LOG(INFO) << "Wrote upgraded NetParameter binary proto to " << argv[2];
return !success;
}
8 changes: 2 additions & 6 deletions tools/upgrade_net_proto_text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using std::ofstream;
using namespace caffe; // NOLINT(build/namespaces)

int main(int argc, char** argv) {
FLAGS_alsologtostderr = 1; // Print output to stderr (while still logging)
::google::InitGoogleLogging(argv[0]);
if (argc != 3) {
LOG(ERROR) << "Usage: "
Expand All @@ -31,7 +32,6 @@ int main(int argc, char** argv) {
return 2;
}
bool need_upgrade = NetNeedsUpgrade(net_param);
bool need_data_upgrade = NetNeedsDataUpgrade(net_param);
bool success = true;
if (need_upgrade) {
success = UpgradeNetAsNeeded(input_filename, &net_param);
Expand All @@ -43,13 +43,9 @@ int main(int argc, char** argv) {
LOG(ERROR) << "File already in latest proto format: " << input_filename;
}

if (need_data_upgrade) {
UpgradeNetDataTransformation(&net_param);
}

// Save new format prototxt.
WriteProtoToTextFile(net_param, argv[2]);

LOG(ERROR) << "Wrote upgraded NetParameter text proto to " << argv[2];
LOG(INFO) << "Wrote upgraded NetParameter text proto to " << argv[2];
return !success;
}
3 changes: 2 additions & 1 deletion tools/upgrade_solver_proto_text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using std::ofstream;
using namespace caffe; // NOLINT(build/namespaces)

int main(int argc, char** argv) {
FLAGS_alsologtostderr = 1; // Print output to stderr (while still logging)
::google::InitGoogleLogging(argv[0]);
if (argc != 3) {
LOG(ERROR) << "Usage: upgrade_solver_proto_text "
Expand Down Expand Up @@ -45,6 +46,6 @@ int main(int argc, char** argv) {
// Save new format prototxt.
WriteProtoToTextFile(solver_param, argv[2]);

LOG(ERROR) << "Wrote upgraded SolverParameter text proto to " << argv[2];
LOG(INFO) << "Wrote upgraded SolverParameter text proto to " << argv[2];
return !success;
}