Skip to content

Conversation

shalinib-ibm
Copy link
Contributor

This patch improves GEMM for FP32 Data Type on PowerPC

Implements GEMM on large blocks with configurable block size mc, nc, kc (default: 256, 256, 256).
Packing Function optimized to access blocks as per memory layout. GEMM Optimized to work on larger blocks.
Isolated Packing from GEMM Operations for better MMA utilization.

Verified functionality and correctness uing llama-cli and stand alone test case (performs matmul and compares final mattrix C result with base).

Performance Testing:

Observed 50% ~ 70% improvement in Prompt Processing Speed mesured using llama-bench with Meta-Llama3-8B FP32 Model. Similar gains observed with Mistral-7b-Instruct-v0.3 Model.

model                    Size Params Backend Threads Test Patch Base
llama 8B all F32        29.92 GiB     8.03 B CPU        20 pp512 98.58 60.3
llama 8B all F32        29.92 GiB     8.03 B CPU        20 pp1024 95.88 57.36
llama 8B all F32        29.92 GiB     8.03 B CPU        20 pp2048 85.46 53.26
llama 8B all F32        29.92 GiB     8.03 B CPU        20 pp4096 68.66 45.78
llama 8B all F32        29.92 GiB     8.03 B CPU        20 pp6144 57.35 40.44

25 ~ 30% improvement in llama-batched-bench with Metla-Llama3-8B in Prompt Processing Speed for large prompts (256, 512, 1024, 2048, 4096)tokens with various batch sizes ( 1, 2, 4, 8, 16)

Make sure to read the contributing guidelines before submitting a PR

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Aug 25, 2025
@shalinib-ibm
Copy link
Contributor Author

@ggerganov Can you please review this patch ?

Copy link
Collaborator

@taronaeo taronaeo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall code intention is inconsistent. Some are indented with 4 spaces, while some are 3 spaces. Can you make them all consistently 4 spaces?

@shalinib-ibm shalinib-ibm force-pushed the powerpc_sgemm_opt branch 2 times, most recently from 64837c8 to ffd0f79 Compare August 26, 2025 06:43
Copy link
Collaborator

@taronaeo taronaeo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General formatting between 4 spaces vs. 3 spaces across your PR. Decide between float* a vs. float * a although you should follow the latter as it is described in CONTRIBUTING.md.

@shalinib-ibm shalinib-ibm force-pushed the powerpc_sgemm_opt branch 4 times, most recently from a7b7058 to a4187cb Compare August 26, 2025 07:58
@shalinib-ibm
Copy link
Contributor Author

@taronaeo Thank you for the review comments. I have made the indentation to 4 spaces across the class tinyBLAS_PPC and this PR. Changed float* a to float * a . Kindly re review the PR.

Copy link
Collaborator

@taronaeo taronaeo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last 2 changes. Will approve once complete :)

This patch improves GEMM for FP32 Data Type on PowerPC

Implements GEMM on large blocks with configurable block size mc, nc, kc
(default: 256, 256, 256).
Packing Function optimized to access blocks as per memory layout.
GEMM Optimized to work on larger blocks.
Isolated Packing from GEMM Operations for better MMA utilization.

Verified functionality and correctness uing llama-cli and stand alone
test case (performs matmul and compares final mattrix C result with base).

Minor code refactoring changes:
Replace macro with inline function
Code Indent made consistent with 4 spaces

Performance Testing:

Observed 50% ~ 70% improvement in Prompt Processing Speed mesured using
llama-bench with Meta-Llama3-8B FP32 Model.  Similar gains observed with
Mistral-7b-Instruct-v0.3 Model.

model                   Size                Params     Backend       Threads   Test    Patch   Base
llama 8B all F32        29.92 GiB           8.03 B      CPU           20       pp512   98.58   60.3
llama 8B all F32        29.92 GiB           8.03 B      CPU           20       pp1024  95.88   57.36
llama 8B all F32        29.92 GiB           8.03 B      CPU           20       pp2048  85.46   53.26
llama 8B all F32        29.92 GiB           8.03 B      CPU           20       pp4096  68.66   45.78
llama 8B all F32        29.92 GiB           8.03 B      CPU           20       pp6144  57.35   40.44

25 ~ 30% improvement in llama-batched-bench with Metla-Llama3-8B in
Prompt Processing Speed for large prompts (256, 512, 1024, 2048, 4096)tokens with various batch
sizes ( 1, 2, 4, 8, 16)

