Skip to content

Commit 789c697

Browse files
committed
CUDA: add bf16 and f32 support to cublas_mul_mat_batched
1 parent 1b809ce commit 789c697

File tree

4 files changed

+202
-51
lines changed

4 files changed

+202
-51
lines changed

ggml/src/ggml-cuda/convert.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,3 +728,14 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
728728
return nullptr;
729729
}
730730
}
731+
732+
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
733+
switch (type) {
734+
case GGML_TYPE_F32:
735+
return convert_unary_cuda<float, nv_bfloat16>;
736+
case GGML_TYPE_F16:
737+
return convert_unary_cuda<half, nv_bfloat16>;
738+
default:
739+
return nullptr;
740+
}
741+
}

ggml/src/ggml-cuda/convert.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
2323
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
2424

2525
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
26+
typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
2627
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
28+
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 185 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,8 +1745,9 @@ static void ggml_cuda_op_mul_mat(
17451745
}
17461746
}
17471747

1748+
template<typename T>
17481749
static __global__ void k_compute_batched_ptrs(
1749-
const half * src0_as_f16, const half * src1_as_f16, char * dst,
1750+
const T * src0_as_f16, const T * src1_as_f16, char * dst,
17501751
const void ** ptrs_src, void ** ptrs_dst,
17511752
int64_t ne12, int64_t ne13,
17521753
int64_t ne23,
@@ -1774,7 +1775,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
17741775
GGML_ASSERT(!ggml_is_transposed(src1));
17751776

17761777
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1777-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
1778+
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
17781779

17791780
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
17801781
// As long as dst is contiguous this does not matter though.
@@ -1788,64 +1789,153 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
17881789

17891790
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
17901791

1791-
const half * src0_f16 = (const half *) src0->data;
1792-
float * dst_ddf = (float *) dst->data;
1792+
const ggml_type src0_type = src0->type;
1793+
const bool use_f32_path = src0_type == GGML_TYPE_F32;
1794+
const bool use_bf16_path = src0_type == GGML_TYPE_BF16;
17931795

1794-
const half * src1_f16 = (const half *) src1->data;
1796+
float * dst_ddf = (float *) dst->data;
17951797
const size_t ts_src1 = ggml_type_size(src1->type);
17961798
GGML_ASSERT(nb10 == ts_src1);
17971799
int64_t s11 = nb11 / ts_src1;
17981800
int64_t s12 = nb12 / ts_src1;
17991801
int64_t s13 = nb13 / ts_src1;
1802+
1803+
const half * src0_f16 = nullptr;
1804+
const half * src1_f16 = nullptr;
1805+
const nv_bfloat16 * src0_bf16 = nullptr;
1806+
const nv_bfloat16 * src1_bf16 = nullptr;
1807+
const float * src0_f32 = nullptr;
1808+
const float * src1_f32 = nullptr;
1809+
1810+
ggml_cuda_pool_alloc<half> src0_f16_alloc(ctx.pool());
18001811
ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
1812+
ggml_cuda_pool_alloc<nv_bfloat16> src0_bf16_alloc(ctx.pool());
1813+
ggml_cuda_pool_alloc<nv_bfloat16> src1_bf16_alloc(ctx.pool());
1814+
ggml_cuda_pool_alloc<float> src0_f32_alloc(ctx.pool());
1815+
ggml_cuda_pool_alloc<float> src1_f32_alloc(ctx.pool());
1816+
1817+
if (use_f32_path) {
1818+
// F32 path
1819+
src0_f32 = (const float *) src0->data;
1820+
if (src1->type == GGML_TYPE_F32) {
1821+
src1_f32 = (const float *) src1->data;
1822+
} else {
1823+
// Convert src1 to F32
1824+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
1825+
const int64_t ne_src1 = ggml_nelements(src1);
1826+
src1_f32_alloc.alloc(ne_src1);
1827+
GGML_ASSERT(to_fp32_cuda != nullptr);
18011828

1802-
// convert src1 to fp16
1803-
if (src1->type != GGML_TYPE_F16) {
1804-
const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
1805-
const int64_t ne_src1 = ggml_nelements(src1);
1806-
src1_f16_alloc.alloc(ne_src1);
1807-
GGML_ASSERT(to_fp16_cuda != nullptr);
1829+
to_fp32_cuda((const void*)((const char*)src1->data), src1_f32_alloc.get(), ne_src1, main_stream);
1830+
src1_f32 = src1_f32_alloc.get();
1831+
s11 = ne10;
1832+
s12 = ne11*s11;
1833+
s13 = ne12*s12;
1834+
}
1835+
} else if (use_bf16_path) {
1836+
// BF16 path
1837+
src0_bf16 = (const nv_bfloat16 *) src0->data;
1838+
if (src1->type == GGML_TYPE_BF16) {
1839+
src1_bf16 = (const nv_bfloat16 *) src1->data;
1840+
} else {
1841+
// Convert src1 to BF16
1842+
const to_bf16_nc_cuda_t to_bf16_cuda = ggml_get_to_bf16_nc_cuda(src1->type);
1843+
const int64_t ne_src1 = ggml_nelements(src1);
1844+
src1_bf16_alloc.alloc(ne_src1);
1845+
GGML_ASSERT(to_bf16_cuda != nullptr);
18081846

1809-
to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1847+
to_bf16_cuda((const void*)((const char*)src1->data), src1_bf16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1848+
src1_bf16 = src1_bf16_alloc.get();
1849+
s11 = ne10;
1850+
s12 = ne11*s11;
1851+
s13 = ne12*s12;
1852+
}
1853+
} else {
1854+
// F16 path (default)
1855+
src0_f16 = (const half *) src0->data;
1856+
if (src1->type == GGML_TYPE_F16) {
1857+
src1_f16 = (const half *) src1->data;
1858+
} else {
1859+
// Convert src1 to F16
1860+
const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
1861+
const int64_t ne_src1 = ggml_nelements(src1);
1862+
src1_f16_alloc.alloc(ne_src1);
1863+
GGML_ASSERT(to_fp16_cuda != nullptr);
18101864

1811-
src1_f16 = src1_f16_alloc.get();
1812-
s11 = ne10;
1813-
s12 = ne11*s11;
1814-
s13 = ne12*s12;
1865+
to_fp16_cuda((const void*)((const char*)src1->data), src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1866+
src1_f16 = src1_f16_alloc.get();
1867+
s11 = ne10;
1868+
s12 = ne11*s11;
1869+
s13 = ne12*s12;
1870+
}
18151871
}
18161872

18171873
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
1874+
ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool());
18181875
char * dst_t;
18191876

1820-
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1821-
cudaDataType_t cu_data_type = CUDA_R_16F;
1877+
cublasComputeType_t cu_compute_type;
1878+
cudaDataType_t cu_data_type;
1879+
cudaDataType_t cu_data_type_a;
1880+
cudaDataType_t cu_data_type_b;
1881+
1882+
if (use_f32_path) {
1883+
cu_compute_type = CUBLAS_COMPUTE_32F;
1884+
cu_data_type = CUDA_R_32F;
1885+
cu_data_type_a = CUDA_R_32F;
1886+
cu_data_type_b = CUDA_R_32F;
1887+
} else if (use_bf16_path) {
1888+
cu_compute_type = CUBLAS_COMPUTE_32F;
1889+
cu_data_type = CUDA_R_16BF;
1890+
cu_data_type_a = CUDA_R_16BF;
1891+
cu_data_type_b = CUDA_R_16BF;
1892+
} else {
1893+
cu_compute_type = CUBLAS_COMPUTE_16F;
1894+
cu_data_type = CUDA_R_16F;
1895+
cu_data_type_a = CUDA_R_16F;
1896+
cu_data_type_b = CUDA_R_16F;
1897+
}
18221898

1823-
// dst strides
18241899
size_t nbd2 = dst->nb[2];
18251900
size_t nbd3 = dst->nb[3];
18261901

18271902
const half alpha_f16 = 1.0f;
18281903
const half beta_f16 = 0.0f;
1829-
18301904
const float alpha_f32 = 1.0f;
18311905
const float beta_f32 = 0.0f;
18321906

1833-
const void * alpha = &alpha_f16;
1834-
const void * beta = &beta_f16;
1907+
const void * alpha;
1908+
const void * beta;
18351909

1836-
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1837-
dst_t = (char *) dst_f16.alloc(ne_dst);
1910+
if (use_f32_path || cu_compute_type == CUBLAS_COMPUTE_32F) {
1911+
alpha = &alpha_f32;
1912+
beta = &beta_f32;
1913+
} else if (use_bf16_path) {
1914+
alpha = &alpha_f32;
1915+
beta = &beta_f32;
1916+
} else {
1917+
alpha = &alpha_f16;
1918+
beta = &beta_f16;
1919+
}
18381920

1839-
nbd2 /= sizeof(float) / sizeof(half);
1840-
nbd3 /= sizeof(float) / sizeof(half);
1921+
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1922+
if (use_f32_path) {
1923+
dst_t = (char *) dst_ddf; // Direct F32 output
1924+
} else if (use_bf16_path) {
1925+
dst_t = (char *) dst_bf16.alloc(ne_dst);
1926+
nbd2 /= sizeof(float) / sizeof(nv_bfloat16);
1927+
nbd3 /= sizeof(float) / sizeof(nv_bfloat16);
1928+
} else {
1929+
dst_t = (char *) dst_f16.alloc(ne_dst);
1930+
nbd2 /= sizeof(float) / sizeof(half);
1931+
nbd3 /= sizeof(float) / sizeof(half);
1932+
}
18411933
} else {
18421934
dst_t = (char *) dst_ddf;
1843-
18441935
cu_compute_type = CUBLAS_COMPUTE_32F;
1845-
cu_data_type = CUDA_R_32F;
1846-
1936+
cu_data_type = CUDA_R_32F;
18471937
alpha = &alpha_f32;
1848-
beta = &beta_f32;
1938+
beta = &beta_f32;
18491939
}
18501940

18511941
int id = ggml_cuda_get_device();
@@ -1886,11 +1976,16 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18861976
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
18871977
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
18881978
// use cublasGemmStridedBatchedEx
1979+
const void * src0_ptr = use_f32_path ? (const void*)src0_f32 :
1980+
use_bf16_path ? (const void*)src0_bf16 : (const void*)src0_f16;
1981+
const void * src1_ptr = use_f32_path ? (const void*)src1_f32 :
1982+
use_bf16_path ? (const void*)src1_bf16 : (const void*)src1_f16;
1983+
18891984
CUBLAS_CHECK(
18901985
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
18911986
ne01, ne11, ne10,
1892-
alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1893-
src1_f16, CUDA_R_16F, s11, s12, // strideB
1987+
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1988+
src1_ptr, cu_data_type_b, s11, s12, // strideB
18941989
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
18951990
ne12*ne13,
18961991
cu_compute_type,
@@ -1902,34 +1997,74 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
19021997
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
19031998
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
19041999

2000+
const void * src0_ptr = use_f32_path ? (const void*)src0_f32 :
2001+
use_bf16_path ? (const void*)src0_bf16 : (const void*)src0_f16;
2002+
const void * src1_ptr = use_f32_path ? (const void*)src1_f32 :
2003+
use_bf16_path ? (const void*)src1_bf16 : (const void*)src1_f16;
2004+
2005+
size_t src1_stride_size = use_f32_path ? sizeof(float) :
2006+
use_bf16_path ? sizeof(nv_bfloat16) : sizeof(half);
2007+
19052008
dim3 block_dims(ne13, ne12);
1906-
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1907-
src0_f16, src1_f16, dst_t,
1908-
ptrs_src.get(), ptrs_dst.get(),
1909-
ne12, ne13,
1910-
ne23,
1911-
nb02, nb03,
1912-
src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
1913-
src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
1914-
nbd2, nbd3,
1915-
r2, r3);
2009+
if( use_f32_path ) {
2010+
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
2011+
(const float*)src0_ptr, (const float*)src1_ptr, dst_t,
2012+
ptrs_src.get(), ptrs_dst.get(),
2013+
ne12, ne13,
2014+
ne23,
2015+
nb02, nb03,
2016+
(src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2017+
(src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2018+
nbd2, nbd3,
2019+
r2, r3);
2020+
} else if (use_bf16_path) {
2021+
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
2022+
(const nv_bfloat16*)src0_ptr, (const nv_bfloat16*)src1_ptr, dst_t,
2023+
ptrs_src.get(), ptrs_dst.get(),
2024+
ne12, ne13,
2025+
ne23,
2026+
nb02, nb03,
2027+
(src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2028+
(src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2029+
nbd2, nbd3,
2030+
r2, r3);
2031+
} else {
2032+
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
2033+
(const half*)src0_ptr, (const half*)src1_ptr, dst_t,
2034+
ptrs_src.get(), ptrs_dst.get(),
2035+
ne12, ne13,
2036+
ne23,
2037+
nb02, nb03,
2038+
(src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2039+
(src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2040+
nbd2, nbd3,
2041+
r2, r3);
2042+
}
2043+
19162044
CUDA_CHECK(cudaGetLastError());
19172045

19182046
CUBLAS_CHECK(
19192047
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
19202048
ne01, ne11, ne10,
1921-
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1922-
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
2049+
alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
2050+
(const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
19232051
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
19242052
ne23,
19252053
cu_compute_type,
19262054
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
19272055
}
19282056
#endif
19292057

1930-
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1931-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1932-
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
2058+
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
2059+
if (use_f32_path) {
2060+
//already in f32
2061+
} else if (use_bf16_path && cu_data_type == CUDA_R_16BF) {
2062+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
2063+
to_fp32_cuda(dst_bf16.get(), dst_ddf, ne_dst, main_stream);
2064+
} else if (cu_data_type == CUDA_R_16F) {
2065+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
2066+
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
2067+
}
19332068
}
19342069
}
19352070

@@ -1989,8 +2124,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19892124
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
19902125
} else if (!split && use_mul_mat_q) {
19912126
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
1992-
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1993-
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2127+
} else if (!split && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32)
2128+
&& (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)
2129+
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
19942130
// general KQ + KQV multi-batch without FlashAttention
19952131
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
19962132
} else if (use_mul_mat_vec) {

tests/test-backend-ops.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4332,8 +4332,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
43324332
for (auto nr : {1,4}) {
43334333
for (uint32_t m = 0; m < 2; ++m) {
43344334
for (uint32_t k = 0; k < 2; ++k) {
4335-
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
4336-
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
4335+
for(ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}){
4336+
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
4337+
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
4338+
}
43374339
}
43384340
}
43394341
}

0 commit comments

Comments
 (0)