@@ -940,10 +940,11 @@ llama_token llama_sample_token(
940
940
struct llama_context * ctx,
941
941
struct llama_context * ctx_guidance,
942
942
struct llama_grammar * grammar,
943
- const struct gpt_params & params,
943
+ struct gpt_params & params,
944
944
const std::vector<llama_token> & last_tokens,
945
945
std::vector<llama_token_data> & candidates,
946
- int idx) {
946
+ const int idx,
947
+ llama_seq_id seq) {
947
948
const int n_ctx = llama_n_ctx (ctx);
948
949
const int n_vocab = llama_n_vocab (llama_get_model (ctx));
949
950
@@ -1011,15 +1012,23 @@ llama_token llama_sample_token(
1011
1012
// Greedy sampling
1012
1013
id = llama_sample_token_greedy (ctx, &cur_p);
1013
1014
} else {
1015
+ float * mirostat_mu = NULL ;
1016
+ if (mirostat > 0 ) {
1017
+ seq = std::max (0 , seq); // Deal with people passing -1 or something.
1018
+ auto mu_it = params.sampler_state .find (seq);
1019
+ if (mu_it == params.sampler_state .end ()) {
1020
+ const llama_sampler_state new_state = { 2 .0f * mirostat_tau };
1021
+ mu_it = params.sampler_state .insert ({seq, new_state}).first ;
1022
+ }
1023
+ mirostat_mu = &mu_it->second .mirostat_mu ;
1024
+ }
1014
1025
if (mirostat == 1 ) {
1015
- static float mirostat_mu = 2 .0f * mirostat_tau;
1016
1026
const int mirostat_m = 100 ;
1017
1027
llama_sample_temp (ctx, &cur_p, temp);
1018
- id = llama_sample_token_mirostat (ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, & mirostat_mu);
1028
+ id = llama_sample_token_mirostat (ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, mirostat_mu);
1019
1029
} else if (mirostat == 2 ) {
1020
- static float mirostat_mu = 2 .0f * mirostat_tau;
1021
1030
llama_sample_temp (ctx, &cur_p, temp);
1022
- id = llama_sample_token_mirostat_v2 (ctx, &cur_p, mirostat_tau, mirostat_eta, & mirostat_mu);
1031
+ id = llama_sample_token_mirostat_v2 (ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_mu);
1023
1032
} else {
1024
1033
// Temperature sampling
1025
1034
size_t min_keep = std::max (1 , params.n_probs );
0 commit comments