diff --git a/llamafile/flags.cpp b/llamafile/flags.cpp index c0e3bb3b74..45916ffab6 100644 --- a/llamafile/flags.cpp +++ b/llamafile/flags.cpp @@ -79,6 +79,7 @@ int FLAG_flash_attn = false; int FLAG_gpu = 0; int FLAG_http_ibuf_size = 5 * 1024 * 1024; int FLAG_http_obuf_size = 1024 * 1024; +int FLAG_http_write_timeout = 60000; int FLAG_keepalive = 5; int FLAG_main_gpu = 0; int FLAG_n_gpu_layers = -1; @@ -346,6 +347,13 @@ void llamafile_get_flags(int argc, char **argv) { continue; } + if (!strcmp(flag, "--http-write-timeout")) { + if (i == argc) + missing("--http-write-timeout"); + FLAG_http_write_timeout = atoi(argv[i++]); + continue; + } + ////////////////////////////////////////////////////////////////////// // sampling flags diff --git a/llamafile/llamafile.h b/llamafile/llamafile.h index b74dda60dd..559936cbdd 100644 --- a/llamafile/llamafile.h +++ b/llamafile/llamafile.h @@ -50,6 +50,7 @@ extern int FLAG_gpu; extern int FLAG_gpu; extern int FLAG_http_ibuf_size; extern int FLAG_http_obuf_size; +extern int FLAG_http_write_timeout; extern int FLAG_keepalive; extern int FLAG_main_gpu; extern int FLAG_n_gpu_layers; diff --git a/llamafile/server/client.cpp b/llamafile/server/client.cpp index e142a5a219..7afb821e85 100644 --- a/llamafile/server/client.cpp +++ b/llamafile/server/client.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -520,14 +521,38 @@ Client::send_response_finish() // // unlike send() this won't fail if binary content is detected. bool -Client::send_binary(const void* p, size_t n) -{ - ssize_t sent; - if ((sent = write(fd_, p, n)) != n) { - if (sent == -1 && errno != EAGAIN && errno != ECONNRESET) - SLOG("write failed %m"); - close_connection_ = true; - return false; +Client::send_binary(const void* p, size_t n) { + const char* buf = (const char*)p; + size_t written = 0; + while (written < n) { + ssize_t sent = write(fd_, buf + written, n - written); + if (sent == -1) { + if (errno == EINTR) + continue; + if (errno == EAGAIN || errno == EWOULDBLOCK) { + struct pollfd pfd = { .fd = fd_, .events = POLLOUT }; + int ret = poll(&pfd, 1, FLAG_http_write_timeout); + if (ret < 0) { + if (errno == EINTR) + continue; + SLOG("poll failed %m"); + close_connection_ = true; + return false; + } + if (ret == 0) { + SLOG("write timed out"); + close_connection_ = true; + return false; + } + continue; + } + if (errno != ECONNRESET) + SLOG("write failed %m"); + close_connection_ = true; + return false; + } + // sent ≥ 0 + written += sent; } return true; } @@ -775,7 +800,7 @@ Client::dispatcher() should_send_error_if_canceled_ = false; if (!send(std::string_view(obuf_.p, p - obuf_.p))) return false; - char buf[512]; + char buf[16384]; size_t i, chunk; for (i = 0; i < size; i += chunk) { chunk = size - i; diff --git a/llamafile/server/main.1 b/llamafile/server/main.1 index e5d01adc2a..08aac519f5 100644 --- a/llamafile/server/main.1 +++ b/llamafile/server/main.1 @@ -171,6 +171,11 @@ supported by the host operating system. The default keepalive is 5. Size of HTTP output buffer size, in bytes. Default is 1048576. .It Fl Fl http-ibuf-size Ar N Size of HTTP input buffer size, in bytes. Default is 1048576. +.It Fl Fl http-write-timeout Ar MS +Socket write timeout in milliseconds. When sending data to a client, if +the socket buffer is full and the client is not reading, the server will +wait up to this many milliseconds for the socket to become writable before +closing the connection. Default is 60000 (60 seconds). .It Fl Fl chat-template Ar NAME Specifies or overrides chat template for model. .Pp diff --git a/llamafile/server/main.1.asc b/llamafile/server/main.1.asc index ab99e21913..68f6a49184 100644 --- a/llamafile/server/main.1.asc +++ b/llamafile/server/main.1.asc @@ -200,6 +200,13 @@ --http-ibuf-size N Size of HTTP input buffer size, in bytes. Default is 1048576. + --http-write-timeout MS + Socket write timeout in milliseconds. When sending data to a + client, if the socket buffer is full and the client is not + reading, the server will wait up to this many milliseconds for + the socket to become writable before closing the connection. + Default is 60000 (60 seconds). + --chat-template NAME Specifies or overrides chat template for model. diff --git a/llamafile/server/worker.cpp b/llamafile/server/worker.cpp index a016c62218..84ce56e2ed 100644 --- a/llamafile/server/worker.cpp +++ b/llamafile/server/worker.cpp @@ -56,13 +56,6 @@ Worker::begin() tokens = tokenbucket_acquire(client_.client_ip_); server_->lock(); dll_remove(&server_->idle_workers, &elem_); - if (dll_is_empty(server_->idle_workers)) { - Dll* slowbro; - if ((slowbro = dll_last(server_->active_workers))) { - SLOG("all threads active! dropping oldest client"); - WORKER(slowbro)->kill(); - } - } working_ = true; if (tokens > FLAG_token_burst) { dll_make_last(&server_->active_workers, &elem_); diff --git a/llamafile/server/writev.cpp b/llamafile/server/writev.cpp index 841af57679..726ae09cfc 100644 --- a/llamafile/server/writev.cpp +++ b/llamafile/server/writev.cpp @@ -16,9 +16,12 @@ // limitations under the License. #include "llamafile/server/log.h" +#include "llamafile/llamafile.h" #include "utils.h" #include +#include #include +#include namespace lf { namespace server { @@ -26,6 +29,7 @@ namespace server { ssize_t safe_writev(int fd, const iovec* iov, int iovcnt) { + // Security check for binary content in headers for (int i = 0; i < iovcnt; ++i) { bool has_binary = false; size_t n = iov[i].iov_len; @@ -39,7 +43,50 @@ safe_writev(int fd, const iovec* iov, int iovcnt) return -1; } } - return writev(fd, iov, iovcnt); + + ssize_t total = 0; + // Create a mutable copy of iovecs to track progress + std::vector copy(iov, iov + iovcnt); + int i = 0; // Current iovec index + + while (i < iovcnt) { + ssize_t sent = writev(fd, copy.data() + i, iovcnt - i); + if (sent == -1) { + if (errno == EINTR) + continue; + if (errno == EAGAIN || errno == EWOULDBLOCK) { + struct pollfd pfd = { .fd = fd, .events = POLLOUT }; + int rc = poll(&pfd, 1, FLAG_http_write_timeout); + if (rc == 0) { + errno = ETIMEDOUT; + return -1; + } + if (rc == -1) { + if (errno == EINTR) + continue; + return -1; + } + continue; + } + return -1; + } + + total += sent; + size_t got = sent; + + // Advance the iovecs based on bytes written + while (got > 0 && i < iovcnt) { + if (got >= copy[i].iov_len) { + got -= copy[i].iov_len; + ++i; + } else { + copy[i].iov_base = (char*)copy[i].iov_base + got; + copy[i].iov_len -= got; + got = 0; + } + } + } + return total; } } // namespace server