Signed-off-by: Shalini Salomi Bodapati <[email protected]>
@taronaeo taronaeo merged commit a6a58d6 into ggml-org:master Aug 26, 2025
52 of 56 checks passed
Minh141120 pushed a commit to menloresearch/llama.cpp that referenced this pull request Aug 27, 2025
This patch improves GEMM for FP32 Data Type on PowerPC

Implements GEMM on large blocks with configurable block size mc, nc, kc
(default: 256, 256, 256).
Packing Function optimized to access blocks as per memory layout.
GEMM Optimized to work on larger blocks.
Isolated Packing from GEMM Operations for better MMA utilization.

Verified functionality and correctness uing llama-cli and stand alone
test case (performs matmul and compares final mattrix C result with base).

Minor code refactoring changes:
Replace macro with inline function
Code Indent made consistent with 4 spaces

Performance Testing:

Observed 50% ~ 70% improvement in Prompt Processing Speed mesured using
llama-bench with Meta-Llama3-8B FP32 Model.  Similar gains observed with
Mistral-7b-Instruct-v0.3 Model.

model                   Size                Params     Backend       Threads   Test    Patch   Base
llama 8B all F32        29.92 GiB           8.03 B      CPU           20       pp512   98.58   60.3
llama 8B all F32        29.92 GiB           8.03 B      CPU           20       pp1024  95.88   57.36
llama 8B all F32        29.92 GiB           8.03 B      CPU           20       pp2048  85.46   53.26
llama 8B all F32        29.92 GiB           8.03 B      CPU           20       pp4096  68.66   45.78
llama 8B all F32        29.92 GiB           8.03 B      CPU           20       pp6144  57.35   40.44

25 ~ 30% improvement in llama-batched-bench with Metla-Llama3-8B in
Prompt Processing Speed for large prompts (256, 512, 1024, 2048, 4096)tokens with various batch
sizes ( 1, 2, 4, 8, 16)

Signed-off-by: Shalini Salomi Bodapati <[email protected]>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Aug 28, 2025
…nemotron-nano-15409

