diff --git a/fastchat/serve/inference.py b/fastchat/serve/inference.py index 99b8647aa..fdb62a69f 100644 --- a/fastchat/serve/inference.py +++ b/fastchat/serve/inference.py @@ -47,6 +47,13 @@ def prepare_logits_processor( return processor_list +def partial_stop(output, stop_str): + for i in range(0, min(len(output), len(stop_str))): + if stop_str.startswith(output[-i:]): + return True + return False + + @torch.inference_mode() def generate_stream( model, tokenizer, params, device, context_len=2048, stream_interval=2 @@ -160,12 +167,16 @@ def generate_stream( skip_special_tokens=True, spaces_between_special_tokens=False, ) + + partially_stopped = False if stop_str: if isinstance(stop_str, str): pos = output.rfind(stop_str, rfind_start) if pos != -1: output = output[:pos] stopped = True + else: + partially_stopped = partial_stop(output, stop_str) elif isinstance(stop_str, Iterable): for each_stop in stop_str: pos = output.rfind(each_stop, rfind_start) @@ -173,18 +184,24 @@ def generate_stream( output = output[:pos] stopped = True break + else: + partially_stopped = partial_stop(output, each_stop) + if partially_stopped: + break else: raise ValueError("Invalid stop field type.") - - yield { - "text": output, - "usage": { - "prompt_tokens": input_echo_len, - "completion_tokens": i, - "total_tokens": input_echo_len + i, - }, - "finish_reason": None, - } + + # prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } if stopped: break