Skip to content
Merged
156 changes: 72 additions & 84 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "sampling.h"

#include <functional>

struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();

Expand Down Expand Up @@ -103,96 +101,89 @@ std::string llama_sampling_print(const llama_sampling_params & params) {

std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string result = "CFG -> Penalties ";

std::unordered_map<char, std::string> samplers_map_display {
{'k', "-> top_k "},
{'f', "-> tfs_z "},
{'y', "-> typical_p "},
{'p', "-> top_p "},
{'m', "-> min_p "},
{'t', "-> temp "}
};

if (params.mirostat == 0){
for (auto s : params.samplers_sequence){
result += samplers_map_display[s];
switch (s){
case 'k':{
result += "-> top_k ";
break;
}
case 'f':{
result += "-> tfs_z ";
break;
}
case 'y':{
result += "-> typical_p ";
break;
}
case 'p':{
result += "-> top_p ";
break;
}
case 'm':{
result += "-> min_p ";
break;
}
case 't':{
result += "-> temp ";
break;
}
default: break;
}
}
} else result += "-> mirostat ";

return result;
}

void sample_top_k(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep){

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
}

void sample_top_p(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep){

const float top_p = params.top_p;
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
}

void sample_tfs_z(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep){

const float tfs_z = params.tfs_z;
llama_sample_tail_free (ctx_main, &cur_p, tfs_z, min_keep);
}
// no reasons to expose this function in header
void sampler_queue(
struct llama_context * ctx_main,
const llama_sampling_params & params,
llama_token_data_array & cur_p,
size_t & min_keep) {
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

void sample_typical_p(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep){
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const std::string samplers_sequence = params.samplers_sequence;

for (auto s : samplers_sequence){
switch (s){
case 'k':{
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
break;
}
case 'f':{
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
break;
}
case 'y':{
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
break;
}
case 'p':{
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
break;
}
case 'm':{
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
break;
}
case 't':{
llama_sample_temp (ctx_main, &cur_p, temp);
break;
}
default: break;
}
}

const float typical_p = params.typical_p;
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
}

void sample_min_p(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep){

const float min_p = params.min_p;
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
}

void sample_temp(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep){

const float temp = params.temp;
llama_sample_temp (ctx_main, &cur_p, temp);
}

std::unordered_map<char, std::function<void(const llama_sampling_params &, struct llama_context *, llama_token_data_array&, size_t&)>> samplers_map
{
{'k', sample_top_k},
{'f', sample_tfs_z},
{'y', sample_typical_p},
{'p', sample_top_p},
{'m', sample_min_p},
{'t', sample_temp}
};

llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
Expand All @@ -211,7 +202,6 @@ llama_token llama_sampling_sample(
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
const std::string samplers_sequence = params.samplers_sequence;

auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
Expand Down Expand Up @@ -278,9 +268,7 @@ llama_token llama_sampling_sample(
// temperature sampling
size_t min_keep = std::max(1, params.n_probs);

for (auto s : samplers_sequence){
samplers_map[s](params, ctx_main, cur_p, min_keep);
}
sampler_queue(ctx_main, params, cur_p, min_keep);

id = llama_sample_token(ctx_main, &cur_p);

Expand Down
36 changes: 0 additions & 36 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,42 +84,6 @@ std::string llama_sampling_print(const llama_sampling_params & params);
// Print sampling order into a string
std::string llama_sampling_order_print(const llama_sampling_params & params);

void sample_top_k(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep);

void sample_top_p(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep);

void sample_tfs_z(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep);

void sample_typical_p(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep);

void sample_min_p(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep);

void sample_temp(
const llama_sampling_params & params,
struct llama_context * ctx_main,
llama_token_data_array & cur_p,
size_t & min_keep);

// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call
Expand Down