Skip to content

Commit a5e6abf

Browse files
leiwen83wenlei03
andauthored
add best_of and use_beam_search for completions interface (lm-sys#2372)
Signed-off-by: Lei Wen <[email protected]> Co-authored-by: Lei Wen <[email protected]>
1 parent dc3dd12 commit a5e6abf

File tree

4 files changed

+79
-26
lines changed

4 files changed

+79
-26
lines changed

fastchat/protocol/api_protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class CompletionResponse(BaseModel):
150150
created: int = Field(default_factory=lambda: int(time.time()))
151151
model: str
152152
choices: List[CompletionResponseChoice]
153-
usage: UsageInfo
153+
usage: Union[UsageInfo, List[UsageInfo]]
154154

155155

156156
class CompletionResponseStreamChoice(BaseModel):

fastchat/protocol/openai_api_protocol.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,13 @@ class CompletionRequest(BaseModel):
151151
presence_penalty: Optional[float] = 0.0
152152
frequency_penalty: Optional[float] = 0.0
153153
user: Optional[str] = None
154+
use_beam_search: Optional[bool] = False
155+
best_of: Optional[int] = None
154156

155157

156158
class CompletionResponseChoice(BaseModel):
157159
index: int
158-
text: str
160+
text: Union[str, List[str]]
159161
logprobs: Optional[int] = None
160162
finish_reason: Optional[Literal["stop", "length"]] = None
161163

fastchat/serve/openai_api_server.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@ async def get_gen_params(
241241
max_tokens: Optional[int],
242242
echo: Optional[bool],
243243
stop: Optional[Union[str, List[str]]],
244+
best_of: Optional[int] = None,
245+
n: Optional[int] = 1,
246+
use_beam_search: Optional[bool] = None,
244247
) -> Dict[str, Any]:
245248
conv = await get_conv(model_name, worker_addr)
246249
conv = Conversation(
@@ -287,6 +290,11 @@ async def get_gen_params(
287290
"stop_token_ids": conv.stop_token_ids,
288291
}
289292

293+
if best_of is not None:
294+
gen_params.update({"n": n, "best_of": best_of})
295+
if use_beam_search is not None:
296+
gen_params.update({"use_beam_search": use_beam_search})
297+
290298
new_stop = set()
291299
_add_to_set(stop, new_stop)
292300
_add_to_set(conv.stop_str, new_stop)
@@ -494,12 +502,18 @@ async def create_completion(request: CompletionRequest):
494502
max_tokens=request.max_tokens,
495503
echo=request.echo,
496504
stop=request.stop,
505+
best_of=request.best_of,
506+
n=request.n,
507+
use_beam_search=request.use_beam_search,
497508
)
498509
for i in range(request.n):
499510
content = asyncio.create_task(
500511
generate_completion(gen_params, worker_addr)
501512
)
502513
text_completions.append(content)
514+
# when use with best_of, only need send one request
515+
if request.best_of:
516+
break
503517

504518
try:
505519
all_tasks = await asyncio.gather(*text_completions)
@@ -519,9 +533,18 @@ async def create_completion(request: CompletionRequest):
519533
finish_reason=content.get("finish_reason", "stop"),
520534
)
521535
)
522-
task_usage = UsageInfo.parse_obj(content["usage"])
523-
for usage_key, usage_value in task_usage.dict().items():
524-
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
536+
idx = 0
537+
while True:
538+
info = content["usage"]
539+
if isinstance(info, list):
540+
info = info[idx]
541+
542+
task_usage = UsageInfo.parse_obj(info)
543+
544+
for usage_key, usage_value in task_usage.dict().items():
545+
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
546+
idx += 1
547+
break
525548

526549
return CompletionResponse(
527550
model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage)

fastchat/serve/vllm_worker.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.sampling_params import SamplingParams
1919
from vllm.utils import random_uuid
2020

21+
from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
2122
from fastchat.serve.model_worker import (
2223
BaseModelWorker,
2324
logger,
@@ -74,6 +75,9 @@ async def generate_stream(self, params):
7475
if self.tokenizer.eos_token_id is not None:
7576
stop_token_ids.append(self.tokenizer.eos_token_id)
7677
echo = params.get("echo", True)
78+
use_beam_search = params.get("use_beam_search", False)
79+
best_of = params.get("best_of", None)
80+
n = params.get("n", 1)
7781

7882
# Handle stop_str
7983
stop = set()
@@ -90,27 +94,51 @@ async def generate_stream(self, params):
9094
top_p = max(top_p, 1e-5)
9195
if temperature <= 1e-5:
9296
top_p = 1.0
93-
sampling_params = SamplingParams(
94-
n=1,
95-
temperature=temperature,
96-
top_p=top_p,
97-
use_beam_search=False,
98-
stop=list(stop),
99-
max_tokens=max_new_tokens,
100-
)
101-
results_generator = engine.generate(context, sampling_params, request_id)
102-
103-
async for request_output in results_generator:
104-
prompt = request_output.prompt
105-
if echo:
106-
text_outputs = [
107-
prompt + output.text for output in request_output.outputs
108-
]
109-
else:
110-
text_outputs = [output.text for output in request_output.outputs]
111-
text_outputs = " ".join(text_outputs)
112-
# Note: usage is not supported yet
113-
ret = {"text": text_outputs, "error_code": 0, "usage": {}}
97+
try:
98+
sampling_params = SamplingParams(
99+
n=n,
100+
temperature=temperature,
101+
top_p=top_p,
102+
use_beam_search=use_beam_search,
103+
stop=list(stop),
104+
max_tokens=max_new_tokens,
105+
best_of=best_of,
106+
)
107+
108+
results_generator = engine.generate(context, sampling_params, request_id)
109+
110+
async for request_output in results_generator:
111+
prompt = request_output.prompt
112+
prompt_tokens = len(request_output.prompt_token_ids)
113+
output_usage = []
114+
for out in request_output.outputs:
115+
completion_tokens = len(out.token_ids)
116+
total_tokens = prompt_tokens + completion_tokens
117+
output_usage.append(
118+
{
119+
"prompt_tokens": prompt_tokens,
120+
"completion_tokens": completion_tokens,
121+
"total_tokens": total_tokens,
122+
}
123+
)
124+
125+
if echo:
126+
text_outputs = [
127+
prompt + output.text for output in request_output.outputs
128+
]
129+
else:
130+
text_outputs = [output.text for output in request_output.outputs]
131+
132+
if sampling_params.best_of is None:
133+
text_outputs = [" ".join(text_outputs)]
134+
ret = {"text": text_outputs, "error_code": 0, "usage": output_usage}
135+
yield (json.dumps(ret) + "\0").encode()
136+
except (ValueError, RuntimeError) as e:
137+
ret = {
138+
"text": f"{e}",
139+
"error_code": ErrorCode.PARAM_OUT_OF_RANGE,
140+
"usage": {},
141+
}
114142
yield (json.dumps(ret) + "\0").encode()
115143

116144
async def generate(self, params):

0 commit comments

Comments
 (0)