Skip to content

Commit 20efbd5

Browse files
cwzradAsyaProninadmatveev
authored
[NPUW] Support generate more than 1 token per inference (#31578)
### Details: - For KV cache model, support generate more than 1 token per inference, which is needed by speculative decoding. - Also update the KV according to the position id for fast draft. Basically if we already saved 20 KV cache, then the next position ID should be 20. Assume in this case we have 3 token inputs, the position id should be [20, 21, 22], after inference, we saved 3 more KV cache, it becomes 23. But after verification in application side, we find the 22 is not a correct token, then for next inference the position id is [22, 23, 24], the position id only increase 2. Then we know in previous inference, the last KV cache is a dirty one. - *...* ### Tickets: - *ticket-id* https://jira.devtools.intel.com/browse/CVS-172014 --------- Signed-off-by: wenzengc <[email protected]> Co-authored-by: Anastasiya Pronina <[email protected]> Co-authored-by: Dmitry Matveev <[email protected]>
1 parent 390790e commit 20efbd5

File tree

11 files changed

+142
-47
lines changed

11 files changed

+142
-47
lines changed

src/plugins/intel_npu/src/al/include/intel_npu/config/npuw.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ DEFINE_OPT(NPUW_LLM, bool, false, npuw::llm::enabled, RunTime);
138138
DEFINE_OPT(NPUW_LLM_BATCH_DIM, uint32_t, 0, npuw::llm::batch_dim, RunTime);
139139
DEFINE_OPT(NPUW_LLM_SEQ_LEN_DIM, uint32_t, 2, npuw::llm::seq_len_dim, RunTime);
140140
DEFINE_OPT(NPUW_LLM_MAX_PROMPT_LEN, uint32_t, 1024, npuw::llm::max_prompt_len, RunTime);
141+
DEFINE_OPT(NPUW_LLM_MAX_GENERATION_TOKEN_LEN, uint32_t, 1, npuw::llm::max_generation_token_len, RunTime);
141142
DEFINE_OPT(NPUW_LLM_MIN_RESPONSE_LEN, uint32_t, 128, npuw::llm::min_response_len, RunTime);
142143
DEFINE_OPT(NPUW_LLM_OPTIMIZE_V_TENSORS, bool, true, npuw::llm::optimize_v_tensors, RunTime);
143144
DEFINE_OPT(NPUW_LLM_CACHE_ROPE, bool, true, npuw::llm::cache_rope, RunTime);

src/plugins/intel_npu/src/al/include/intel_npu/npuw_private_properties.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,14 @@ static constexpr ov::Property<uint32_t> seq_len_dim{"NPUW_LLM_SEQ_LEN_DIM"};
423423
*/
424424
static constexpr ov::Property<uint32_t> max_prompt_len{"NPUW_LLM_MAX_PROMPT_LEN"};
425425

426+
/**
427+
+ * @brief
428+
+ * Type: uint32_t.
429+
+ * Desirable max input token length for generation.
430+
+ * Default value: 1.
431+
+ */
432+
static constexpr ov::Property<uint32_t> max_generation_token_len{"NPUW_LLM_MAX_GENERATION_TOKEN_LEN"};
433+
426434
/**
427435
* @brief
428436
* Type: uint32_t.

src/plugins/intel_npu/src/al/src/config/npuw.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ void intel_npu::registerNPUWLLMOptions(OptionsDesc& desc) {
6666
desc.add<NPUW_LLM_OPTIMIZE_V_TENSORS>();
6767
desc.add<NPUW_LLM_CACHE_ROPE>();
6868
desc.add<NPUW_LLM_PREFILL_CHUNK_SIZE>();
69+
desc.add<NPUW_LLM_MAX_GENERATION_TOKEN_LEN>();
6970
desc.add<NPUW_LLM_PREFILL_HINT>();
7071
desc.add<NPUW_LLM_GENERATE_HINT>();
7172
desc.add<NPUW_LLM_SHARED_HEAD>();

src/plugins/intel_npu/src/plugin/include/properties.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class Properties final {
111111
ov::intel_npu::npuw::llm::batch_dim.name(),
112112
ov::intel_npu::npuw::llm::seq_len_dim.name(),
113113
ov::intel_npu::npuw::llm::max_prompt_len.name(),
114+
ov::intel_npu::npuw::llm::max_generation_token_len.name(),
114115
ov::intel_npu::npuw::llm::min_response_len.name(),
115116
ov::intel_npu::npuw::llm::optimize_v_tensors.name(),
116117
ov::intel_npu::npuw::llm::cache_rope.name(),

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -499,10 +499,13 @@ void reshape_to_static(std::shared_ptr<ov::Model> model,
499499
model->reshape(new_shapes);
500500
}
501501

502-
void reshape_sliced_head_to_static(std::shared_ptr<ov::Model> lm_head_model, const uint32_t& batch_dim) {
503-
// We have only one input with dynamic shapes: output of Slice operation, and this output
504-
// should have "1" for dimension representing number of embeddings to send to the matmul.
505-
// Batch size should be also equal "1" for NPU.
502+
void reshape_sliced_head_to_static(std::shared_ptr<ov::Model> lm_head_model,
503+
const uint32_t& batch_dim,
504+
std::size_t max_generation_token_len) {
505+
// We have only one input with dynamic shapes: output embeds.
506+
// Output embeds should have "max_generation_token_len" for dimension representing
507+
// number of embeddings to send to the matmul. Batch size should be equal to "1"
508+
// for NPU.
506509
const auto& input = lm_head_model->input(0);
507510
const auto& partial_shape = input.get_partial_shape();
508511
NPUW_ASSERT(partial_shape.size() == 3);
@@ -512,7 +515,7 @@ void reshape_sliced_head_to_static(std::shared_ptr<ov::Model> lm_head_model, con
512515
// Left dynamic axis will be for number of embeddings
513516
for (auto i = 0; i < new_shape.rank().get_length(); i++) {
514517
if (new_shape[i].is_dynamic()) {
515-
new_shape[i] = 1;
518+
new_shape[i] = max_generation_token_len;
516519
// Sanity check that only one left dimension is dynamic, as
517520
// another one should contain embedding space rank
518521
break;
@@ -522,7 +525,9 @@ void reshape_sliced_head_to_static(std::shared_ptr<ov::Model> lm_head_model, con
522525
lm_head_model->reshape(new_shape);
523526
}
524527

525-
void slice_out_embeds(std::shared_ptr<ov::Model> model, const uint32_t& batch_dim) {
528+
void slice_out_embeds(std::shared_ptr<ov::Model> model,
529+
const uint32_t& batch_dim,
530+
std::size_t max_generation_token_len) {
526531
std::shared_ptr<ov::Node> embed_result;
527532
for (auto&& output : model->outputs()) {
528533
if (output.get_any_name() == ov::npuw::LLMCompiledModel::output_embeds) {
@@ -533,15 +538,16 @@ void slice_out_embeds(std::shared_ptr<ov::Model> model, const uint32_t& batch_di
533538
if (embed_result) {
534539
auto shape = embed_result->input(0).get_shape();
535540
// If shape.size() is 3, then last axis should be the Vocab size.
536-
// But 1st and 2nd axis can mean different things.
541+
// But 1st and 2nd axes can mean different things.
537542
// 1st axis can represent the batch size, while 2nd - the number of embeddings,
538543
// or vice-versa (in chatglm)
539544
if (shape.size() == 3) {
540545
uint32_t num_embeds_dim = 1 - batch_dim;
541-
if (shape[num_embeds_dim] > 1) {
542-
std::vector<int32_t> start_pos{static_cast<int32_t>(batch_dim * (shape[num_embeds_dim] - 1)),
543-
static_cast<int32_t>(num_embeds_dim * (shape[num_embeds_dim] - 1)),
544-
0};
546+
if (shape[num_embeds_dim] > max_generation_token_len) {
547+
std::vector<int32_t> start_pos{
548+
static_cast<int32_t>(batch_dim * (shape[num_embeds_dim] - max_generation_token_len)),
549+
static_cast<int32_t>(num_embeds_dim * (shape[num_embeds_dim] - max_generation_token_len)),
550+
0};
545551
std::vector<int32_t> stop_pos{static_cast<int32_t>(batch_dim * (shape[num_embeds_dim] - 1)) + 1,
546552
static_cast<int32_t>(num_embeds_dim * (shape[num_embeds_dim] - 1)) + 1,
547553
static_cast<int32_t>(shape[2])};
@@ -673,6 +679,9 @@ ov::AnyMap get_default_generate_config(const std::optional<NPUDesc>& npudesc,
673679
if (hint == ::intel_npu::npuw::llm::GenerateHint::FAST_COMPILE) {
674680
config.emplace("NPUW_UNFOLD_IREQS", "YES");
675681
}
682+
// We don't need slice out for kv cache model, especially for speculative decoding which need
683+
// to generate more than 1 token for each inference
684+
config.erase("NPUW_SLICE_OUT");
676685
return config;
677686
}
678687

@@ -849,6 +858,10 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
849858
KVAxesPosition axes{batch_dim, seq_len_dim};
850859
uint32_t max_prompt_len = align_to(m_cfg.get<::intel_npu::NPUW_LLM_MAX_PROMPT_LEN>(), 64u);
851860
const uint32_t min_response_len = align_to(m_cfg.get<::intel_npu::NPUW_LLM_MIN_RESPONSE_LEN>(), 64u);
861+
uint32_t max_generation_token_len = m_cfg.get<::intel_npu::NPUW_LLM_MAX_GENERATION_TOKEN_LEN>();
862+
if (max_generation_token_len != 1) {
863+
max_generation_token_len = align_to(max_generation_token_len, 8u);
864+
}
852865

853866
// If chunk size covers the entire prompt, just follow the static behavior.
854867
// Otherwise, use chunking and align the prompt size to the chunk size.
@@ -872,7 +885,9 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
872885
LOG_VERB("Prefill chunk size: " << m_prefill_chunk_size);
873886
LOG_VERB("Maximum prompt length: " << max_prompt_len);
874887

875-
m_kvcache_desc = KVCacheDesc{max_prompt_len, max_prompt_len + min_response_len, 0u, seq_len_dim};
888+
m_kvcache_desc =
889+
KVCacheDesc{max_prompt_len, max_prompt_len + min_response_len, 0u, seq_len_dim, max_generation_token_len};
890+
876891
LOG_DEBUG("Make prefill model with static shapes");
877892
m_max_lora_rank = m_cfg.get<::intel_npu::NPUW_LLM_MAX_LORA_RANK>();
878893
if (m_use_chunk_prefill) {
@@ -889,14 +904,18 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
889904
m_max_lora_rank);
890905
}
891906
LOG_DEBUG("Make kvcache model with static shapes");
892-
reshape_to_static(kvcache_model, 1u, m_kvcache_desc.total_size, axes, m_max_lora_rank);
907+
reshape_to_static(kvcache_model,
908+
m_kvcache_desc.max_generation_token_len,
909+
m_kvcache_desc.total_size,
910+
axes,
911+
m_max_lora_rank);
893912
if (lm_head_model) {
894913
LOG_DEBUG("Shared LM head: slice the prefill output");
895-
// KVCache model is already reshaped to [1, 1, embed size], so only apply slice to
896-
// the Prefill model:
897-
slice_out_embeds(prefill_model, axes.batch);
914+
// KVCache model is already reshaped to [1, max_generation_token_len, embed size],
915+
// so only apply slice to the Prefill model:
916+
slice_out_embeds(prefill_model, axes.batch, m_kvcache_desc.max_generation_token_len);
898917
LOG_DEBUG("Make LM head model with static shapes");
899-
reshape_sliced_head_to_static(lm_head_model, axes.batch);
918+
reshape_sliced_head_to_static(lm_head_model, axes.batch, m_kvcache_desc.max_generation_token_len);
900919
}
901920

902921
LOG_DEBUG("5.1, decompose GroupQueryAttention OP");
@@ -1089,6 +1108,7 @@ void ov::npuw::LLMCompiledModel::serialize(std::ostream& stream, const ov::npuw:
10891108
write(model_stream, m_kvcache_desc.total_size);
10901109
write(model_stream, m_kvcache_desc.num_stored_tokens);
10911110
write(model_stream, m_kvcache_desc.dim);
1111+
write(model_stream, m_kvcache_desc.max_generation_token_len);
10921112
write(model_stream, m_kvcache_desc.v_tensors_transposed);
10931113
write(model_stream, m_prefill_chunk_size);
10941114
write(model_stream, m_use_chunk_prefill);
@@ -1297,6 +1317,7 @@ std::shared_ptr<ov::npuw::LLMCompiledModel> ov::npuw::LLMCompiledModel::deserial
12971317
read(model_stream, compiled->m_kvcache_desc.total_size);
12981318
read(model_stream, compiled->m_kvcache_desc.num_stored_tokens);
12991319
read(model_stream, compiled->m_kvcache_desc.dim);
1320+
read(model_stream, compiled->m_kvcache_desc.max_generation_token_len);
13001321
read(model_stream, compiled->m_kvcache_desc.v_tensors_transposed);
13011322
read(model_stream, compiled->m_prefill_chunk_size);
13021323
read(model_stream, compiled->m_use_chunk_prefill);

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class LLMCompiledModel : public ov::npuw::ICompiledModel {
2424
uint32_t total_size = 0u;
2525
uint32_t num_stored_tokens = 0u;
2626
uint32_t dim = 0u;
27+
uint32_t max_generation_token_len = 0u;
2728
bool v_tensors_transposed = false;
2829
};
2930

0 commit comments

Comments
 (0)