@@ -1745,8 +1745,9 @@ static void ggml_cuda_op_mul_mat(
1745
1745
}
1746
1746
}
1747
1747
1748
+ template <typename T>
1748
1749
static __global__ void k_compute_batched_ptrs (
1749
- const half * src0_as_f16, const half * src1_as_f16, char * dst,
1750
+ const T * src0_as_f16, const T * src1_as_f16, char * dst,
1750
1751
const void ** ptrs_src, void ** ptrs_dst,
1751
1752
int64_t ne12, int64_t ne13,
1752
1753
int64_t ne23,
@@ -1774,7 +1775,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1774
1775
GGML_ASSERT (!ggml_is_transposed (src1));
1775
1776
1776
1777
GGML_ASSERT (!ggml_backend_buft_is_cuda_split (src0->buffer ->buft ));
1777
- GGML_ASSERT (src0->type == GGML_TYPE_F16);
1778
+ GGML_ASSERT (src0->type == GGML_TYPE_F16 || src0-> type == GGML_TYPE_BF16 || src0-> type == GGML_TYPE_F32 );
1778
1779
1779
1780
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
1780
1781
// As long as dst is contiguous this does not matter though.
@@ -1788,64 +1789,153 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1788
1789
1789
1790
CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (), main_stream));
1790
1791
1791
- const half * src0_f16 = (const half *) src0->data ;
1792
- float * dst_ddf = (float *) dst->data ;
1792
+ const ggml_type src0_type = src0->type ;
1793
+ const bool use_f32_path = src0_type == GGML_TYPE_F32;
1794
+ const bool use_bf16_path = src0_type == GGML_TYPE_BF16;
1793
1795
1794
- const half * src1_f16 = (const half *) src1 ->data ;
1796
+ float * dst_ddf = (float *) dst ->data ;
1795
1797
const size_t ts_src1 = ggml_type_size (src1->type );
1796
1798
GGML_ASSERT (nb10 == ts_src1);
1797
1799
int64_t s11 = nb11 / ts_src1;
1798
1800
int64_t s12 = nb12 / ts_src1;
1799
1801
int64_t s13 = nb13 / ts_src1;
1802
+
1803
+ const half * src0_f16 = nullptr ;
1804
+ const half * src1_f16 = nullptr ;
1805
+ const nv_bfloat16 * src0_bf16 = nullptr ;
1806
+ const nv_bfloat16 * src1_bf16 = nullptr ;
1807
+ const float * src0_f32 = nullptr ;
1808
+ const float * src1_f32 = nullptr ;
1809
+
1810
+ ggml_cuda_pool_alloc<half> src0_f16_alloc (ctx.pool ());
1800
1811
ggml_cuda_pool_alloc<half> src1_f16_alloc (ctx.pool ());
1812
+ ggml_cuda_pool_alloc<nv_bfloat16> src0_bf16_alloc (ctx.pool ());
1813
+ ggml_cuda_pool_alloc<nv_bfloat16> src1_bf16_alloc (ctx.pool ());
1814
+ ggml_cuda_pool_alloc<float > src0_f32_alloc (ctx.pool ());
1815
+ ggml_cuda_pool_alloc<float > src1_f32_alloc (ctx.pool ());
1816
+
1817
+ if (use_f32_path) {
1818
+ // F32 path
1819
+ src0_f32 = (const float *) src0->data ;
1820
+ if (src1->type == GGML_TYPE_F32) {
1821
+ src1_f32 = (const float *) src1->data ;
1822
+ } else {
1823
+ // Convert src1 to F32
1824
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (src1->type );
1825
+ const int64_t ne_src1 = ggml_nelements (src1);
1826
+ src1_f32_alloc.alloc (ne_src1);
1827
+ GGML_ASSERT (to_fp32_cuda != nullptr );
1801
1828
1802
- // convert src1 to fp16
1803
- if (src1->type != GGML_TYPE_F16) {
1804
- const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda (src1->type );
1805
- const int64_t ne_src1 = ggml_nelements (src1);
1806
- src1_f16_alloc.alloc (ne_src1);
1807
- GGML_ASSERT (to_fp16_cuda != nullptr );
1829
+ to_fp32_cuda ((const void *)((const char *)src1->data ), src1_f32_alloc.get (), ne_src1, main_stream);
1830
+ src1_f32 = src1_f32_alloc.get ();
1831
+ s11 = ne10;
1832
+ s12 = ne11*s11;
1833
+ s13 = ne12*s12;
1834
+ }
1835
+ } else if (use_bf16_path) {
1836
+ // BF16 path
1837
+ src0_bf16 = (const nv_bfloat16 *) src0->data ;
1838
+ if (src1->type == GGML_TYPE_BF16) {
1839
+ src1_bf16 = (const nv_bfloat16 *) src1->data ;
1840
+ } else {
1841
+ // Convert src1 to BF16
1842
+ const to_bf16_nc_cuda_t to_bf16_cuda = ggml_get_to_bf16_nc_cuda (src1->type );
1843
+ const int64_t ne_src1 = ggml_nelements (src1);
1844
+ src1_bf16_alloc.alloc (ne_src1);
1845
+ GGML_ASSERT (to_bf16_cuda != nullptr );
1808
1846
1809
- to_fp16_cuda (src1_f16, src1_f16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1847
+ to_bf16_cuda ((const void *)((const char *)src1->data ), src1_bf16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1848
+ src1_bf16 = src1_bf16_alloc.get ();
1849
+ s11 = ne10;
1850
+ s12 = ne11*s11;
1851
+ s13 = ne12*s12;
1852
+ }
1853
+ } else {
1854
+ // F16 path (default)
1855
+ src0_f16 = (const half *) src0->data ;
1856
+ if (src1->type == GGML_TYPE_F16) {
1857
+ src1_f16 = (const half *) src1->data ;
1858
+ } else {
1859
+ // Convert src1 to F16
1860
+ const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda (src1->type );
1861
+ const int64_t ne_src1 = ggml_nelements (src1);
1862
+ src1_f16_alloc.alloc (ne_src1);
1863
+ GGML_ASSERT (to_fp16_cuda != nullptr );
1810
1864
1811
- src1_f16 = src1_f16_alloc.get ();
1812
- s11 = ne10;
1813
- s12 = ne11*s11;
1814
- s13 = ne12*s12;
1865
+ to_fp16_cuda ((const void *)((const char *)src1->data ), src1_f16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1866
+ src1_f16 = src1_f16_alloc.get ();
1867
+ s11 = ne10;
1868
+ s12 = ne11*s11;
1869
+ s13 = ne12*s12;
1870
+ }
1815
1871
}
1816
1872
1817
1873
ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool ());
1874
+ ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16 (ctx.pool ());
1818
1875
char * dst_t ;
1819
1876
1820
- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1821
- cudaDataType_t cu_data_type = CUDA_R_16F;
1877
+ cublasComputeType_t cu_compute_type;
1878
+ cudaDataType_t cu_data_type;
1879
+ cudaDataType_t cu_data_type_a;
1880
+ cudaDataType_t cu_data_type_b;
1881
+
1882
+ if (use_f32_path) {
1883
+ cu_compute_type = CUBLAS_COMPUTE_32F;
1884
+ cu_data_type = CUDA_R_32F;
1885
+ cu_data_type_a = CUDA_R_32F;
1886
+ cu_data_type_b = CUDA_R_32F;
1887
+ } else if (use_bf16_path) {
1888
+ cu_compute_type = CUBLAS_COMPUTE_32F;
1889
+ cu_data_type = CUDA_R_16BF;
1890
+ cu_data_type_a = CUDA_R_16BF;
1891
+ cu_data_type_b = CUDA_R_16BF;
1892
+ } else {
1893
+ cu_compute_type = CUBLAS_COMPUTE_16F;
1894
+ cu_data_type = CUDA_R_16F;
1895
+ cu_data_type_a = CUDA_R_16F;
1896
+ cu_data_type_b = CUDA_R_16F;
1897
+ }
1822
1898
1823
- // dst strides
1824
1899
size_t nbd2 = dst->nb [2 ];
1825
1900
size_t nbd3 = dst->nb [3 ];
1826
1901
1827
1902
const half alpha_f16 = 1 .0f ;
1828
1903
const half beta_f16 = 0 .0f ;
1829
-
1830
1904
const float alpha_f32 = 1 .0f ;
1831
1905
const float beta_f32 = 0 .0f ;
1832
1906
1833
- const void * alpha = &alpha_f16 ;
1834
- const void * beta = &beta_f16 ;
1907
+ const void * alpha;
1908
+ const void * beta;
1835
1909
1836
- if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
1837
- dst_t = (char *) dst_f16.alloc (ne_dst);
1910
+ if (use_f32_path || cu_compute_type == CUBLAS_COMPUTE_32F) {
1911
+ alpha = &alpha_f32;
1912
+ beta = &beta_f32;
1913
+ } else if (use_bf16_path) {
1914
+ alpha = &alpha_f32;
1915
+ beta = &beta_f32;
1916
+ } else {
1917
+ alpha = &alpha_f16;
1918
+ beta = &beta_f16;
1919
+ }
1838
1920
1839
- nbd2 /= sizeof (float ) / sizeof (half);
1840
- nbd3 /= sizeof (float ) / sizeof (half);
1921
+ if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
1922
+ if (use_f32_path) {
1923
+ dst_t = (char *) dst_ddf; // Direct F32 output
1924
+ } else if (use_bf16_path) {
1925
+ dst_t = (char *) dst_bf16.alloc (ne_dst);
1926
+ nbd2 /= sizeof (float ) / sizeof (nv_bfloat16);
1927
+ nbd3 /= sizeof (float ) / sizeof (nv_bfloat16);
1928
+ } else {
1929
+ dst_t = (char *) dst_f16.alloc (ne_dst);
1930
+ nbd2 /= sizeof (float ) / sizeof (half);
1931
+ nbd3 /= sizeof (float ) / sizeof (half);
1932
+ }
1841
1933
} else {
1842
1934
dst_t = (char *) dst_ddf;
1843
-
1844
1935
cu_compute_type = CUBLAS_COMPUTE_32F;
1845
- cu_data_type = CUDA_R_32F;
1846
-
1936
+ cu_data_type = CUDA_R_32F;
1847
1937
alpha = &alpha_f32;
1848
- beta = &beta_f32;
1938
+ beta = &beta_f32;
1849
1939
}
1850
1940
1851
1941
int id = ggml_cuda_get_device ();
@@ -1886,11 +1976,16 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1886
1976
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
1887
1977
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
1888
1978
// use cublasGemmStridedBatchedEx
1979
+ const void * src0_ptr = use_f32_path ? (const void *)src0_f32 :
1980
+ use_bf16_path ? (const void *)src0_bf16 : (const void *)src0_f16;
1981
+ const void * src1_ptr = use_f32_path ? (const void *)src1_f32 :
1982
+ use_bf16_path ? (const void *)src1_bf16 : (const void *)src1_f16;
1983
+
1889
1984
CUBLAS_CHECK (
1890
1985
cublasGemmStridedBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
1891
1986
ne01, ne11, ne10,
1892
- alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1893
- src1_f16, CUDA_R_16F, s11, s12, // strideB
1987
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1988
+ src1_ptr, cu_data_type_b, s11, s12, // strideB
1894
1989
beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
1895
1990
ne12*ne13,
1896
1991
cu_compute_type,
@@ -1902,34 +1997,74 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1902
1997
ggml_cuda_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
1903
1998
ggml_cuda_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
1904
1999
2000
+ const void * src0_ptr = use_f32_path ? (const void *)src0_f32 :
2001
+ use_bf16_path ? (const void *)src0_bf16 : (const void *)src0_f16;
2002
+ const void * src1_ptr = use_f32_path ? (const void *)src1_f32 :
2003
+ use_bf16_path ? (const void *)src1_bf16 : (const void *)src1_f16;
2004
+
2005
+ size_t src1_stride_size = use_f32_path ? sizeof (float ) :
2006
+ use_bf16_path ? sizeof (nv_bfloat16) : sizeof (half);
2007
+
1905
2008
dim3 block_dims (ne13, ne12);
1906
- k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
1907
- src0_f16, src1_f16, dst_t ,
1908
- ptrs_src.get (), ptrs_dst.get (),
1909
- ne12, ne13,
1910
- ne23,
1911
- nb02, nb03,
1912
- src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof (half),
1913
- src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof (half),
1914
- nbd2, nbd3,
1915
- r2, r3);
2009
+ if ( use_f32_path ) {
2010
+ k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
2011
+ (const float *)src0_ptr, (const float *)src1_ptr, dst_t ,
2012
+ ptrs_src.get (), ptrs_dst.get (),
2013
+ ne12, ne13,
2014
+ ne23,
2015
+ nb02, nb03,
2016
+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2017
+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2018
+ nbd2, nbd3,
2019
+ r2, r3);
2020
+ } else if (use_bf16_path) {
2021
+ k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
2022
+ (const nv_bfloat16*)src0_ptr, (const nv_bfloat16*)src1_ptr, dst_t ,
2023
+ ptrs_src.get (), ptrs_dst.get (),
2024
+ ne12, ne13,
2025
+ ne23,
2026
+ nb02, nb03,
2027
+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2028
+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2029
+ nbd2, nbd3,
2030
+ r2, r3);
2031
+ } else {
2032
+ k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
2033
+ (const half*)src0_ptr, (const half*)src1_ptr, dst_t ,
2034
+ ptrs_src.get (), ptrs_dst.get (),
2035
+ ne12, ne13,
2036
+ ne23,
2037
+ nb02, nb03,
2038
+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2039
+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2040
+ nbd2, nbd3,
2041
+ r2, r3);
2042
+ }
2043
+
1916
2044
CUDA_CHECK (cudaGetLastError ());
1917
2045
1918
2046
CUBLAS_CHECK (
1919
2047
cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
1920
2048
ne01, ne11, ne10,
1921
- alpha, (const void **) (ptrs_src.get () + 0 *ne23), CUDA_R_16F, nb01/nb00,
1922
- (const void **) (ptrs_src.get () + 1 *ne23), CUDA_R_16F, s11,
2049
+ alpha, (const void **) (ptrs_src.get () + 0 *ne23), cu_data_type_a, nb01/nb00,
2050
+ (const void **) (ptrs_src.get () + 1 *ne23), cu_data_type_b, s11,
1923
2051
beta, ( void **) (ptrs_dst.get () + 0 *ne23), cu_data_type, ne0,
1924
2052
ne23,
1925
2053
cu_compute_type,
1926
2054
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1927
2055
}
1928
2056
#endif
1929
2057
1930
- if (dst->op_params [0 ] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1931
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1932
- to_fp32_cuda (dst_f16.get (), dst_ddf, ne_dst, main_stream);
2058
+ if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
2059
+ if (use_f32_path) {
2060
+ // already in f32
2061
+ } else if (use_bf16_path && cu_data_type == CUDA_R_16BF) {
2062
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_BF16);
2063
+ to_fp32_cuda (dst_bf16.get (), dst_ddf, ne_dst, main_stream);
2064
+ } else if (cu_data_type == CUDA_R_16F) {
2065
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
2066
+ to_fp32_cuda (dst_f16.get (), dst_ddf, ne_dst, main_stream);
2067
+ }
1933
2068
}
1934
2069
}
1935
2070
@@ -1989,8 +2124,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1989
2124
ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
1990
2125
} else if (!split && use_mul_mat_q) {
1991
2126
ggml_cuda_mul_mat_q (ctx, src0, src1, nullptr , dst);
1992
- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1993
- !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
2127
+ } else if (!split && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32)
2128
+ && (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)
2129
+ && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
1994
2130
// general KQ + KQV multi-batch without FlashAttention
1995
2131
ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
1996
2132
} else if (use_mul_mat_vec) {
0 commit comments