Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
36092aa
address clang-tidy lints
Feb 15, 2025
ef9b91a
tool-call: massive refactoring
Feb 15, 2025
2f683f0
rm minja dep from util & common
Feb 15, 2025
7a04ebc
move minja to common/minja
Feb 15, 2025
ece941b
Update utils.hpp
Feb 15, 2025
aa09a3c
add common_chat_tool
Feb 15, 2025
7ae7560
force utf8 encoding in get_chat_template
Feb 15, 2025
646528a
fix json tools parsing
Feb 16, 2025
db2b44e
add json tools / messages parsing helpers to common
Feb 16, 2025
c7c8907
fix common_chat_msgs_parse_oaicompat
Feb 16, 2025
5f17156
concat multipart content in legacy template path
Feb 16, 2025
ee9b9d6
add name & tool_call_id to common_chat_msg
Feb 16, 2025
07f0ad0
Update test-chat.cpp
Feb 16, 2025
1acda5f
test & fix json<->msg conversions
Feb 16, 2025
a58e1fc
fix typo
Feb 16, 2025
103c840
fix content part string concat in legacy template branch
Feb 16, 2025
c154c02
test tools json conversions
Feb 16, 2025
3d41f1b
test content parts in test-chat
Feb 16, 2025
59c8059
fix clang-tidy lints in [test-]chat.*
Feb 16, 2025
1847cae
fix deepseek r1 slow test (no longer <think> opening w/ new template)
Feb 16, 2025
8462a51
fix lints in test-chat-template.cpp
Feb 16, 2025
80c432b
tweak test_calc_result expectations
Feb 16, 2025
42b29e1
fix double bos/eos jinja avoidance hack (was preventing inner bos/eos…
Feb 16, 2025
ce4ccf0
add common_chat_templates_source + rehab server template logs
Feb 16, 2025
cb31f08
fix msg lints
Feb 16, 2025
76f5d27
tool-call: allow empty tools w/ auto + grammar
Feb 16, 2025
34e4e22
fix & test grammar & json_schema w/ & w/o --jinja
Feb 16, 2025
1c6168b
Update test-chat-template.cpp
Feb 16, 2025
ae6b870
test & fix array message.content
Feb 16, 2025
1421037
fix links to prepare merge
Feb 16, 2025
d95a17c
Merge remote-tracking branch 'origin/master' into chat-cleanups
Feb 16, 2025
5a5ed7b
fix merge
Feb 16, 2025
dd5ef85
rm trailing spaces
Feb 16, 2025
2f2f0fa
Add missing <optional> include to chat.cpp
Feb 16, 2025
a58b9e5
tiny fix: somehow llama_token being defined in an extern c makes it l…
Feb 16, 2025
f999ff5
alternative fix for gcc c vs. c++ weirdness
Feb 16, 2025
55a7614
add missing <regex> include to test-chat-template
Feb 16, 2025
9d62f62
Update chat.hpp
Feb 16, 2025
da0982a
have common_chat_templates_init return a unique_ptr
Feb 17, 2025
7ddb454
chat.{hpp -> h}
Feb 17, 2025
d2969b8
build common_chat_templates_ptr earlier
Feb 17, 2025
fd2b8e1
use deleter functor for common_chat_templates_ptr
Feb 17, 2025
9a85439
Merge remote-tracking branch 'origin/master' into chat-cleanups
Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,7 @@ llama-server: \
examples/server/index.html.hpp \
examples/server/loading.html.hpp \
common/chat.cpp \
common/chat.hpp \
common/chat.h \
common/chat-template.hpp \
common/json.hpp \
common/minja.hpp \
Expand Down
2 changes: 1 addition & 1 deletion common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ add_library(${TARGET} STATIC
arg.h
base64.hpp
chat.cpp
chat.hpp
chat.h
common.cpp
common.h
console.cpp
Expand Down
2 changes: 1 addition & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include "log.h"
#include "sampling.h"
#include "chat.hpp"
#include "chat.h"

#include <algorithm>
#include <climits>
Expand Down
10 changes: 5 additions & 5 deletions common/chat.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "chat.hpp"
#include "chat.h"
#include "json-schema-to-grammar.h"
#include "log.h"
#include "minja/chat-template.hpp"
Expand Down Expand Up @@ -269,12 +269,12 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
msg.role = "user";
msg.content = "test";

auto * tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);

common_chat_templates_inputs inputs;
inputs.messages = {msg};

common_chat_templates_apply(tmpls, inputs);
common_chat_templates_apply(tmpls.get(), inputs);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
Expand Down Expand Up @@ -362,7 +362,7 @@ const char * common_chat_templates_source(const struct common_chat_templates * t
return tmpls->template_default->source().c_str();
}

struct common_chat_templates * common_chat_templates_init(
common_chat_templates_ptr common_chat_templates_init(
const struct llama_model * model,
const std::string & chat_template_override,
const std::string & bos_token_override,
Expand Down Expand Up @@ -426,7 +426,7 @@ struct common_chat_templates * common_chat_templates_init(
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
}
}
return tmpls;
return {tmpls, common_chat_templates_free};
}

std::string common_chat_format_name(common_chat_format format) {
Expand Down
8 changes: 5 additions & 3 deletions common/chat.hpp → common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,19 @@ struct common_chat_params {
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);

struct common_chat_templates * common_chat_templates_init(

void common_chat_templates_free(struct common_chat_templates * tmpls);
typedef std::unique_ptr<struct common_chat_templates, decltype(&common_chat_templates_free)> common_chat_templates_ptr;

common_chat_templates_ptr common_chat_templates_init(
const struct llama_model * model,
const std::string & chat_template_override,
const std::string & bos_token_override = "",
const std::string & eos_token_override = "");

bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
void common_chat_templates_free(struct common_chat_templates * tmpls);

typedef std::unique_ptr<struct common_chat_templates, decltype(&common_chat_templates_free)> common_chat_templates_ptr;

struct common_chat_params common_chat_templates_apply(
const struct common_chat_templates * tmpls,
Expand Down
6 changes: 2 additions & 4 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "log.h"
#include "sampling.h"
#include "llama.h"
#include "chat.hpp"
#include "chat.h"

#include <cstdio>
#include <cstring>
Expand Down Expand Up @@ -158,9 +158,7 @@ int main(int argc, char ** argv) {
}

const llama_vocab * vocab = llama_model_get_vocab(model);
common_chat_templates_ptr chat_templates(
common_chat_templates_init(model, params.chat_template),
&common_chat_templates_free);
auto chat_templates = common_chat_templates_init(model, params.chat_template);

LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);

Expand Down
6 changes: 2 additions & 4 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <string>
#include <vector>

#include "chat.hpp"
#include "chat.h"
#include "common.h"
#include "json.hpp"
#include "linenoise.cpp/linenoise.h"
Expand Down Expand Up @@ -1057,9 +1057,7 @@ static int get_user_input(std::string & user_input, const std::string & user) {
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
common_chat_templates_ptr chat_templates(
common_chat_templates_init(llama_data.model.get(), ""),
&common_chat_templates_free);
auto chat_templates = common_chat_templates_init(llama_data.model.get(), "");
static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
Expand Down
20 changes: 10 additions & 10 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1804,7 +1804,9 @@ struct server_context {
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;

struct common_chat_templates * chat_templates = nullptr;
common_chat_templates_ptr chat_templates;

server_context() : chat_templates(nullptr, nullptr) {}

~server_context() {
// Clear any sampling context
Expand All @@ -1822,7 +1824,6 @@ struct server_context {
}

llama_batch_free(batch);
common_chat_templates_free(chat_templates);
}

bool load_model(const common_params & params) {
Expand Down Expand Up @@ -1891,10 +1892,9 @@ struct server_context {

chat_templates = common_chat_templates_init(model, params_base.chat_template);
try {
common_chat_format_example(chat_templates, params.use_jinja);
common_chat_format_example(chat_templates.get(), params.use_jinja);
} catch (const std::exception & e) {
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
common_chat_templates_free(chat_templates);
chat_templates = common_chat_templates_init(model, "chatml");
}

Expand Down Expand Up @@ -3793,13 +3793,13 @@ int main(int argc, char ** argv) {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model },
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates) },
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
{ "build_info", build_info },
};
if (ctx_server.params_base.use_jinja) {
if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates, "tool_use")) {
if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
data["chat_template_tool_use"] = tool_use_src;
}
}
Expand Down Expand Up @@ -4036,7 +4036,7 @@ int main(int argc, char ** argv) {
}

auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());

return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
Expand All @@ -4049,7 +4049,7 @@ int main(int argc, char ** argv) {
// same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
};

Expand Down Expand Up @@ -4455,8 +4455,8 @@ int main(int argc, char ** argv) {

// print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
common_chat_templates_source(ctx_server.chat_templates),
common_chat_format_example(ctx_server.chat_templates, ctx_server.params_base.use_jinja).c_str());
common_chat_templates_source(ctx_server.chat_templates.get()),
common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());

ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
ctx_server.process_single_task(task);
Expand Down
2 changes: 1 addition & 1 deletion examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "chat.hpp"
#include "chat.h"

