@@ -499,10 +499,13 @@ void reshape_to_static(std::shared_ptr<ov::Model> model,
499
499
model->reshape (new_shapes);
500
500
}
501
501
502
- void reshape_sliced_head_to_static (std::shared_ptr<ov::Model> lm_head_model, const uint32_t & batch_dim) {
503
- // We have only one input with dynamic shapes: output of Slice operation, and this output
504
- // should have "1" for dimension representing number of embeddings to send to the matmul.
505
- // Batch size should be also equal "1" for NPU.
502
+ void reshape_sliced_head_to_static (std::shared_ptr<ov::Model> lm_head_model,
503
+ const uint32_t & batch_dim,
504
+ std::size_t max_generation_token_len) {
505
+ // We have only one input with dynamic shapes: output embeds.
506
+ // Output embeds should have "max_generation_token_len" for dimension representing
507
+ // number of embeddings to send to the matmul. Batch size should be equal to "1"
508
+ // for NPU.
506
509
const auto & input = lm_head_model->input (0 );
507
510
const auto & partial_shape = input.get_partial_shape ();
508
511
NPUW_ASSERT (partial_shape.size () == 3 );
@@ -512,7 +515,7 @@ void reshape_sliced_head_to_static(std::shared_ptr<ov::Model> lm_head_model, con
512
515
// Left dynamic axis will be for number of embeddings
513
516
for (auto i = 0 ; i < new_shape.rank ().get_length (); i++) {
514
517
if (new_shape[i].is_dynamic ()) {
515
- new_shape[i] = 1 ;
518
+ new_shape[i] = max_generation_token_len ;
516
519
// Sanity check that only one left dimension is dynamic, as
517
520
// another one should contain embedding space rank
518
521
break ;
@@ -522,7 +525,9 @@ void reshape_sliced_head_to_static(std::shared_ptr<ov::Model> lm_head_model, con
522
525
lm_head_model->reshape (new_shape);
523
526
}
524
527
525
- void slice_out_embeds (std::shared_ptr<ov::Model> model, const uint32_t & batch_dim) {
528
+ void slice_out_embeds (std::shared_ptr<ov::Model> model,
529
+ const uint32_t & batch_dim,
530
+ std::size_t max_generation_token_len) {
526
531
std::shared_ptr<ov::Node> embed_result;
527
532
for (auto && output : model->outputs ()) {
528
533
if (output.get_any_name () == ov::npuw::LLMCompiledModel::output_embeds) {
@@ -533,15 +538,16 @@ void slice_out_embeds(std::shared_ptr<ov::Model> model, const uint32_t& batch_di
533
538
if (embed_result) {
534
539
auto shape = embed_result->input (0 ).get_shape ();
535
540
// If shape.size() is 3, then last axis should be the Vocab size.
536
- // But 1st and 2nd axis can mean different things.
541
+ // But 1st and 2nd axes can mean different things.
537
542
// 1st axis can represent the batch size, while 2nd - the number of embeddings,
538
543
// or vice-versa (in chatglm)
539
544
if (shape.size () == 3 ) {
540
545
uint32_t num_embeds_dim = 1 - batch_dim;
541
- if (shape[num_embeds_dim] > 1 ) {
542
- std::vector<int32_t > start_pos{static_cast <int32_t >(batch_dim * (shape[num_embeds_dim] - 1 )),
543
- static_cast <int32_t >(num_embeds_dim * (shape[num_embeds_dim] - 1 )),
544
- 0 };
546
+ if (shape[num_embeds_dim] > max_generation_token_len) {
547
+ std::vector<int32_t > start_pos{
548
+ static_cast <int32_t >(batch_dim * (shape[num_embeds_dim] - max_generation_token_len)),
549
+ static_cast <int32_t >(num_embeds_dim * (shape[num_embeds_dim] - max_generation_token_len)),
550
+ 0 };
545
551
std::vector<int32_t > stop_pos{static_cast <int32_t >(batch_dim * (shape[num_embeds_dim] - 1 )) + 1 ,
546
552
static_cast <int32_t >(num_embeds_dim * (shape[num_embeds_dim] - 1 )) + 1 ,
547
553
static_cast <int32_t >(shape[2 ])};
@@ -673,6 +679,9 @@ ov::AnyMap get_default_generate_config(const std::optional<NPUDesc>& npudesc,
673
679
if (hint == ::intel_npu::npuw::llm::GenerateHint::FAST_COMPILE) {
674
680
config.emplace (" NPUW_UNFOLD_IREQS" , " YES" );
675
681
}
682
+ // We don't need slice out for kv cache model, especially for speculative decoding which need
683
+ // to generate more than 1 token for each inference
684
+ config.erase (" NPUW_SLICE_OUT" );
676
685
return config;
677
686
}
678
687
@@ -849,6 +858,10 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
849
858
KVAxesPosition axes{batch_dim, seq_len_dim};
850
859
uint32_t max_prompt_len = align_to (m_cfg.get <::intel_npu::NPUW_LLM_MAX_PROMPT_LEN>(), 64u );
851
860
const uint32_t min_response_len = align_to (m_cfg.get <::intel_npu::NPUW_LLM_MIN_RESPONSE_LEN>(), 64u );
861
+ uint32_t max_generation_token_len = m_cfg.get <::intel_npu::NPUW_LLM_MAX_GENERATION_TOKEN_LEN>();
862
+ if (max_generation_token_len != 1 ) {
863
+ max_generation_token_len = align_to (max_generation_token_len, 8u );
864
+ }
852
865
853
866
// If chunk size covers the entire prompt, just follow the static behavior.
854
867
// Otherwise, use chunking and align the prompt size to the chunk size.
@@ -872,7 +885,9 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
872
885
LOG_VERB (" Prefill chunk size: " << m_prefill_chunk_size);
873
886
LOG_VERB (" Maximum prompt length: " << max_prompt_len);
874
887
875
- m_kvcache_desc = KVCacheDesc{max_prompt_len, max_prompt_len + min_response_len, 0u , seq_len_dim};
888
+ m_kvcache_desc =
889
+ KVCacheDesc{max_prompt_len, max_prompt_len + min_response_len, 0u , seq_len_dim, max_generation_token_len};
890
+
876
891
LOG_DEBUG (" Make prefill model with static shapes" );
877
892
m_max_lora_rank = m_cfg.get <::intel_npu::NPUW_LLM_MAX_LORA_RANK>();
878
893
if (m_use_chunk_prefill) {
@@ -889,14 +904,18 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
889
904
m_max_lora_rank);
890
905
}
891
906
LOG_DEBUG (" Make kvcache model with static shapes" );
892
- reshape_to_static (kvcache_model, 1u , m_kvcache_desc.total_size , axes, m_max_lora_rank);
907
+ reshape_to_static (kvcache_model,
908
+ m_kvcache_desc.max_generation_token_len ,
909
+ m_kvcache_desc.total_size ,
910
+ axes,
911
+ m_max_lora_rank);
893
912
if (lm_head_model) {
894
913
LOG_DEBUG (" Shared LM head: slice the prefill output" );
895
- // KVCache model is already reshaped to [1, 1 , embed size], so only apply slice to
896
- // the Prefill model:
897
- slice_out_embeds (prefill_model, axes.batch );
914
+ // KVCache model is already reshaped to [1, max_generation_token_len , embed size],
915
+ // so only apply slice to the Prefill model:
916
+ slice_out_embeds (prefill_model, axes.batch , m_kvcache_desc. max_generation_token_len );
898
917
LOG_DEBUG (" Make LM head model with static shapes" );
899
- reshape_sliced_head_to_static (lm_head_model, axes.batch );
918
+ reshape_sliced_head_to_static (lm_head_model, axes.batch , m_kvcache_desc. max_generation_token_len );
900
919
}
901
920
902
921
LOG_DEBUG (" 5.1, decompose GroupQueryAttention OP" );
@@ -1089,6 +1108,7 @@ void ov::npuw::LLMCompiledModel::serialize(std::ostream& stream, const ov::npuw:
1089
1108
write (model_stream, m_kvcache_desc.total_size );
1090
1109
write (model_stream, m_kvcache_desc.num_stored_tokens );
1091
1110
write (model_stream, m_kvcache_desc.dim );
1111
+ write (model_stream, m_kvcache_desc.max_generation_token_len );
1092
1112
write (model_stream, m_kvcache_desc.v_tensors_transposed );
1093
1113
write (model_stream, m_prefill_chunk_size);
1094
1114
write (model_stream, m_use_chunk_prefill);
@@ -1297,6 +1317,7 @@ std::shared_ptr<ov::npuw::LLMCompiledModel> ov::npuw::LLMCompiledModel::deserial
1297
1317
read (model_stream, compiled->m_kvcache_desc .total_size );
1298
1318
read (model_stream, compiled->m_kvcache_desc .num_stored_tokens );
1299
1319
read (model_stream, compiled->m_kvcache_desc .dim );
1320
+ read (model_stream, compiled->m_kvcache_desc .max_generation_token_len );
1300
1321
read (model_stream, compiled->m_kvcache_desc .v_tensors_transposed );
1301
1322
read (model_stream, compiled->m_prefill_chunk_size );
1302
1323
read (model_stream, compiled->m_use_chunk_prefill );
0 commit comments