diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0f39def3794..4f9b78b48ae 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3657,6 +3657,15 @@ int main(int argc, char ** argv) { return 1; } + // validate batch size for embeddings + // embeddings require all tokens to be processed in a single ubatch + // see https://github.com/ggml-org/llama.cpp/issues/12836 + if (params.embedding && params.n_batch > params.n_ubatch) { + LOG_WRN("%s: embeddings enabled with n_batch (%d) > n_ubatch (%d)\n", __func__, params.n_batch, params.n_ubatch); + LOG_WRN("%s: setting n_batch = n_ubatch = %d to avoid assertion failure\n", __func__, params.n_ubatch); + params.n_batch = params.n_ubatch; + } + // TODO: should we have a separate n_parallel parameter for the server? // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 // TODO: this is a common configuration that is suitable for most local use cases