Skip to content
Next Next commit
Rewrite special token handling from #1931
  • Loading branch information
staviq committed Oct 8, 2023
commit b592c70deb98c7be52b9b3fd2c378a7e2e68fab3
12 changes: 7 additions & 5 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,21 +862,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos) {
return llama_tokenize(llama_get_model(ctx), text, add_bos);
bool add_bos,
bool allow_special_tokens) {
return llama_tokenize(llama_get_model(ctx), text, add_bos, allow_special_tokens);
}

std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos) {
bool add_bos,
bool allow_special_tokens) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
Expand Down
6 changes: 4 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,14 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos);
bool add_bos,
bool allow_special_tokens = false);

std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos);
bool add_bos,
bool allow_special_tokens = false);

// tokenizes a token into a piece
// should work similar to Python's `tokenizer.id_to_piece`
Expand Down
8 changes: 4 additions & 4 deletions common/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ size_t tokenize_file(
(int) buf.size(),
out_tokens.data(),
(int) out_tokens.size(),
false);
false,false);
if (n_tokens < 0) {
out_tokens.resize(-n_tokens);
n_tokens = llama_tokenize(
Expand All @@ -872,7 +872,7 @@ size_t tokenize_file(
(int) buf.size(),
out_tokens.data(),
(int) out_tokens.size(),
false);
false,false);
}
if (n_tokens >= 0) {
out_tokens.resize(n_tokens);
Expand Down Expand Up @@ -966,15 +966,15 @@ size_t tokenize_file(
(int) buf_sample.size(),
tok_sample.data(),
(int) tok_sample.size(),
false);
false,false);
if (n_tokens < 0) {
tok_sample.resize(-n_tokens);
n_tokens = llama_tokenize(llama_get_model(lctx),
buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(),
(int) tok_sample.size(),
false);
false,false);
GGML_ASSERT(n_tokens >= 0);
}
GGML_ASSERT(n_tokens <= (int) tok_sample.size());
Expand Down
14 changes: 7 additions & 7 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ int main(int argc, char ** argv) {

if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt\n");
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
} else {
LOG("use session tokens\n");
embd_inp = session_tokens;
Expand All @@ -259,10 +259,10 @@ int main(int argc, char ** argv) {
if (ctx_guidance) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt));

guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos);
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos, true);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));

std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));

original_prompt_len = original_inp.size();
Expand Down Expand Up @@ -316,8 +316,8 @@ int main(int argc, char ** argv) {
}

// prefix & suffix for instruct mode
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);

LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
Expand Down Expand Up @@ -715,7 +715,7 @@ int main(int argc, char ** argv) {
if (params.interactive) {
if (!params.antiprompt.empty()) {
// tokenize and inject first reverse prompt
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true);
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
is_antiprompt = true;
}
Expand Down Expand Up @@ -780,7 +780,7 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
}

const auto line_inp = ::llama_tokenize(ctx, buffer, false);
const auto line_inp = ::llama_tokenize(ctx, buffer, false, true);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));

embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
Expand Down
Loading