49
49
# | |
50
50
# out1 out2
51
51
#
52
+ #
53
+ # ------Test input model 3------
54
+ # in1 in2
55
+ # | / \
56
+ # +--------+ +------+
57
+ # | Add | | Relu |
58
+ # +--------+ +------+
59
+ # | |
60
+ # out1 out2
61
+ #
52
62
def create_test_onnx_models ():
53
63
models = {}
54
64
# Input model 1
@@ -91,6 +101,23 @@ def create_test_onnx_models():
91
101
models ["input_model_2.onnx" ] = make_model (graph , producer_name = "ONNX Importer" ,
92
102
opset_imports = [onnx .helper .make_opsetid ("" , 13 )])
93
103
104
+ # Input model 3
105
+ add_2 = onnx .helper .make_node ("Add" , inputs = ["in1" , "in2" ], outputs = ["out1" ], name = "onnx_add_op" )
106
+ relu_2 = onnx .helper .make_node ("Relu" , inputs = ["in2" ], outputs = ["out2" ])
107
+
108
+ input_tensors = [
109
+ make_tensor_value_info ("in1" , onnx .TensorProto .FLOAT , (2 , 2 )),
110
+ make_tensor_value_info ("in2" , onnx .TensorProto .FLOAT , (2 , 2 )),
111
+ ]
112
+ output_tensors = [
113
+ make_tensor_value_info ("out1" , onnx .TensorProto .FLOAT , (2 , 2 )),
114
+ make_tensor_value_info ("out1" , onnx .TensorProto .FLOAT , (2 , 2 )),
115
+ make_tensor_value_info ("out2" , onnx .TensorProto .FLOAT , (2 , 2 )),
116
+ ]
117
+ graph = make_graph ([add_2 , relu_2 ], "test_graph_3" , input_tensors , output_tensors )
118
+ models ["input_model_3.onnx" ] = make_model (graph , producer_name = "ONNX Importer" ,
119
+ opset_imports = [onnx .helper .make_opsetid ("" , 13 )])
120
+
94
121
# Expected for extract_subgraph
95
122
input_tensors = [
96
123
make_tensor_value_info ("in1" , onnx .TensorProto .FLOAT , (2 , 2 )),
@@ -188,6 +215,19 @@ def create_test_onnx_models():
188
215
models ["test_override_all_outputs_2.onnx" ] = make_model (graph , producer_name = "ONNX Importer" ,
189
216
opset_imports = [onnx .helper .make_opsetid ("" , 13 )])
190
217
218
+ # Expected for test_override_all_outputs 3
219
+ input_tensors = [
220
+ make_tensor_value_info ("in1" , onnx .TensorProto .FLOAT , (2 , 2 )),
221
+ make_tensor_value_info ("in2" , onnx .TensorProto .FLOAT , (2 , 2 )),
222
+ ]
223
+ output_tensors = [
224
+ make_tensor_value_info ("out1" , onnx .TensorProto .FLOAT , (2 , 2 )),
225
+ make_tensor_value_info ("out1" , onnx .TensorProto .FLOAT , (2 , 2 )),
226
+ ]
227
+ graph = make_graph ([add_2 ], "test_graph_3" , input_tensors , output_tensors )
228
+ models ["test_override_all_outputs_3.onnx" ] = make_model (graph , producer_name = "ONNX Importer" ,
229
+ opset_imports = [onnx .helper .make_opsetid ("" , 13 )])
230
+
191
231
# Expected for test_override_all_inputs
192
232
input_tensors = [
193
233
make_tensor_value_info ("in3" , onnx .TensorProto .FLOAT , (2 , 2 )),
@@ -594,6 +634,50 @@ def test_override_all_outputs_2():
594
634
assert res
595
635
596
636
637
+ def test_override_all_outputs_3 ():
638
+ skip_if_onnx_frontend_is_disabled ()
639
+ fe = fem .load_by_framework (framework = ONNX_FRONTEND_NAME )
640
+ assert fe
641
+
642
+ model = fe .load ("input_model_3.onnx" )
643
+ assert model
644
+
645
+ place1 = model .get_place_by_tensor_name (tensor_name = "out1" )
646
+ place2 = model .get_place_by_tensor_name (tensor_name = "out1" )
647
+ model .override_all_outputs (outputs = [place1 , place2 ])
648
+ result_func = fe .convert (model )
649
+
650
+ expected_model = fe .load ("test_override_all_outputs_3.onnx" )
651
+ expected_func = fe .convert (expected_model )
652
+
653
+ res = compare_functions (result_func , expected_func )
654
+ assert res
655
+
656
+
657
+ def test_override_all_outputs_invalid_place ():
658
+ skip_if_onnx_frontend_is_disabled ()
659
+ fe = fem .load_by_framework (framework = ONNX_FRONTEND_NAME )
660
+ assert fe
661
+
662
+ model = fe .load ("input_model_3.onnx" )
663
+ assert model
664
+
665
+ model2 = fe .load ("input_model.onnx" )
666
+ assert model2
667
+ invalid_place = model2 .get_place_by_tensor_name (tensor_name = "out3" )
668
+
669
+ place1 = model .get_place_by_tensor_name (tensor_name = "out1" )
670
+ place2 = model .get_place_by_tensor_name (tensor_name = "out1" )
671
+ model .override_all_outputs (outputs = [place1 , place2 , invalid_place ])
672
+ result_func = fe .convert (model )
673
+
674
+ expected_model = fe .load ("test_override_all_outputs_3.onnx" )
675
+ expected_func = fe .convert (expected_model )
676
+
677
+ res = compare_functions (result_func , expected_func )
678
+ assert res
679
+
680
+
597
681
def test_override_all_inputs ():
598
682
skip_if_onnx_frontend_is_disabled ()
599
683
fe = fem .load_by_framework (framework = ONNX_FRONTEND_NAME )
@@ -618,26 +702,31 @@ def test_override_all_inputs():
618
702
assert res
619
703
620
704
621
- def test_override_all_inputs_exceptions ():
705
+ def test_override_all_inputs_invalid_place ():
622
706
skip_if_onnx_frontend_is_disabled ()
623
707
fe = fem .load_by_framework (framework = ONNX_FRONTEND_NAME )
624
708
assert fe
625
709
626
- model = fe .load ("input_model .onnx" )
710
+ model = fe .load ("input_model_3 .onnx" )
627
711
assert model
628
712
629
- place1 = model .get_place_by_tensor_name (tensor_name = "in1" )
630
- place2 = model .get_place_by_tensor_name (tensor_name = "in2" )
631
- place3 = model .get_place_by_operation_name_and_input_port (operation_name = "split1" , input_port_index = 0 )
632
- place4 = model .get_place_by_tensor_name (tensor_name = "in3" )
713
+ model2 = fe .load ("input_model.onnx" )
714
+ assert model2
633
715
634
- with pytest .raises (Exception ) as e :
635
- model .override_all_inputs (inputs = [place1 , place2 ])
636
- assert "Unexpected number of inputs after override_all_inputs" in str (e )
716
+ out3_tensor = model2 .get_place_by_tensor_name (tensor_name = "out3" )
717
+ invalid_place = out3_tensor .get_producing_operation ().get_input_port (input_port_index = 0 )
637
718
638
- with pytest .raises (Exception ) as e :
639
- model .override_all_inputs (inputs = [place3 , place4 ])
640
- assert "Unexpected number of inputs after override_all_inputs" in str (e )
719
+ out1_tensor = model .get_place_by_tensor_name (tensor_name = "out1" )
720
+ place1 = out1_tensor .get_producing_operation ().get_input_port (input_port_index = 0 )
721
+ place2 = out1_tensor .get_producing_operation ().get_input_port (input_port_index = 1 )
722
+ model .override_all_inputs (inputs = [place1 , place2 , invalid_place ])
723
+ result_func = fe .convert (model )
724
+
725
+ expected_model = fe .load ("input_model_3.onnx" )
726
+ expected_func = fe .convert (expected_model )
727
+
728
+ res = compare_functions (result_func , expected_func )
729
+ assert res
641
730
642
731
643
732
def test_is_input_output ():
0 commit comments