Skip to content

Commit e89db1c

Browse files
authored
Fix "Unexpected number of outputs after override_all_outputs" (#9454)
1 parent f255c19 commit e89db1c

File tree

3 files changed

+173
-20
lines changed

3 files changed

+173
-20
lines changed

src/bindings/python/tests/test_frontend/test_frontend_onnx_editor.py

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@
4949
# | |
5050
# out1 out2
5151
#
52+
#
53+
# ------Test input model 3------
54+
# in1 in2
55+
# | / \
56+
# +--------+ +------+
57+
# | Add | | Relu |
58+
# +--------+ +------+
59+
# | |
60+
# out1 out2
61+
#
5262
def create_test_onnx_models():
5363
models = {}
5464
# Input model 1
@@ -91,6 +101,23 @@ def create_test_onnx_models():
91101
models["input_model_2.onnx"] = make_model(graph, producer_name="ONNX Importer",
92102
opset_imports=[onnx.helper.make_opsetid("", 13)])
93103

104+
# Input model 3
105+
add_2 = onnx.helper.make_node("Add", inputs=["in1", "in2"], outputs=["out1"], name="onnx_add_op")
106+
relu_2 = onnx.helper.make_node("Relu", inputs=["in2"], outputs=["out2"])
107+
108+
input_tensors = [
109+
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
110+
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
111+
]
112+
output_tensors = [
113+
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
114+
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
115+
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (2, 2)),
116+
]
117+
graph = make_graph([add_2, relu_2], "test_graph_3", input_tensors, output_tensors)
118+
models["input_model_3.onnx"] = make_model(graph, producer_name="ONNX Importer",
119+
opset_imports=[onnx.helper.make_opsetid("", 13)])
120+
94121
# Expected for extract_subgraph
95122
input_tensors = [
96123
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
@@ -188,6 +215,19 @@ def create_test_onnx_models():
188215
models["test_override_all_outputs_2.onnx"] = make_model(graph, producer_name="ONNX Importer",
189216
opset_imports=[onnx.helper.make_opsetid("", 13)])
190217

218+
# Expected for test_override_all_outputs 3
219+
input_tensors = [
220+
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
221+
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
222+
]
223+
output_tensors = [
224+
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
225+
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
226+
]
227+
graph = make_graph([add_2], "test_graph_3", input_tensors, output_tensors)
228+
models["test_override_all_outputs_3.onnx"] = make_model(graph, producer_name="ONNX Importer",
229+
opset_imports=[onnx.helper.make_opsetid("", 13)])
230+
191231
# Expected for test_override_all_inputs
192232
input_tensors = [
193233
make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (2, 2)),
@@ -594,6 +634,50 @@ def test_override_all_outputs_2():
594634
assert res
595635

596636

637+
def test_override_all_outputs_3():
638+
skip_if_onnx_frontend_is_disabled()
639+
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
640+
assert fe
641+
642+
model = fe.load("input_model_3.onnx")
643+
assert model
644+
645+
place1 = model.get_place_by_tensor_name(tensor_name="out1")
646+
place2 = model.get_place_by_tensor_name(tensor_name="out1")
647+
model.override_all_outputs(outputs=[place1, place2])
648+
result_func = fe.convert(model)
649+
650+
expected_model = fe.load("test_override_all_outputs_3.onnx")
651+
expected_func = fe.convert(expected_model)
652+
653+
res = compare_functions(result_func, expected_func)
654+
assert res
655+
656+
657+
def test_override_all_outputs_invalid_place():
658+
skip_if_onnx_frontend_is_disabled()
659+
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
660+
assert fe
661+
662+
model = fe.load("input_model_3.onnx")
663+
assert model
664+
665+
model2 = fe.load("input_model.onnx")
666+
assert model2
667+
invalid_place = model2.get_place_by_tensor_name(tensor_name="out3")
668+
669+
place1 = model.get_place_by_tensor_name(tensor_name="out1")
670+
place2 = model.get_place_by_tensor_name(tensor_name="out1")
671+
model.override_all_outputs(outputs=[place1, place2, invalid_place])
672+
result_func = fe.convert(model)
673+
674+
expected_model = fe.load("test_override_all_outputs_3.onnx")
675+
expected_func = fe.convert(expected_model)
676+
677+
res = compare_functions(result_func, expected_func)
678+
assert res
679+
680+
597681
def test_override_all_inputs():
598682
skip_if_onnx_frontend_is_disabled()
599683
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
@@ -618,26 +702,31 @@ def test_override_all_inputs():
618702
assert res
619703

620704

621-
def test_override_all_inputs_exceptions():
705+
def test_override_all_inputs_invalid_place():
622706
skip_if_onnx_frontend_is_disabled()
623707
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
624708
assert fe
625709

626-
model = fe.load("input_model.onnx")
710+
model = fe.load("input_model_3.onnx")
627711
assert model
628712

629-
place1 = model.get_place_by_tensor_name(tensor_name="in1")
630-
place2 = model.get_place_by_tensor_name(tensor_name="in2")
631-
place3 = model.get_place_by_operation_name_and_input_port(operation_name="split1", input_port_index=0)
632-
place4 = model.get_place_by_tensor_name(tensor_name="in3")
713+
model2 = fe.load("input_model.onnx")
714+
assert model2
633715

634-
with pytest.raises(Exception) as e:
635-
model.override_all_inputs(inputs=[place1, place2])
636-
assert "Unexpected number of inputs after override_all_inputs" in str(e)
716+
out3_tensor = model2.get_place_by_tensor_name(tensor_name="out3")
717+
invalid_place = out3_tensor.get_producing_operation().get_input_port(input_port_index=0)
637718

638-
with pytest.raises(Exception) as e:
639-
model.override_all_inputs(inputs=[place3, place4])
640-
assert "Unexpected number of inputs after override_all_inputs" in str(e)
719+
out1_tensor = model.get_place_by_tensor_name(tensor_name="out1")
720+
place1 = out1_tensor.get_producing_operation().get_input_port(input_port_index=0)
721+
place2 = out1_tensor.get_producing_operation().get_input_port(input_port_index=1)
722+
model.override_all_inputs(inputs=[place1, place2, invalid_place])
723+
result_func = fe.convert(model)
724+
725+
expected_model = fe.load("input_model_3.onnx")
726+
expected_func = fe.convert(expected_model)
727+
728+
res = compare_functions(result_func, expected_func)
729+
assert res
641730

642731

643732
def test_is_input_output():

src/frontends/onnx/frontend/src/input_model.cpp

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <openvino/frontend/exception.hpp>
88
#include <openvino/util/file_util.hpp>
99

10+
#include "ngraph/log.hpp"
1011
#include "place.hpp"
1112

1213
using namespace ov;
@@ -202,28 +203,90 @@ std::shared_ptr<Model> InputModel::convert() {
202203
}
203204

204205
// Editor features
206+
bool InputModel::is_correct_place(const ov::frontend::Place::Ptr& place) const {
207+
if (const auto tensor = std::dynamic_pointer_cast<PlaceTensor>(place)) {
208+
return m_editor->is_correct_tensor_name(tensor->get_names()[0]);
209+
}
210+
if (const auto op = std::dynamic_pointer_cast<PlaceOp>(place)) {
211+
return m_editor->is_correct_and_unambiguous_node(op->get_editor_node());
212+
}
213+
if (const auto input_edge = std::dynamic_pointer_cast<PlaceInputEdge>(place)) {
214+
if (auto tensor = std::dynamic_pointer_cast<PlaceTensor>(input_edge->get_source_tensor())) {
215+
return m_editor->is_correct_tensor_name(tensor->get_names()[0]);
216+
}
217+
}
218+
if (const auto output_edge = std::dynamic_pointer_cast<PlaceOutputEdge>(place)) {
219+
if (auto tensor = std::dynamic_pointer_cast<PlaceTensor>(output_edge->get_target_tensor())) {
220+
return m_editor->is_correct_tensor_name(tensor->get_names()[0]);
221+
}
222+
}
223+
return false;
224+
}
225+
205226
void InputModel::override_all_outputs(const std::vector<ov::frontend::Place::Ptr>& outputs) {
206-
extract_subgraph({}, outputs);
207-
NGRAPH_CHECK(m_editor->model_outputs().size() == outputs.size(),
208-
"Unexpected number of outputs after override_all_outputs");
209-
NGRAPH_CHECK(std::all_of(std::begin(outputs),
210-
std::end(outputs),
227+
std::vector<Place::Ptr> expected_valid_outputs;
228+
for (const auto& output : outputs) {
229+
bool is_correct = is_correct_place(output);
230+
if (!is_correct)
231+
NGRAPH_WARN << "Name " << output->get_names().at(0)
232+
<< " of output node is not a correct node name. Ignoring this parameter.";
233+
else
234+
expected_valid_outputs.push_back(output);
235+
}
236+
237+
extract_subgraph({}, expected_valid_outputs);
238+
239+
NGRAPH_CHECK(std::all_of(std::begin(expected_valid_outputs),
240+
std::end(expected_valid_outputs),
211241
[](const ov::frontend::Place::Ptr& place) {
212242
return place->is_output();
213243
}),
214244
"Not all provided arguments of override_all_outputs are new outputs of the model");
245+
246+
const auto current_outputs = get_outputs();
247+
NGRAPH_CHECK(std::all_of(std::begin(current_outputs),
248+
std::end(current_outputs),
249+
[&](const Place::Ptr& current_out) {
250+
return std::find_if(std::begin(expected_valid_outputs),
251+
std::end(expected_valid_outputs),
252+
[&](const Place::Ptr& expected_out) {
253+
return expected_out->is_equal(current_out);
254+
}) != std::end(current_outputs);
255+
}),
256+
"Some other than expected outputs were created during override_all_outputs");
215257
}
216258

217259
void InputModel::override_all_inputs(const std::vector<ov::frontend::Place::Ptr>& inputs) {
260+
std::vector<Place::Ptr> expected_valid_inputs;
261+
for (const auto& input : inputs) {
262+
bool is_correct = is_correct_place(input);
263+
if (!is_correct)
264+
NGRAPH_WARN << "Name " << input->get_names().at(0)
265+
<< " of input node is not a correct node. Ignoring this parameter.";
266+
else
267+
expected_valid_inputs.push_back(input);
268+
}
269+
218270
const auto outputs_before_extraction = m_editor->model_outputs();
219-
extract_subgraph({inputs}, {});
271+
extract_subgraph({expected_valid_inputs}, {});
272+
220273
NGRAPH_CHECK(std::equal(std::begin(outputs_before_extraction),
221274
std::end(outputs_before_extraction),
222275
std::begin(m_editor->model_outputs())),
223276
"All outputs should be preserved after override_all_inputs. Provided inputs does "
224277
"not satisfy all outputs");
225-
NGRAPH_CHECK(m_editor->model_inputs().size() == inputs.size(),
226-
"Unexpected number of inputs after override_all_inputs");
278+
279+
const auto current_inputs = get_inputs();
280+
NGRAPH_CHECK(std::all_of(std::begin(current_inputs),
281+
std::end(current_inputs),
282+
[&](const Place::Ptr& current_in) {
283+
return std::find_if(std::begin(expected_valid_inputs),
284+
std::end(expected_valid_inputs),
285+
[&](const Place::Ptr& expected_in) {
286+
return expected_in->is_equal(current_in);
287+
}) != std::end(current_inputs);
288+
}),
289+
"Some other than expected inputs were created during override_all_inputs");
227290
}
228291

229292
void InputModel::extract_subgraph(const std::vector<ov::frontend::Place::Ptr>& inputs,

src/frontends/onnx/frontend/src/input_model.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class InputModel : public ov::frontend::InputModel {
7878

7979
private:
8080
std::shared_ptr<ov::onnx_editor::ONNXModelEditor> m_editor;
81+
bool is_correct_place(const ov::frontend::Place::Ptr& place) const;
8182

8283
std::unordered_map<std::string, std::unordered_set<std::string>> m_additional_tensor_names;
8384
void add_tensor_names(std::shared_ptr<Model>& model);

0 commit comments

Comments
 (0)