@@ -197,10 +197,10 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
197
197
sampler_tester tester (n_vocab);
198
198
199
199
llama_token min_token_id = 0 ;
200
- const llama_token max_token_id = n_vocab- 1 ;
200
+ const llama_token max_token_id = n_vocab - 1 ;
201
201
202
202
for (auto s : samplers_sequence) {
203
- switch (s){
203
+ switch (s) {
204
204
case ' k' : tester.apply (llama_sampler_init_top_k (top_k)); break ;
205
205
case ' y' : GGML_ABORT (" typical test not implemented" );
206
206
case ' p' : tester.apply (llama_sampler_init_top_p (top_p, 1 )); break ;
@@ -243,10 +243,10 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
243
243
}
244
244
245
245
GGML_ASSERT (size == expected_size);
246
- GGML_ASSERT (cur_p.data [0 ].id == max_token_id);
247
- GGML_ASSERT (cur_p.data [expected_size-1 ].id == min_token_id);
246
+ GGML_ASSERT (!cur_p. sorted || cur_p.data [0 ].id == max_token_id);
247
+ GGML_ASSERT (!cur_p. sorted || cur_p.data [expected_size-1 ].id == min_token_id);
248
248
} else if (s == ' m' ) {
249
- int expected_size = ceilf ((1 .0f - min_p) * n_vocab);
249
+ int expected_size = ceilf ((1 .0f - min_p) * n_vocab);
250
250
expected_size = std::max (expected_size, 1 );
251
251
expected_size = std::min (expected_size, size);
252
252
@@ -256,14 +256,14 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
256
256
min_token_id = std::min (min_token_id, (llama_token)(n_vocab - 1 ));
257
257
258
258
GGML_ASSERT (size == expected_size);
259
- GGML_ASSERT (cur_p.data [0 ].id == max_token_id);
260
- GGML_ASSERT (cur_p.data [expected_size-1 ].id == min_token_id);
259
+ GGML_ASSERT (!cur_p. sorted || cur_p.data [0 ].id == max_token_id);
260
+ GGML_ASSERT (!cur_p. sorted || cur_p.data [expected_size-1 ].id == min_token_id);
261
261
} else {
262
262
GGML_ABORT (" fatal error" );
263
263
}
264
264
}
265
265
266
- printf (" Sampler queue %3s OK with n_vocab=%05zu top_k=%05d top_p=%f min_p=%f\n " ,
266
+ printf (" Sampler queue %3s OK with n_vocab=%05zu top_k=%5d top_p=%f min_p=%f\n " ,
267
267
samplers_sequence.c_str (), n_vocab, top_k, top_p, min_p);
268
268
}
269
269
@@ -308,28 +308,28 @@ static void test_perf() {
308
308
int main (void ) {
309
309
ggml_time_init ();
310
310
311
- test_temp ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 1 .0f );
312
- test_temp ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {1 .0f , 0 .0f , 0 .0f , 0 .0f }, 0 .0f );
311
+ test_temp ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f , 0 .2f , 0 .3f , 0 .4f }, 1 .0f );
312
+ test_temp ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .0f , 0 .0f , 0 .0f , 1 .0f }, 0 .0f );
313
313
314
- test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 1 .0f , 0 .0f , 1 .0f );
315
- test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {1 .0f , 0 .0f , 0 .0f , 0 .0f }, 0 .0f , 0 .0f , 1 .0f );
314
+ test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f , 0 .2f , 0 .3f , 0 .4f }, 1 .0f , 0 .0f , 1 .0f );
315
+ test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .0f , 0 .0f , 0 .0f , 1 .0f }, 0 .0f , 0 .0f , 1 .0f );
316
316
317
317
test_top_k ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {1 .0f }, 1 );
318
318
test_top_k ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .44444f , 0 .33333f , 0 .22222f }, 3 );
319
319
test_top_k ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 4 );
320
- test_top_k ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 0 );
320
+ test_top_k ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f , 0 .2f , 0 .3f , 0 .4f }, 0 );
321
321
322
322
test_top_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {1 .0f }, 0 );
323
323
test_top_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .571429f , 0 .428571f }, 0 .7f );
324
324
test_top_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .44444f , 0 .33333f , 0 .22222f }, 0 .8f );
325
- test_top_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 1 .0f );
326
-
327
- test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /1 .0f , 0 .3f /1 .0f , 0 .2f /1 .0f , 0 .1f /1 .0f }, 0 .00f );
328
- test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /1 .0f , 0 .3f /1 .0f , 0 .2f /1 .0f , 0 .1f /1 .0f }, 0 .24f );
329
- test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .9f , 0 .3f /0 .9f , 0 .2f /0 .9f }, 0 .26f );
330
- test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .9f , 0 .3f /0 .9f , 0 .2f /0 .9f }, 0 .49f );
331
- test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .7f , 0 .3f /0 .7f }, 0 .51f );
332
- test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .7f , 0 .3f /0 .7f }, 0 .74f );
325
+ test_top_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f , 0 .2f , 0 .3f , 0 .4f }, 1 .0f );
326
+
327
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f /1 .0f , 0 .2f /1 .0f , 0 .3f /1 .0f , 0 .4f /1 .0f }, 0 .00f );
328
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f /1 .0f , 0 .2f /1 .0f , 0 .3f /1 .0f , 0 .4f /1 .0f }, 0 .24f );
329
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .2f /0 .9f , 0 .3f /0 .9f , 0 .4f /0 .9f }, 0 .26f );
330
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .2f /0 .9f , 0 .3f /0 .9f , 0 .4f /0 .9f }, 0 .49f );
331
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .3f /0 .7f , 0 .4f /0 .7f }, 0 .51f );
332
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .3f /0 .7f , 0 .4f /0 .7f }, 0 .74f );
333
333
test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 0 .76f );
334
334
test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 1 .00f );
335
335
test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 1 .05f );
@@ -345,23 +345,23 @@ int main(void) {
345
345
test_typical ({0 .97f , 0 .01f , 0 .01f , 0 .01f }, {0 .97f }, 0 .5f );
346
346
test_typical ({0 .4f , 0 .2f , 0 .2f , 0 .2f }, {0 .2f , 0 .2f , 0 .2f }, 0 .5f );
347
347
348
- test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 }, {0 . 25f , 0 .25f , 0 .25f , 0 .25f , 0 }, 50 .0f , 0 .0f , 0 .0f );
349
- test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 . 5f , 0 . 5f , 0 , 0 , 0 }, 50 .0f , 0 .0f , 0 .0f );
350
- test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 0 }, {0 . 5f , 0 . 5f , 0 , 0 , 0 }, 50 .0f , 0 .0f , 0 .0f );
348
+ test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 }, {0 , 0 .25f , 0 .25f , 0 .25f , 0 . 25f }, 50 .0f , 0 .0f , 0 .0f );
349
+ test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 , 0 , 0 , 0 . 5f , 0 . 5f }, 50 .0f , 0 .0f , 0 .0f );
350
+ test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 0 }, {0 , 0 , 0 , 0 . 5f , 0 . 5f }, 50 .0f , 0 .0f , 0 .0f );
351
351
352
- test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 }, {0 .249997f , 0 .249997f , 0 .249997f , 0 .249997f , 0 .000011f }, 1 .0f , 5 .0f , 5 .0f );
353
- test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 .499966f , 0 .499966f , 0 .000023f , 0 .000023f , 0 .000023f }, 1 .0f , 5 .0f , 5 .0f );
354
- test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 0 }, {0 .499977f , 0 .499977f , 0 .000023f , 0 .000023f , 0 .000000f }, 1 .0f , 5 .0f , 5 .0f );
352
+ test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 }, {0 .000011f , 0 .249997f , 0 .249997f , 0 .249997f , 0 .249997f }, 1 .0f , 5 .0f , 5 .0f );
353
+ test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 .000023f , 0 .000023f , 0 .000023f , 0 .499966f , 0 .499966f }, 1 .0f , 5 .0f , 5 .0f );
354
+ test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 0 }, {0 .000000f , 0 .000023f , 0 .000023f , 0 .499977f , 0 .499977f }, 1 .0f , 5 .0f , 5 .0f );
355
355
356
356
357
357
test_dry ({0 .25f , 0 .25f , 0 .25f , 0 .25f }, {0 , 1 }, {0 .25f , 0 .25f , 0 .25f , 0 .25f }, 1 .0f , 1 .1f , 2 , 4 , {});
358
- test_dry ({0 .25f , 0 .25f , 0 .25f , 0 .25f }, {0 , 1 , 2 , 0 , 1 }, {0 .296923f , 0 .296923f , 0 .296923f , 0 .109232f }, 1 .0f , 1 .1f , 2 , 5 , {});
358
+ test_dry ({0 .25f , 0 .25f , 0 .25f , 0 .25f }, {0 , 1 , 2 , 0 , 1 }, {0 .296923f , 0 .296923f , 0 .109232f , 0 .296923f }, 1 .0f , 1 .1f , 2 , 5 , {});
359
359
test_dry ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 3 , 4 , 0 , 1 }, {0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, 1 .0f , 1 .1f , 2 , 6 , {{3 }});
360
- test_dry ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 1 }, {0 .241818f , 0 .241818f , 0 .241818f , 0 .241818f , 0 .032727f }, 2 .0f , 1 .1f , 2 , 5 , {});
360
+ test_dry ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 1 }, {0 .241818f , 0 .241818f , 0 .032727f , 0 .241818f , 0 .241818f }, 2 .0f , 1 .1f , 2 , 5 , {});
361
361
test_dry ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 3 , 4 , 0 , 1 }, {0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, 1 .0f , 1 .1f , 4 , 7 , {});
362
362
363
363
test_top_n_sigma ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .571429f , 0 .428571f , 0 .0f , 0 .0f }, 1 .00f );
364
- test_top_n_sigma ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 0 .00f ); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
364
+ test_top_n_sigma ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f , 0 .2f , 0 .3f , 0 .4f }, 0 .00f ); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
365
365
test_top_n_sigma ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 3 .00f );
366
366
367
367
test_sampler_queue (10000 , " k" , 10000 , 1 .0f , 1 .0f );
0 commit comments