Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,15 @@ async def check_length(request, prompt, max_tokens, worker_addr):
{"model": request.model, "prompt": prompt},
"count",
)
return min(max_tokens, context_len - token_num)
length = min(max_tokens, context_len - token_num)

if length <= 0:
return None, create_error_response(
ErrorCode.CONTEXT_OVERFLOW,
f"This model's maximum context length is {context_len} tokens. However, your messages resulted in {token_num} tokens. Please reduce the length of the messages.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"This model's maximum context length is {context_len} tokens. However, your messages resulted in {token_num} tokens. Please reduce the length of the messages.",
f"Your message has {token_num} tokens which exceed the model's maximum context length ({context_len} tokens). Please shorten the message and try again.",

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original error message inside the pr is taken from OpenAI's error message. For example: https://community.openai.com/t/error-this-models-maximum-context-length-is-x-tokens/328860

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see. then it makes sense to be consistent with openai error message.

)

return length, None


def check_requests(request) -> Optional[JSONResponse]:
Expand Down Expand Up @@ -392,13 +400,19 @@ async def create_chat_completion(request: ChatCompletionRequest):
echo=False,
stop=request.stop,
)
gen_params["max_new_tokens"] = await check_length(

max_new_tokens, error_check_ret = await check_length(
request,
gen_params["prompt"],
gen_params["max_new_tokens"],
worker_addr,
)

if error_check_ret is not None:
return error_check_ret

gen_params["max_new_tokens"] = max_new_tokens

if request.stream:
generator = chat_completion_stream_generator(
request.model, gen_params, request.n, worker_addr
Expand Down Expand Up @@ -502,7 +516,12 @@ async def create_completion(request: CompletionRequest):

worker_addr = await get_worker_address(request.model)
for text in request.prompt:
max_tokens = await check_length(request, text, request.max_tokens, worker_addr)
max_tokens, error_check_ret = await check_length(
request, text, request.max_tokens, worker_addr
)
if error_check_ret is not None:
return error_check_ret

if isinstance(max_tokens, int) and max_tokens < request.max_tokens:
request.max_tokens = max_tokens

Expand Down Expand Up @@ -772,13 +791,18 @@ async def create_chat_completion(request: APIChatCompletionRequest):
if request.repetition_penalty is not None:
gen_params["repetition_penalty"] = request.repetition_penalty

gen_params["max_new_tokens"] = await check_length(
max_new_tokens, error_check_ret = await check_length(
request,
gen_params["prompt"],
gen_params["max_new_tokens"],
worker_addr,
)

if error_check_ret is not None:
return error_check_ret

gen_params["max_new_tokens"] = max_new_tokens

if request.stream:
generator = chat_completion_stream_generator(
request.model, gen_params, request.n, worker_addr
Expand Down