Skip to content

Commit 255b070

Browse files
committed
sampling : do not sort in dist sampler
ggml-ci
1 parent 0b57c88 commit 255b070

File tree

2 files changed

+32
-32
lines changed

2 files changed

+32
-32
lines changed

src/llama-sampling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,8 +594,8 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
594594
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
595595
auto * ctx = (llama_sampler_dist *) smpl->ctx;
596596

597-
// sorting is not necessary here, but for now we are doing it
598-
llama_sampler_softmax_impl(cur_p, true);
597+
// sorting is not necessary here
598+
llama_sampler_softmax_impl(cur_p, false);
599599

600600
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
601601
}

tests/test-sampling.cpp

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,10 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
197197
sampler_tester tester(n_vocab);
198198

199199
llama_token min_token_id = 0;
200-
const llama_token max_token_id = n_vocab-1;
200+
const llama_token max_token_id = n_vocab - 1;
201201

202202
for (auto s : samplers_sequence) {
203-
switch (s){
203+
switch (s) {
204204
case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
205205
case 'y': GGML_ABORT("typical test not implemented");
206206
case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
@@ -243,10 +243,10 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
243243
}
244244

245245
GGML_ASSERT(size == expected_size);
246-
GGML_ASSERT(cur_p.data[0].id == max_token_id);
247-
GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
246+
GGML_ASSERT(!cur_p.sorted || cur_p.data[0].id == max_token_id);
247+
GGML_ASSERT(!cur_p.sorted || cur_p.data[expected_size-1].id == min_token_id);
248248
} else if (s == 'm') {
249-
int expected_size = ceilf((1.0f-min_p) * n_vocab);
249+
int expected_size = ceilf((1.0f - min_p) * n_vocab);
250250
expected_size = std::max(expected_size, 1);
251251
expected_size = std::min(expected_size, size);
252252

@@ -256,14 +256,14 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
256256
min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1));
257257

258258
GGML_ASSERT(size == expected_size);
259-
GGML_ASSERT(cur_p.data[0].id == max_token_id);
260-
GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
259+
GGML_ASSERT(!cur_p.sorted || cur_p.data[0].id == max_token_id);
260+
GGML_ASSERT(!cur_p.sorted || cur_p.data[expected_size-1].id == min_token_id);
261261
} else {
262262
GGML_ABORT("fatal error");
263263
}
264264
}
265265

266-
printf("Sampler queue %3s OK with n_vocab=%05zu top_k=%05d top_p=%f min_p=%f\n",
266+
printf("Sampler queue %3s OK with n_vocab=%05zu top_k=%5d top_p=%f min_p=%f\n",
267267
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
268268
}
269269

@@ -308,28 +308,28 @@ static void test_perf() {
308308
int main(void) {
309309
ggml_time_init();
310310

311-
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
312-
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
311+
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 1.0f);
312+
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.0f, 0.0f, 0.0f, 1.0f}, 0.0f);
313313

314-
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
315-
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
314+
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 1.0f, 0.0f, 1.0f);
315+
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.0f, 0.0f, 0.0f, 1.0f}, 0.0f, 0.0f, 1.0f);
316316

317317
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
318318
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
319319
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
320-
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
320+
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 0);
321321

322322
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 0);
323323
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f}, 0.7f);
324324
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 0.8f);
325-
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
326-
327-
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f);
328-
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);
329-
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f}, 0.26f);
330-
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f}, 0.49f);
331-
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.51f);
332-
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.74f);
325+
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 1.0f);
326+
327+
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f/1.0f, 0.2f/1.0f, 0.3f/1.0f, 0.4f/1.0f}, 0.00f);
328+
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f/1.0f, 0.2f/1.0f, 0.3f/1.0f, 0.4f/1.0f}, 0.24f);
329+
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.2f/0.9f, 0.3f/0.9f, 0.4f/0.9f}, 0.26f);
330+
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.2f/0.9f, 0.3f/0.9f, 0.4f/0.9f}, 0.49f);
331+
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.3f/0.7f, 0.4f/0.7f}, 0.51f);
332+
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.3f/0.7f, 0.4f/0.7f}, 0.74f);
333333
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
334334
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
335335
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.05f);
@@ -345,23 +345,23 @@ int main(void) {
345345
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
346346
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
347347

348-
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
349-
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
350-
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
348+
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0, 0.25f, 0.25f, 0.25f, 0.25f}, 50.0f, 0.0f, 0.0f);
349+
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0, 0, 0, 0.5f, 0.5f}, 50.0f, 0.0f, 0.0f);
350+
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0, 0, 0, 0.5f, 0.5f}, 50.0f, 0.0f, 0.0f);
351351

352-
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
353-
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
354-
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
352+
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.000011f, 0.249997f, 0.249997f, 0.249997f, 0.249997f}, 1.0f, 5.0f, 5.0f);
353+
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.000023f, 0.000023f, 0.000023f, 0.499966f, 0.499966f}, 1.0f, 5.0f, 5.0f);
354+
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.000000f, 0.000023f, 0.000023f, 0.499977f, 0.499977f}, 1.0f, 5.0f, 5.0f);
355355

356356

357357
test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1}, {0.25f, 0.25f, 0.25f, 0.25f}, 1.0f, 1.1f, 2, 4, {});
358-
test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1, 2, 0, 1}, {0.296923f, 0.296923f, 0.296923f, 0.109232f}, 1.0f, 1.1f, 2, 5, {});
358+
test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1, 2, 0, 1}, {0.296923f, 0.296923f, 0.109232f, 0.296923f}, 1.0f, 1.1f, 2, 5, {});
359359
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 2, 6, {{3}});
360-
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
360+
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.032727f, 0.241818f, 0.241818f}, 2.0f, 1.1f, 2, 5, {});
361361
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
362362

363363
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
364-
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.00f); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
364+
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 0.00f); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
365365
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
366366

367367
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);

0 commit comments

Comments
 (0)