Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fastchat/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class CompletionRequest(BaseModel):
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
use_beam_search: Optional[bool] = False
best_of: Optional[int] = None


class CompletionResponseChoice(BaseModel):
Expand Down
16 changes: 16 additions & 0 deletions fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ async def fetch_remote(url, pload=None, name=None):
async with aiohttp.ClientSession(timeout=fetch_timeout) as session:
async with session.post(url, json=pload) as response:
chunks = []
if response.status != 200:
ret = {
"text": f"{response.reason}",
"error_code": ErrorCode.INTERNAL_ERROR,
}
return json.dumps(ret)

async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk)
output = b"".join(chunks)
Expand Down Expand Up @@ -236,6 +243,8 @@ async def get_gen_params(
max_tokens: Optional[int],
echo: Optional[bool],
stop: Optional[Union[str, List[str]]],
best_of: Optional[int] = None,
use_beam_search: Optional[bool] = None,
) -> Dict[str, Any]:
conv = await get_conv(model_name, worker_addr)
conv = Conversation(
Expand Down Expand Up @@ -280,6 +289,11 @@ async def get_gen_params(
"stop_token_ids": conv.stop_token_ids,
}

if best_of is not None:
gen_params.update({"best_of": best_of})
if use_beam_search is not None:
gen_params.update({"use_beam_search": use_beam_search})

new_stop = set()
_add_to_set(stop, new_stop)
_add_to_set(conv.stop_str, new_stop)
Expand Down Expand Up @@ -487,6 +501,8 @@ async def create_completion(request: CompletionRequest):
max_tokens=request.max_tokens,
echo=request.echo,
stop=request.stop,
best_of=request.best_of,
use_beam_search=request.use_beam_search,
)
for i in range(request.n):
content = asyncio.create_task(
Expand Down
19 changes: 13 additions & 6 deletions fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ async def generate_stream(self, params):
if self.tokenizer.eos_token_id is not None:
stop_token_ids.append(self.tokenizer.eos_token_id)
echo = params.get("echo", True)
use_beam_search = params.get("use_beam_search", False)
best_of = params.get("best_of", None)

# Handle stop_str
stop = set()
Expand All @@ -94,9 +96,10 @@ async def generate_stream(self, params):
n=1,
temperature=temperature,
top_p=top_p,
use_beam_search=False,
use_beam_search=use_beam_search,
stop=list(stop),
max_tokens=max_new_tokens,
best_of=best_of,
)
results_generator = engine.generate(context, sampling_params, request_id)

Expand All @@ -110,17 +113,21 @@ async def generate_stream(self, params):
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)
# Note: usage is not supported yet
prompt_tokens = len(request_output.prompt_token_ids)
completion_tokens = sum(
len(output.token_ids) for output in request_output.outputs
)
ret = {
"text": text_outputs,
"error_code": 0,
"usage": {},
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"cumulative_logprob": [
output.cumulative_logprob for output in request_output.outputs
],
"prompt_token_len": len(request_output.prompt_token_ids),
"output_token_len": [
len(output.token_ids) for output in request_output.outputs
],
"finish_reason": request_output.outputs[0].finish_reason
if len(request_output.outputs) == 1
else [output.finish_reason for output in request_output.outputs],
Expand Down