Skip to content

Commit 34eb95e

Browse files
committed
Reapply "CUDA: fix logic for clearing padding with -ngl 0 (ggml-org#13320)"
1 parent 6d94964 commit 34eb95e

File tree

4 files changed

+18
-6
lines changed

4 files changed

+18
-6
lines changed

ggml/include/ggml-backend.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ extern "C" {
3838
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size);
3939
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
4040
GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft);
41-
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
41+
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
4242
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
4343
GGML_API ggml_backend_dev_t ggml_backend_buft_get_device (ggml_backend_buffer_type_t buft);
4444

@@ -59,7 +59,7 @@ extern "C" {
5959
GGML_API enum ggml_status ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
6060
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
6161
GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer);
62-
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
62+
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor);
6363
GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
6464
GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
6565
GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);

ggml/src/ggml-backend.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
5656
return SIZE_MAX;
5757
}
5858

59-
size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
59+
size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
6060
// get_alloc_size is optional, defaults to ggml_nbytes
6161
if (buft->iface.get_alloc_size) {
6262
size_t size = buft->iface.get_alloc_size(buft, tensor);
@@ -152,7 +152,7 @@ size_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) {
152152
return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer));
153153
}
154154

155-
size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
155+
size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor) {
156156
return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor);
157157
}
158158

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -589,8 +589,8 @@ static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer
589589

590590
if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
591591
// initialize padding to 0 to avoid possible NaN values
592-
size_t original_size = ggml_nbytes(tensor);
593-
size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
592+
const size_t original_size = ggml_nbytes(tensor);
593+
const size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
594594

595595
if (padded_size > original_size) {
596596
ggml_cuda_set_device(ctx->device);
@@ -712,6 +712,7 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t
712712

713713
if (ggml_is_quantized(tensor->type)) {
714714
if (ne0 % MATRIX_ROW_PADDING != 0) {
715+
GGML_ASSERT(tensor->nb[0] == ggml_element_size(tensor));
715716
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
716717
}
717718
}
@@ -833,6 +834,7 @@ static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buff
833834

834835
static enum ggml_status ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
835836
GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
837+
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
836838

837839
ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
838840
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
@@ -884,6 +886,7 @@ static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buff
884886
// split tensors must always be set in their entirety at once
885887
GGML_ASSERT(offset == 0);
886888
GGML_ASSERT(size == ggml_nbytes(tensor));
889+
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
887890

888891
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
889892

@@ -922,6 +925,7 @@ static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buff
922925
// split tensors must always be set in their entirety at once
923926
GGML_ASSERT(offset == 0);
924927
GGML_ASSERT(size == ggml_nbytes(tensor));
928+
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
925929

926930
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
927931

@@ -1003,6 +1007,7 @@ static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buf
10031007

10041008
static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
10051009
ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
1010+
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
10061011

10071012
size_t total_size = 0;
10081013

@@ -2391,6 +2396,12 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
23912396
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
23922397
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
23932398

2399+
ggml_tensor src0_slice = *src0;
2400+
src0_slice.ne[2] = 1;
2401+
src0_slice.nb[3] = src0_slice.nb[2];
2402+
src0_slice.data = (char *) src0->data + i02*nb02;
2403+
GGML_ASSERT(!ggml_cuda_should_use_mmq(src0->type, cc, ne11) || ne00 % MATRIX_ROW_PADDING == 0);
2404+
23942405
{
23952406
dim3 block_dims(std::min((unsigned int)ne10, 768u));
23962407
dim3 grid_dims(ids->ne[1], n_ids);

ggml/src/ggml-cuda/quantize.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ void quantize_mmq_q8_1_cuda(
160160
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
161161
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
162162
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
163+
GGML_ASSERT(ne00 % 4 == 0);
163164
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
164165

165166
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);

0 commit comments

Comments
 (0)