Skip to content

Commit 5bb780d

Browse files
Implemented logit processors and stop criteria's
1 parent e5d596e commit 5bb780d

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

llama_cpp/llama.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ def _sample(
316316
mirostat_tau: llama_cpp.c_float,
317317
mirostat_eta: llama_cpp.c_float,
318318
penalize_nl: bool = True,
319+
logits_processors=None
319320
):
320321
assert self.ctx is not None
321322
assert len(self.eval_logits) > 0
@@ -328,6 +329,10 @@ def _sample(
328329
else last_n_tokens_size
329330
)
330331
logits = self.eval_logits[-1]
332+
for processor in logits_processors:
333+
logits = processor(list(self.eval_tokens), logits)
334+
335+
self.eval_logits[-1] = logits
331336
nl_logit = logits[self._token_nl]
332337
candidates = self._candidates
333338
for i, logit in enumerate(logits):
@@ -436,6 +441,8 @@ def sample(
436441
mirostat_eta: float = 0.1,
437442
mirostat_tau: float = 5.0,
438443
penalize_nl: bool = True,
444+
logits_processors=None
445+
439446
):
440447
"""Sample a token from the model.
441448
@@ -468,6 +475,8 @@ def sample(
468475
mirostat_tau=llama_cpp.c_float(mirostat_tau),
469476
mirostat_eta=llama_cpp.c_float(mirostat_eta),
470477
penalize_nl=penalize_nl,
478+
logits_processors=logits_processors
479+
471480
)
472481

473482
def generate(
@@ -484,6 +493,7 @@ def generate(
484493
mirostat_mode: int = 0,
485494
mirostat_tau: float = 5.0,
486495
mirostat_eta: float = 0.1,
496+
logits_processors=None
487497
) -> Generator[int, Optional[Sequence[int]], None]:
488498
"""Create a generator of tokens from a prompt.
489499
@@ -541,6 +551,7 @@ def generate(
541551
mirostat_mode=mirostat_mode,
542552
mirostat_tau=mirostat_tau,
543553
mirostat_eta=mirostat_eta,
554+
logits_processors=logits_processors
544555
)
545556
tokens_or_none = yield token
546557
tokens = [token]
@@ -637,6 +648,8 @@ def _create_completion(
637648
mirostat_tau: float = 5.0,
638649
mirostat_eta: float = 0.1,
639650
model: Optional[str] = None,
651+
logits_processors=None,
652+
stopping_criterias=None
640653
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
641654
assert self.ctx is not None
642655
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
@@ -700,13 +713,22 @@ def _create_completion(
700713
frequency_penalty=frequency_penalty,
701714
presence_penalty=presence_penalty,
702715
repeat_penalty=repeat_penalty,
716+
logits_processors=logits_processors
703717
):
704718
if token == self._token_eos:
705719
text = self.detokenize(completion_tokens)
706720
finish_reason = "stop"
707721
break
708722

709723
completion_tokens.append(token)
724+
for stopping_crit in stopping_criterias:
725+
if stopping_crit(completion_tokens, None):
726+
text = self.detokenize(completion_tokens)
727+
finish_reason = "stop"
728+
break
729+
730+
if finish_reason == "stop":
731+
break
710732

711733
all_text = self.detokenize(completion_tokens)
712734

@@ -1006,6 +1028,8 @@ def create_completion(
10061028
mirostat_tau: float = 5.0,
10071029
mirostat_eta: float = 0.1,
10081030
model: Optional[str] = None,
1031+
logits_processors=None,
1032+
stopping_criterias=None
10091033
) -> Union[Completion, Iterator[CompletionChunk]]:
10101034
"""Generate text from a prompt.
10111035
@@ -1048,6 +1072,9 @@ def create_completion(
10481072
mirostat_tau=mirostat_tau,
10491073
mirostat_eta=mirostat_eta,
10501074
model=model,
1075+
logits_processors=logits_processors,
1076+
stopping_criterias=stopping_criterias
1077+
10511078
)
10521079
if stream:
10531080
chunks: Iterator[CompletionChunk] = completion_or_chunks

0 commit comments

Comments
 (0)