Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ int main(int argc, char ** argv) {
}

LOG_TEE("\n");
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq);
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %d, n_threads_batch = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("\n");

LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
Expand Down
35 changes: 21 additions & 14 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4719,16 +4719,18 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int

// the CUDA soft max implementation differs from the CPU implementation
// instead of doubles floats are used
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
const int row = blockDim.x*blockIdx.x + threadIdx.x;
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
const int rowx = blockDim.x*blockIdx.x + threadIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
const int block_size = blockDim.y;
const int tid = threadIdx.y;

float max_val = -INFINITY;

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
max_val = max(max_val, x[i]);
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
}

// find the max value in the block
Expand All @@ -4740,10 +4742,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
float tmp = 0.f;

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const float val = expf(x[i] - max_val);
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
tmp += val;
dst[i] = val;
dst[ix] = val;
}

// sum up partial sums
Expand All @@ -4755,7 +4758,7 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
const float inv_tmp = 1.f / tmp;

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const int i = rowx*ncols + col;
dst[i] *= inv_tmp;
}
}
Expand Down Expand Up @@ -5792,10 +5795,10 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}

static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
const dim3 block_dims(1, WARP_SIZE, 1);
const dim3 block_nums(nrows_x, 1, 1);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
}

static void im2col_f32_f16_cuda(const float * x, half * dst,
Expand Down Expand Up @@ -6846,14 +6849,18 @@ inline void ggml_cuda_op_soft_max(
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional

const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 0;

soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
float scale = 1.0f;
memcpy(&scale, dst->op_params, sizeof(float));

soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);

(void) src1;
(void) dst;
(void) src1_dd;
}

inline void ggml_cuda_op_scale(
Expand Down
15 changes: 10 additions & 5 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1036,11 +1036,16 @@ void ggml_metal_graph_compute(
nth /= 2;
[encoder setComputePipelineState:ctx->pipeline_soft_max];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];

const float scale = ((float *) dst->op_params)[0];

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];

[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
Expand Down
26 changes: 16 additions & 10 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,12 @@ kernel void kernel_gelu(

kernel void kernel_soft_max(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
Expand All @@ -194,14 +196,15 @@ kernel void kernel_soft_max(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;

// parallel max
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
float lmax = (tpitg < ne00) ? (psrc0[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY;

for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]);
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
}

float max = simd_max(lmax);
Expand All @@ -225,7 +228,7 @@ kernel void kernel_soft_max(
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp(psrc0[i00] - max);
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// wish to compute it twice.
Expand Down Expand Up @@ -257,10 +260,12 @@ kernel void kernel_soft_max(

kernel void kernel_soft_max_4(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
Expand All @@ -271,14 +276,15 @@ kernel void kernel_soft_max_4(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);

// parallel max
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
float4 lmax4 = tpitg < ne00/4 ? (psrc4[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY;

for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]);
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
}

const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
Expand All @@ -303,7 +309,7 @@ kernel void kernel_soft_max_4(
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp(psrc4[i00] - max);
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
Expand Down
76 changes: 60 additions & 16 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4826,7 +4826,17 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace(
static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale,
bool inplace) {
GGML_ASSERT(ggml_is_contiguous(a));
if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == 1);
GGML_ASSERT(mask->ne[3] == 1);
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
}

bool is_node = false;

if (a->grad) {
Expand All @@ -4835,23 +4845,35 @@ static struct ggml_tensor * ggml_soft_max_impl(

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

float params[] = { scale };
ggml_set_op_params(result, params, sizeof(params));

result->op = GGML_OP_SOFT_MAX;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = mask;

return result;
}

struct ggml_tensor * ggml_soft_max(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_soft_max_impl(ctx, a, false);
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
}

struct ggml_tensor * ggml_soft_max_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_soft_max_impl(ctx, a, true);
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
}

struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale) {
return ggml_soft_max_impl(ctx, a, mask, scale, false);
}

// ggml_soft_max_back
Expand Down Expand Up @@ -10551,20 +10573,25 @@ static void ggml_compute_forward_diag_mask_zero(
static void ggml_compute_forward_soft_max_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
assert(ggml_is_contiguous(dst));
assert(ggml_are_same_shape(src0, dst));

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

float scale = 1.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));

// TODO: handle transposed/permuted matrices

const int ith = params->ith;
const int nth = params->nth;

const int64_t ne11 = src1 ? src1->ne[1] : 1;

const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);

Expand All @@ -10575,29 +10602,39 @@ static void ggml_compute_forward_soft_max_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;

for (int i1 = ir0; i1 < ir1; i1++) {
float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);

// broadcast the mask across rows
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;

float * wp = wdata;
for (int i = 0; i < nc; i++) {
wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f);
}

#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(sp[i]));
assert(!isnan(wp[i]));
}
#endif

float max = -INFINITY;
ggml_vec_max_f32(nc, &max, sp);
ggml_vec_max_f32(nc, &max, wp);

ggml_float sum = 0.0;

uint16_t scvt;
for (int i = 0; i < nc; i++) {
if (sp[i] == -INFINITY) {
if (wp[i] == -INFINITY) {
dp[i] = 0.0f;
} else {
// const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
// const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
memcpy(&scvt, &s, sizeof(scvt));
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
sum += (ggml_float)val;
Expand All @@ -10622,11 +10659,12 @@ static void ggml_compute_forward_soft_max_f32(
static void ggml_compute_forward_soft_max(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_soft_max_f32(params, src0, dst);
ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
} break;
default:
{
Expand Down Expand Up @@ -13863,7 +13901,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_SOFT_MAX:
{
ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_SOFT_MAX_BACK:
{
Expand Down Expand Up @@ -15899,6 +15937,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
}
} break;
case GGML_OP_SOFT_MAX:
{
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));

cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
} break;
case GGML_OP_CONV_TRANSPOSE_1D:
{
GGML_ASSERT(node->src[0]->ne[3] == 1);
Expand Down
8 changes: 8 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

// fused soft_max(a*scale + mask)
// mask is optional
GGML_API struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale);

GGML_API struct ggml_tensor * ggml_soft_max_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand Down
Loading