Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/frontends/onnx/frontend/src/op/identity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,41 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/identity.hpp"

#include "core/operator_set.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils/common.hpp"

using namespace ov::op;

namespace ov {
namespace frontend {
namespace onnx {
namespace ai_onnx {
namespace opset_1 {
ov::OutputVector identity(const ov::frontend::onnx::Node& node) {
ov::OutputVector outputs = node.get_ov_inputs();
for (auto& out : outputs) {
common::mark_as_optimized_out(out);
// This operator will be optimized out in EliminateSlice pass
// Need this to avoid data not being copied out when Identity connects from input to result
// in some cases like:
// Input->Identity->Result
ov::Output<ov::Node> input = node.get_ov_inputs().at(0);
const auto& start = v0::Constant::create(element::i64, {1}, {0});
auto input_shape = input.get_partial_shape();
bool need_squeeze = (input_shape.rank().is_dynamic() || input_shape.rank().get_length() == 0);
const auto& end = v0::Constant::create(element::i64, {1}, {std::numeric_limits<int64_t>::max()});
const auto& step = v0::Constant::create(element::i64, {1}, {1});
if (need_squeeze) {
input = std::make_shared<v0::Unsqueeze>(input, v0::Constant::create(element::i64, {1}, {0}));
}
ov::Output<ov::Node> output = std::make_shared<v8::Slice>(input, start, end, step);
if (need_squeeze) {
output = std::make_shared<v15::Squeeze>(output, v0::Constant::create(element::i64, {1}, {0}));
}
return outputs;
return {output};
}
ONNX_OP("Identity", OPSET_SINCE(1), ai_onnx::opset_1::identity);
} // namespace opset_1
Expand Down
4 changes: 1 addition & 3 deletions src/frontends/onnx/frontend/src/op/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ ov::OutputVector loop(const ov::frontend::onnx::Node& node) {
") is not greater than number of outputs. Required at least: ",
loop_carried_dependencies.size() + 1);

ov::ParameterVector body_params(body_inputs.begin() + 2, body_inputs.end());
body_params.emplace(body_params.begin(),
body_inputs[0]); // current iteration body input
ov::ParameterVector body_params(body_inputs.begin(), body_inputs.end());
const auto body = std::make_shared<ov::Model>(body_outputs, body_params);
auto loop = std::make_shared<v5::Loop>(trip_count, termination_cond);
v5::Loop::SpecialBodyPorts spec_ports{0, 0};
Expand Down
Loading