Skip to content

Commit 188b7f9

Browse files
[Performance][ROCm] Add skinny gemms for unquantized linear on ROCm (#15830)
Signed-off-by: charlifu <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]>
1 parent b9b4746 commit 188b7f9

File tree

12 files changed

+1957
-95
lines changed

12 files changed

+1957
-95
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
678678
#
679679
set(VLLM_ROCM_EXT_SRC
680680
"csrc/rocm/torch_bindings.cpp"
681+
"csrc/rocm/skinny_gemms.cu"
681682
"csrc/rocm/attention.cu")
682683

683684
define_gpu_extension_target(

csrc/rocm/ops.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22

33
#include <torch/all.h>
44

5+
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
6+
const int64_t rows_per_block);
7+
8+
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
9+
const int64_t CuCount);
10+
11+
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
12+
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
13+
514
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
615
torch::Tensor& max_logits, torch::Tensor& tmp_out,
716
torch::Tensor& query, torch::Tensor& key_cache,

0 commit comments

Comments
 (0)