Skip to content
Merged
6 changes: 6 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
} else if (arg == "--sampling-seq") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.samplers_sequence = argv[i];
} else if (arg == "--top-p") {
if (++i >= argc) {
invalid_param = true;
Expand Down
78 changes: 58 additions & 20 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,54 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
return std::string(result);
}

std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string result = "CFG -> Penalties ";
if (params.mirostat == 0){
for (auto s : params.samplers_sequence){
switch (s){
case 'k': result += "-> top_k "; break;
case 'f': result += "-> tfs_z "; break;
case 'y': result += "-> typical_p "; break;
case 'p': result += "-> top_p "; break;
case 'm': result += "-> min_p "; break;
case 't': result += "-> temp "; break;
default : break;
}
}
} else result += "-> mirostat ";

return result;
}

// no reasons to expose this function in header
void sampler_queue(
struct llama_context * ctx_main,
const llama_sampling_params & params,
llama_token_data_array & cur_p,
size_t & min_keep) {
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const std::string samplers_sequence = params.samplers_sequence;

for (auto s : samplers_sequence){
switch (s){
case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
case 'y': llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
case 'm': llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
case 't': llama_sample_temp (ctx_main, &cur_p, temp); break;
default : break;
}
}
}

llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
Expand All @@ -108,20 +156,15 @@ llama_token llama_sampling_sample(

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
const float temp = params.temp;
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;

auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
Expand Down Expand Up @@ -188,12 +231,7 @@ llama_token llama_sampling_sample(
// temperature sampling
size_t min_keep = std::max(1, params.n_probs);

llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
llama_sample_temp (ctx_main, &cur_p, temp);
sampler_queue(ctx_main, params, cur_p, min_keep);

id = llama_sample_token(ctx_main, &cur_p);

Expand Down
36 changes: 20 additions & 16 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,23 @@

// sampling parameters
typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.10f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = true; // consider newlines as a repeatable token
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.10f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = true; // consider newlines as a repeatable token
std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp

std::string grammar; // optional BNF-like grammar to constrain sampling

Expand Down Expand Up @@ -80,6 +81,9 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
// Print sampling parameters into a string
std::string llama_sampling_print(const llama_sampling_params & params);

// Print sampling order into a string
std::string llama_sampling_order_print(const llama_sampling_params & params);

// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call
Expand Down
1 change: 1 addition & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ int main(int argc, char ** argv) {
}
}
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");

Expand Down