@@ -458,23 +458,24 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
458458 return true ;
459459}
460460
461+ #define K_TOKEN_CHUNK 4
462+
461463static void compute_logprobs (const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
462464 const std::vector<std::pair<size_t , llama_token>>& eval_pairs, std::vector<float >& eval_results) {
463- constexpr int k_token_chunk = 4 ;
464465 if (eval_results.size () != eval_pairs.size ()) {
465466 eval_results.resize (eval_pairs.size ());
466467 }
467468 if (eval_pairs.empty ()) return ;
468469
469- size_t max_threads = std::min ((eval_pairs.size () + k_token_chunk - 1 )/k_token_chunk , workers.size ());
470+ size_t max_threads = std::min ((eval_pairs.size () + K_TOKEN_CHUNK - 1 )/K_TOKEN_CHUNK , workers.size ());
470471
471472 std::atomic<int > counter (0 );
472473 auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
473- float local_logprobs[k_token_chunk ];
474+ float local_logprobs[K_TOKEN_CHUNK ];
474475 while (true ) {
475- size_t first = counter.fetch_add (k_token_chunk , std::memory_order_relaxed);
476+ size_t first = counter.fetch_add (K_TOKEN_CHUNK , std::memory_order_relaxed);
476477 if (first >= eval_results.size ()) break ;
477- size_t last = std::min (first + k_token_chunk , eval_results.size ());
478+ size_t last = std::min (first + K_TOKEN_CHUNK , eval_results.size ());
478479 for (size_t i = first; i < last; ++i) {
479480 auto logits = batch_logits + eval_pairs[i].first * n_vocab;
480481 float max_logit = logits[0 ];
@@ -497,7 +498,6 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
497498 for (size_t it = 0 ; it < max_threads; ++it) {
498499 workers[it].join ();
499500 }
500-
501501}
502502
503503static void hellaswag_score (llama_context * ctx, const gpt_params & params) {
0 commit comments