* origin/master: (59 commits)
scripts: add sqlite3 check for compare-commits.sh (ggml-org#15633)
kv-cache : remove LLAMA_SET_ROWS checks (ggml-org#15505)
gguf-py: byteswapping improvements (ggml-org#12851)
cli : change log to warning to explain reason for stopping (ggml-org#15604)
model-conversion : add mmproj conversion target (ggml-org#15628)
cuda: Add cublasLt_static linking when GGML_STATIC is enabled (ggml-org#15622)
server: higher timeout for tests (ggml-org#15621)
presets : add qwen3-30B-a3b FIM (ggml-org#15616)
HIP: Enable support for ggml_backend_cuda_register_host_buffer (ggml-org#15615)
kv-cache : better estimate of n_kv for multi-sequence batches (ggml-org#15610)
CANN: refactor mask handling and improve performance in FA (ggml-org#15561)
ggml-cpu : add basic RVV support for vector f32 ops (ggml-org#15057)
common : add -m to bash completion for --model [no ci] (ggml-org#15591)
OpenCL: add fused group_norm/norm, mul, add (ggml-org#15314)
tests : fix test-opt with GGML_BACKEND_DL (ggml-org#15599)
SYCL: fix rms_norm_mul_add for tensor dim not a multiple of sg_size (ggml-org#15592)
mtmd : fix mtmd ios build (ggml-org#15579)
tests: add performance test for mul mat id (ggml-org#15543)
llamafile: PowerPC Sgemm Optimization (ggml-org#15558)
graph : fix assert in memory-less build_attn (ggml-org#15590)
...
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Aug 28, 2025
…upport

* origin/master: (61 commits)
scripts: add sqlite3 check for compare-commits.sh (ggml-org#15633)
kv-cache : remove LLAMA_SET_ROWS checks (ggml-org#15505)
gguf-py: byteswapping improvements (ggml-org#12851)
cli : change log to warning to explain reason for stopping (ggml-org#15604)
model-conversion : add mmproj conversion target (ggml-org#15628)
cuda: Add cublasLt_static linking when GGML_STATIC is enabled (ggml-org#15622)
server: higher timeout for tests (ggml-org#15621)
presets : add qwen3-30B-a3b FIM (ggml-org#15616)
HIP: Enable support for ggml_backend_cuda_register_host_buffer (ggml-org#15615)
kv-cache : better estimate of n_kv for multi-sequence batches (ggml-org#15610)
CANN: refactor mask handling and improve performance in FA (ggml-org#15561)
ggml-cpu : add basic RVV support for vector f32 ops (ggml-org#15057)
common : add -m to bash completion for --model [no ci] (ggml-org#15591)
OpenCL: add fused group_norm/norm, mul, add (ggml-org#15314)
tests : fix test-opt with GGML_BACKEND_DL (ggml-org#15599)
SYCL: fix rms_norm_mul_add for tensor dim not a multiple of sg_size (ggml-org#15592)
mtmd : fix mtmd ios build (ggml-org#15579)
tests: add performance test for mul mat id (ggml-org#15543)
llamafile: PowerPC Sgemm Optimization (ggml-org#15558)
graph : fix assert in memory-less build_attn (ggml-org#15590)
...
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Aug 28, 2025
…g-model-disabled-agent-prefill

* origin/master: (76 commits)
scripts: add sqlite3 check for compare-commits.sh (ggml-org#15633)
kv-cache : remove LLAMA_SET_ROWS checks (ggml-org#15505)
gguf-py: byteswapping improvements (ggml-org#12851)
cli : change log to warning to explain reason for stopping (ggml-org#15604)
model-conversion : add mmproj conversion target (ggml-org#15628)
cuda: Add cublasLt_static linking when GGML_STATIC is enabled (ggml-org#15622)
server: higher timeout for tests (ggml-org#15621)
presets : add qwen3-30B-a3b FIM (ggml-org#15616)
HIP: Enable support for ggml_backend_cuda_register_host_buffer (ggml-org#15615)
kv-cache : better estimate of n_kv for multi-sequence batches (ggml-org#15610)
CANN: refactor mask handling and improve performance in FA (ggml-org#15561)
ggml-cpu : add basic RVV support for vector f32 ops (ggml-org#15057)
common : add -m to bash completion for --model [no ci] (ggml-org#15591)
OpenCL: add fused group_norm/norm, mul, add (ggml-org#15314)
tests : fix test-opt with GGML_BACKEND_DL (ggml-org#15599)
SYCL: fix rms_norm_mul_add for tensor dim not a multiple of sg_size (ggml-org#15592)
mtmd : fix mtmd ios build (ggml-org#15579)
tests: add performance test for mul mat id (ggml-org#15543)
llamafile: PowerPC Sgemm Optimization (ggml-org#15558)
graph : fix assert in memory-less build_attn (ggml-org#15590)
...
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Aug 28, 2025
…nemotron-nano-15409

* origin/master: (59 commits)
scripts: add sqlite3 check for compare-commits.sh (ggml-org#15633)
kv-cache : remove LLAMA_SET_ROWS checks (ggml-org#15505)
gguf-py: byteswapping improvements (ggml-org#12851)
cli : change log to warning to explain reason for stopping (ggml-org#15604)
model-conversion : add mmproj conversion target (ggml-org#15628)
cuda: Add cublasLt_static linking when GGML_STATIC is enabled (ggml-org#15622)
server: higher timeout for tests (ggml-org#15621)
presets : add qwen3-30B-a3b FIM (ggml-org#15616)
HIP: Enable support for ggml_backend_cuda_register_host_buffer (ggml-org#15615)
kv-cache : better estimate of n_kv for multi-sequence batches (ggml-org#15610)
CANN: refactor mask handling and improve performance in FA (ggml-org#15561)
ggml-cpu : add basic RVV support for vector f32 ops (ggml-org#15057)
common : add -m to bash completion for --model [no ci] (ggml-org#15591)
OpenCL: add fused group_norm/norm, mul, add (ggml-org#15314)
tests : fix test-opt with GGML_BACKEND_DL (ggml-org#15599)
SYCL: fix rms_norm_mul_add for tensor dim not a multiple of sg_size (ggml-org#15592)
mtmd : fix mtmd ios build (ggml-org#15579)
tests: add performance test for mul mat id (ggml-org#15543)
llamafile: PowerPC Sgemm Optimization (ggml-org#15558)
graph : fix assert in memory-less build_attn (ggml-org#15590)
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants