Skip to content
27 changes: 8 additions & 19 deletions src/cpp/src/gguf_utils/building_blocks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,11 +672,6 @@ ov::Output<ov::Node> make_int4_weights(
zero_point_data[i] = (bias2 << 4) | (bias1 & 0x0F);
}

// CVS-166438: GGUF Q4_0 zp array (U4) with all same value (8) will be converted to single U4 scalar via ConvertU4WeightsZeroPointToScalar transformation.
// This corner case can be handled by CPU plugin properly, but will trigger compilation error on GPU plugin.
// Temporal WA by adding one small bias to keep zp array shape for GPU plugin, confirm no accuracy impact for final LLM generation results.
zero_point_data[0] += 1;

auto zero_points_node = std::make_shared<ov::op::v0::Constant>(zero_point_tensor);
auto zero_points_f16 = std::make_shared<ov::op::v0::Convert>(zero_points_node, ov::element::f16);

Expand Down Expand Up @@ -748,23 +743,17 @@ ov::Output<ov::Node> make_lm_head(
const ov::Output<ov::Node>& input,
const std::unordered_map<std::string, ov::Tensor>& consts,
const ov::Output<ov::Node>& embeddings_node,
gguf_tensor_type qtype,
bool shared_embedding) {
gguf_tensor_type qtype) {

ov::Output<ov::Node> w_f32;
if (shared_embedding){
w_f32 = embeddings_node;
}
else {
if (consts.count(key + ".weight")) {
gguf_tensor_type lm_qtype = qtype;
if (!consts.count(key + ".scales")) {
lm_qtype = gguf_tensor_type::GGUF_TYPE_F16;
}
w_f32 = make_weights_subgraph(key, consts, lm_qtype, false, -1);
} else {
w_f32 = embeddings_node;
if (consts.count(key + ".weight")) {
Copy link
Preview

Copilot AI Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The logic structure has changed after removing the shared_embedding condition, but the original fallback logic (using embeddings_node when key + ".weight" doesn't exist) is preserved. Consider adding a comment to clarify this fallback behavior for future maintainers.

Copilot uses AI. Check for mistakes.

gguf_tensor_type lm_qtype = qtype;
if (!consts.count(key + ".scales")) {
lm_qtype = gguf_tensor_type::GGUF_TYPE_F16;
}
w_f32 = make_weights_subgraph(key, consts, lm_qtype, false, -1);
} else {
w_f32 = embeddings_node;
}
return std::make_shared<ov::op::v0::MatMul>(
input, w_f32, false, true);
Expand Down
3 changes: 1 addition & 2 deletions src/cpp/src/gguf_utils/building_blocks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ ov::Output<ov::Node> make_lm_head(
const ov::Output<ov::Node>& input,
const std::unordered_map<std::string, ov::Tensor>& consts,
const ov::Output<ov::Node>& embeddings_node,
gguf_tensor_type qtype,
bool shared_embedding);
gguf_tensor_type qtype);

ov::Output<ov::Node> make_rms_norm(
const std::string& key,
Expand Down
18 changes: 2 additions & 16 deletions src/cpp/src/gguf_utils/gguf_modeling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,10 @@ auto set_name = [](auto node, const std::string& name) {
node->set_friendly_name(name);
};

// Also valid for other models, e.g. SmolLMs
// CVS-166108: Adding shared_embedding as true by default based on following two reason:
// 1. For optimum-cli converted LLM OpenVINO IR, original input embedding weight will be reused for last make_lm_head layer
// Which can reduce both model size on disk and runtime memory usage via storing only single embeddding consts
// (e.g. Qwen2.5-7B-Instruct-Q4_0 token_embd.weight & output.weight shape [3584, 152064]
// 2. For some GGUF model that contains both token_embd.weight & output.weight, e.g. Qwen2.5-3B-Instruct Q4_0
// meet accuracy issue on MTL/LNL GPU due to use both token_embd.weight & output.weight in OpenVINO IR.
// WA Known issue: Qwen2.5-3B-Instruct-Q4_K_M meet accuracy issue on MTL/LNL CPU if only re-used token_embd.weight

std::shared_ptr<ov::Model> create_language_model(
const std::map<std::string, GGUFMetaData>& configs,
std::unordered_map<std::string, ov::Tensor>& consts,
std::unordered_map<std::string, gguf_tensor_type>& qtypes,
bool shared_embedding = false) {
std::unordered_map<std::string, gguf_tensor_type>& qtypes) {
// Create input parameters
auto input_ids = std::make_shared<ov::op::v0::Parameter>(
ov::element::i64, ov::PartialShape{-1, -1});
Expand Down Expand Up @@ -127,8 +117,7 @@ std::shared_ptr<ov::Model> create_language_model(
final_norm,
consts,
embeddings,
qtypes.at("lm_head.qtype"),
shared_embedding);
qtypes.at("lm_head.qtype"));

// Create results
auto logits = std::make_shared<ov::op::v0::Result>(embed_out);
Expand All @@ -143,9 +132,6 @@ std::shared_ptr<ov::Model> create_language_model(
model->set_rt_info(ov::element::f16, {"runtime_options", ov::hint::kv_cache_precision.name()});
}
model->set_rt_info(8.0f, {"runtime_options", ov::hint::activations_scale_factor.name()});
// CVS-166554: Dynamic quatnization enabled by default with gourp size 32 on MTL platfrom cause the runtime issue
// Apply WA to disable dynamic quantization with rt_info to fix GPU plugin issue on MTL
model->set_rt_info(0, {"runtime_options", ov::hint::dynamic_quantization_group_size.name()});

return model;
}
Expand Down
10 changes: 7 additions & 3 deletions tests/python_tests/data/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,18 @@ def get_gguf_model_list():
{
"hf_model_id": "HuggingFaceTB/SmolLM2-135M",
"gguf_model_id": "prithivMLmods/SmolLM2-135M-GGUF",
"gguf_filename": "SmolLM2-135M.F16.gguf"
"gguf_filename": "SmolLM2-135M.F16.gguf",
"dynamic_quantization_group_size": None,
},
{
"gguf_model_id": "Qwen/Qwen2.5-0.5B-Instruct-GGUF",
"gguf_filename": "qwen2.5-0.5b-instruct-q4_0.gguf"
"gguf_filename": "qwen2.5-0.5b-instruct-q4_0.gguf",
"dynamic_quantization_group_size": None,
},
{
"gguf_model_id": "sammysun0711/tiny-random-deepseek-distill-qwen-gguf",
"gguf_filename": "tiny-random-deepseek-distill-qwen_q8_0.gguf"
"gguf_filename": "tiny-random-deepseek-distill-qwen_q8_0.gguf",
# Dummy gguf model accuracy is sensitive for dynamic quantization w/ small group size 32 (default), set group size as 64 explicitly instead
"dynamic_quantization_group_size": "64",
},
]
8 changes: 5 additions & 3 deletions tests/python_tests/test_llm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,7 @@ def test_pipelines_with_gguf_generate(pipeline_type, model_ids):
pytest.skip(reason="168882: Sporadic segmentation fault failure on MacOS.")
gguf_model_id = model_ids["gguf_model_id"]
gguf_filename = model_ids["gguf_filename"]
dynamic_quantization_group_size = model_ids["dynamic_quantization_group_size"]
prompt = 'Why is the Sun yellow?'

opt_model = load_hf_model_from_gguf(gguf_model_id, gguf_filename)
Expand All @@ -850,7 +851,7 @@ def test_pipelines_with_gguf_generate(pipeline_type, model_ids):
res_string_input_1 = all_text_batch[0]

gguf_full_path = download_gguf_model(gguf_model_id, gguf_filename)
ov_pipe_gguf = create_ov_pipeline(gguf_full_path, pipeline_type=pipeline_type)
ov_pipe_gguf = create_ov_pipeline(gguf_full_path, pipeline_type=pipeline_type, dynamic_quantization_group_size=dynamic_quantization_group_size)
encoded_result = ov_pipe_gguf.generate(ov.Tensor(input_ids.numpy()), generation_config=ov_generation_config)
del ov_pipe_gguf
gc.collect()
Expand All @@ -868,6 +869,7 @@ def test_full_gguf_pipeline(pipeline_type, model_ids, enable_save_ov_model):
pytest.skip(reason="168882: Sporadic segmentation fault failure on MacOS.")
gguf_model_id = model_ids["gguf_model_id"]
gguf_filename = model_ids["gguf_filename"]
dynamic_quantization_group_size = model_ids["dynamic_quantization_group_size"]
prompt = 'Why is the Sun yellow?'

opt_model = load_hf_model_from_gguf(gguf_model_id, gguf_filename)
Expand Down Expand Up @@ -895,14 +897,14 @@ def test_full_gguf_pipeline(pipeline_type, model_ids, enable_save_ov_model):
res_string_input_1 = all_text_batch[0]

gguf_full_path = download_gguf_model(gguf_model_id, gguf_filename)
ov_pipe_gguf = create_ov_pipeline(gguf_full_path, pipeline_type=pipeline_type, enable_save_ov_model=enable_save_ov_model)
ov_pipe_gguf = create_ov_pipeline(gguf_full_path, pipeline_type=pipeline_type, enable_save_ov_model=enable_save_ov_model, dynamic_quantization_group_size=dynamic_quantization_group_size)
res_string_input_2 = ov_pipe_gguf.generate(prompt, generation_config=ov_generation_config)
del ov_pipe_gguf
gc.collect()

if enable_save_ov_model:
gguf_full_path = Path(gguf_full_path)
ov_pipe_native = create_ov_pipeline(gguf_full_path.parent, pipeline_type=pipeline_type)
ov_pipe_native = create_ov_pipeline(gguf_full_path.parent, pipeline_type=pipeline_type, dynamic_quantization_group_size=dynamic_quantization_group_size)
res_string_input_3 = ov_pipe_native.generate(prompt, generation_config=ov_generation_config)
del ov_pipe_native
gc.collect()
Expand Down
5 changes: 4 additions & 1 deletion tests/python_tests/utils/ov_genai_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,18 @@ def create_ov_pipeline(models_path: Path,
ov_config: dict = get_default_llm_properties(),
scheduler_config: SchedulerConfig = SchedulerConfig(),
draft_model_path: Path = None,
enable_save_ov_model: bool = None):
enable_save_ov_model: bool = None,
dynamic_quantization_group_size: str = None):
local_ov_config = ov_config.copy()
if pipeline_type == PipelineType.AUTO:
return LLMPipeline(models_path, device, ov_config)
elif pipeline_type == PipelineType.STATEFUL:
if enable_save_ov_model is not None: local_ov_config["enable_save_ov_model"] = enable_save_ov_model
if dynamic_quantization_group_size is not None: local_ov_config["DYNAMIC_QUANTIZATION_GROUP_SIZE"] = dynamic_quantization_group_size
return LLMPipeline(models_path, device, local_ov_config, ATTENTION_BACKEND="SDPA")
elif pipeline_type == PipelineType.PAGED_ATTENTION:
if enable_save_ov_model is not None: local_ov_config["enable_save_ov_model"] = enable_save_ov_model
if dynamic_quantization_group_size is not None: local_ov_config["DYNAMIC_QUANTIZATION_GROUP_SIZE"] = dynamic_quantization_group_size
return LLMPipeline(models_path, device, local_ov_config, scheduler_config=scheduler_config, ATTENTION_BACKEND="PA")
elif pipeline_type == PipelineType.CONTINUOUS_BATCHING:
return ContinuousBatchingPipeline(models_path, scheduler_config, device, ov_config)
Expand Down