#include <random>
#include <sstream>
Expand Down
8 changes: 4 additions & 4 deletions tests/test-chat-template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include "llama.h"
#include "common.h"
#include "chat.hpp"
#include "chat.h"

static std::string normalize_newlines(const std::string & s) {
#ifdef _WIN32
Expand Down Expand Up @@ -322,7 +322,7 @@ int main(void) {
}
printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
try {
common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token), &common_chat_templates_free);
auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token);
common_chat_templates_inputs inputs;
inputs.use_jinja = true;
inputs.messages = messages;
Expand All @@ -349,7 +349,7 @@ int main(void) {
auto sys_msg = simple_msg("system", "You are a helpful assistant");

auto fmt_sys = [&](std::string tmpl_str) {
common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, tmpl_str), &common_chat_templates_free);
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false);
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n");
Expand All @@ -376,7 +376,7 @@ int main(void) {
auto new_msg = simple_msg("user", "How are you");

auto fmt_single = [&](const std::string & tmpl_str) {
common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str()), &common_chat_templates_free);
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str());
auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false);
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n");
Expand Down
4 changes: 2 additions & 2 deletions tests/test-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <json.hpp>
#include <string>

#include "chat.hpp"
#include "chat.h"
#include "llama-grammar.h"
#include "unicode.h"

Expand Down Expand Up @@ -45,7 +45,7 @@ static std::string read_file(const std::string & path) {
}

static common_chat_templates_ptr read_templates(const std::string & path) {
return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path)), &common_chat_templates_free);
return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path)));
}

static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
Expand Down
Loading