Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
context : remove logits_all flag
ggml-ci
  • Loading branch information
ggerganov committed May 8, 2025
commit 6c0501adf78d6ae185027eaa439ea04e7ca6d9ae
7 changes: 0 additions & 7 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2097,13 +2097,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.cache_type_v = kv_cache_type_from_str(value);
}
).set_env("LLAMA_ARG_CACHE_TYPE_V"));
add_opt(common_arg(
{"--perplexity", "--all-logits"},
string_format("return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false"),
[](common_params & params) {
params.logits_all = true;
}
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
add_opt(common_arg(
{"--hellaswag"},
"compute HellaSwag score over random tasks from datafile supplied with -f",
Expand Down
2 changes: 1 addition & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.n_threads = params.cpuparams.n_threads;
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
cparams.logits_all = params.logits_all;
cparams.logits_all = false;
cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base;
Expand Down
1 change: 0 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ struct common_params {
bool ctx_shift = true; // context shift on inifinite text generation

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
bool verbose_prompt = false; // print prompt tokens before generation
Expand Down
4 changes: 1 addition & 3 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ llama_context::llama_context(
__func__, n_ctx_per_seq, hparams.n_ctx_train);
}

logits_all = params.logits_all;

if (!hparams.vocab_only) {
// GPU backends
for (auto * dev : model.devices) {
Expand Down Expand Up @@ -890,7 +888,7 @@ int llama_context::decode(llama_batch & inp_batch) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs_all += batch.logits[i] != 0;
}
} else if (logits_all || embd_pooled) {
} else if (embd_pooled) {
n_outputs_all = n_tokens_all;
} else {
// keep last output only
Expand Down
3 changes: 0 additions & 3 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,6 @@ struct llama_context {

std::unique_ptr<llama_memory_i> memory;

// TODO: remove
bool logits_all = false;

// decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
Expand Down
1 change: 0 additions & 1 deletion tools/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,6 @@ int main(int argc, char ** argv) {
params.out_file = "imatrix.dat" ;

params.n_ctx = 512;
params.logits_all = true;
params.escape = false;

if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_IMATRIX, print_usage)) {
Expand Down
8 changes: 0 additions & 8 deletions tools/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,6 @@ int main(int argc, char ** argv) {
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });

if (params.logits_all) {
LOG_ERR("************\n");
LOG_ERR("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
LOG_ERR("************\n\n");

return 0;
}

if (params.embedding) {
LOG_ERR("************\n");
LOG_ERR("%s: please use the 'embedding' tool for embedding calculations\n", __func__);
Expand Down
6 changes: 4 additions & 2 deletions tools/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1554,7 +1554,10 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
if (int(batch_indeces.size()) != num_answers) {
batch_indeces.resize(num_answers);
}
for (int s = 0; s < num_answers; ++s) batch_indeces[s] = s0 + s;

for (int s = 0; s < num_answers; ++s) {
batch_indeces[s] = s0 + s;
}

for (size_t i = 0; i < cur_task.common_prefix; ++i) {
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
Expand Down Expand Up @@ -1970,7 +1973,6 @@ int main(int argc, char ** argv) {
common_params params;

params.n_ctx = 512;
params.logits_all = true;
params.escape = false;

if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
Expand Down