1
+ // Copyright (C) 2018-2025 Intel Corporation
2
+ // SPDX-License-Identifier: Apache-2.0
3
+ //
4
+
5
+ #include " openvino/pass/sdpa_to_vlsdpa.hpp"
6
+
7
+ #include < gtest/gtest.h>
8
+
9
+ #include " common_test_utils/ov_test_utils.hpp"
10
+ #include " openvino/core/model.hpp"
11
+ #include " openvino/opsets/opset13.hpp"
12
+ #include " openvino/pass/manager.hpp"
13
+ #include " openvino/runtime/core.hpp"
14
+ #include " ov_ops/vl_sdpa.hpp"
15
+
16
+ using namespace std ;
17
+ using namespace ov ;
18
+ using namespace ov ::opset13;
19
+
20
+ namespace {
21
+ std::shared_ptr<ov::Model> build_model (const string& mask_name) {
22
+ auto q = std::make_shared<Parameter>(element::f32 , PartialShape{-1 , 8 , 32 }); /* L,H,S */
23
+ auto k = std::make_shared<Parameter>(element::f32 , PartialShape{-1 , 8 , 32 });
24
+ auto v = std::make_shared<Parameter>(element::f32 , PartialShape{-1 , 8 , 32 });
25
+
26
+ q->set_friendly_name (" q" );
27
+ k->set_friendly_name (" k" );
28
+ v->set_friendly_name (" v" );
29
+
30
+ auto transpose_q = std::make_shared<Transpose>(q, Constant::create (element::i64 , Shape{3 }, {1 , 0 , 2 }));
31
+ auto transpose_k = std::make_shared<Transpose>(k, Constant::create (element::i64 , Shape{3 }, {1 , 0 , 2 }));
32
+ auto transpose_v = std::make_shared<Transpose>(v, Constant::create (element::i64 , Shape{3 }, {1 , 0 , 2 }));
33
+ transpose_q->set_friendly_name (" transpose_q" );
34
+ transpose_k->set_friendly_name (" transpose_k" );
35
+ transpose_v->set_friendly_name (" transpose_v" );
36
+
37
+ auto mask = std::make_shared<Parameter>(element::f32 , PartialShape{1 , -1 , -1 });
38
+ mask->set_friendly_name (mask_name);
39
+ mask->get_output_tensor (0 ).set_names ({mask_name});
40
+
41
+ const auto casual = false ;
42
+
43
+ auto sdpa =
44
+ std::make_shared<opset13::ScaledDotProductAttention>(transpose_q, transpose_k, transpose_v, mask, casual);
45
+ sdpa->set_friendly_name (" sdpa" );
46
+
47
+ auto transpose_o = std::make_shared<Transpose>(sdpa, Constant::create (element::i64 , Shape{3 }, {1 , 0 , 2 }));
48
+ transpose_o->set_friendly_name (" transpose_o" );
49
+
50
+ return std::make_shared<ov::Model>(OutputVector{transpose_o}, ParameterVector{q, k, v, mask});
51
+ }
52
+
53
+ std::shared_ptr<ov::Model> build_target_model (const string& mask_name) {
54
+ auto q = std::make_shared<Parameter>(element::f32 , PartialShape{-1 , 8 , 32 }); /* L,H,S */
55
+ auto k = std::make_shared<Parameter>(element::f32 , PartialShape{-1 , 8 , 32 });
56
+ auto v = std::make_shared<Parameter>(element::f32 , PartialShape{-1 , 8 , 32 });
57
+ q->set_friendly_name (" q" );
58
+ k->set_friendly_name (" k" );
59
+ v->set_friendly_name (" v" );
60
+
61
+ auto transpose_q = std::make_shared<Transpose>(q, Constant::create (element::i64 , Shape{3 }, {1 , 0 , 2 }));
62
+ auto transpose_k = std::make_shared<Transpose>(k, Constant::create (element::i64 , Shape{3 }, {1 , 0 , 2 }));
63
+ auto transpose_v = std::make_shared<Transpose>(v, Constant::create (element::i64 , Shape{3 }, {1 , 0 , 2 }));
64
+ transpose_q->set_friendly_name (" transpose_q" );
65
+ transpose_k->set_friendly_name (" transpose_k" );
66
+ transpose_v->set_friendly_name (" transpose_v" );
67
+
68
+ auto cuseq_mask = std::make_shared<Parameter>(element::i32 , PartialShape{-1 });
69
+ cuseq_mask->set_friendly_name (mask_name);
70
+ cuseq_mask->get_output_tensor (0 ).set_names ({mask_name});
71
+
72
+ auto vlsdpa =
73
+ std::make_shared<ov::op::internal::VLSDPA>(OutputVector{transpose_q, transpose_k, transpose_v, cuseq_mask});
74
+
75
+ auto transpose_o = std::make_shared<Transpose>(vlsdpa, Constant::create (element::i64 , Shape{3 }, {1 , 0 , 2 }));
76
+ transpose_o->set_friendly_name (" transpose_o" );
77
+
78
+ return std::make_shared<ov::Model>(OutputVector{transpose_o}, ParameterVector{q, k, v, cuseq_mask});
79
+ }
80
+ }; // namespace
81
+
82
+ TEST_F (TransformationTestsF, SDPA2VLSDPAAttentionMaskTest) {
83
+ disable_rt_info_check ();
84
+ {
85
+ model = build_model (" attention_mask" );
86
+ model->set_rt_info (" QWenVL" , " model_type_hint" ); // request_vl_sdpa_transformations
87
+ manager.register_pass <ov::pass::SDPAToVLSDPA>();
88
+ }
89
+ { model_ref = build_target_model (" cu_seq_lens" ); }
90
+
91
+ comparator.enable (FunctionsComparator::CmpValues::CONST_VALUES);
92
+ comparator.enable (FunctionsComparator::CmpValues::ATTRIBUTES);
93
+ comparator.enable (FunctionsComparator::CmpValues::NAMES);
94
+ }
95
+
96
+ TEST_F (TransformationTestsF, SDPA2VLSDPAWindowAttentionMaskTest) {
97
+ disable_rt_info_check ();
98
+ {
99
+ model = build_model (" window_attention_mask" );
100
+ model->set_rt_info (" QWenVL" , " model_type_hint" ); // request_vl_sdpa_transformations
101
+ manager.register_pass <ov::pass::SDPAToVLSDPA>();
102
+ }
103
+ { model_ref = build_target_model (" cu_window_seqlens" ); }
104
+
105
+ comparator.enable (FunctionsComparator::CmpValues::CONST_VALUES);
106
+ comparator.enable (FunctionsComparator::CmpValues::ATTRIBUTES);
107
+ comparator.enable (FunctionsComparator::CmpValues::NAMES);
108
+ }
0 commit comments