@@ -316,6 +316,7 @@ def _sample(
316
316
mirostat_tau : llama_cpp .c_float ,
317
317
mirostat_eta : llama_cpp .c_float ,
318
318
penalize_nl : bool = True ,
319
+ logits_processors = None
319
320
):
320
321
assert self .ctx is not None
321
322
assert len (self .eval_logits ) > 0
@@ -328,6 +329,10 @@ def _sample(
328
329
else last_n_tokens_size
329
330
)
330
331
logits = self .eval_logits [- 1 ]
332
+ for processor in logits_processors :
333
+ logits = processor (list (self .eval_tokens ), logits )
334
+
335
+ self .eval_logits [- 1 ] = logits
331
336
nl_logit = logits [self ._token_nl ]
332
337
candidates = self ._candidates
333
338
for i , logit in enumerate (logits ):
@@ -436,6 +441,8 @@ def sample(
436
441
mirostat_eta : float = 0.1 ,
437
442
mirostat_tau : float = 5.0 ,
438
443
penalize_nl : bool = True ,
444
+ logits_processors = None
445
+
439
446
):
440
447
"""Sample a token from the model.
441
448
@@ -468,6 +475,8 @@ def sample(
468
475
mirostat_tau = llama_cpp .c_float (mirostat_tau ),
469
476
mirostat_eta = llama_cpp .c_float (mirostat_eta ),
470
477
penalize_nl = penalize_nl ,
478
+ logits_processors = logits_processors
479
+
471
480
)
472
481
473
482
def generate (
@@ -484,6 +493,7 @@ def generate(
484
493
mirostat_mode : int = 0 ,
485
494
mirostat_tau : float = 5.0 ,
486
495
mirostat_eta : float = 0.1 ,
496
+ logits_processors = None
487
497
) -> Generator [int , Optional [Sequence [int ]], None ]:
488
498
"""Create a generator of tokens from a prompt.
489
499
@@ -541,6 +551,7 @@ def generate(
541
551
mirostat_mode = mirostat_mode ,
542
552
mirostat_tau = mirostat_tau ,
543
553
mirostat_eta = mirostat_eta ,
554
+ logits_processors = logits_processors
544
555
)
545
556
tokens_or_none = yield token
546
557
tokens = [token ]
@@ -637,6 +648,8 @@ def _create_completion(
637
648
mirostat_tau : float = 5.0 ,
638
649
mirostat_eta : float = 0.1 ,
639
650
model : Optional [str ] = None ,
651
+ logits_processors = None ,
652
+ stopping_criterias = None
640
653
) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
641
654
assert self .ctx is not None
642
655
completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
@@ -700,13 +713,22 @@ def _create_completion(
700
713
frequency_penalty = frequency_penalty ,
701
714
presence_penalty = presence_penalty ,
702
715
repeat_penalty = repeat_penalty ,
716
+ logits_processors = logits_processors
703
717
):
704
718
if token == self ._token_eos :
705
719
text = self .detokenize (completion_tokens )
706
720
finish_reason = "stop"
707
721
break
708
722
709
723
completion_tokens .append (token )
724
+ for stopping_crit in stopping_criterias :
725
+ if stopping_crit (completion_tokens , None ):
726
+ text = self .detokenize (completion_tokens )
727
+ finish_reason = "stop"
728
+ break
729
+
730
+ if finish_reason == "stop" :
731
+ break
710
732
711
733
all_text = self .detokenize (completion_tokens )
712
734
@@ -1006,6 +1028,8 @@ def create_completion(
1006
1028
mirostat_tau : float = 5.0 ,
1007
1029
mirostat_eta : float = 0.1 ,
1008
1030
model : Optional [str ] = None ,
1031
+ logits_processors = None ,
1032
+ stopping_criterias = None
1009
1033
) -> Union [Completion , Iterator [CompletionChunk ]]:
1010
1034
"""Generate text from a prompt.
1011
1035
@@ -1048,6 +1072,9 @@ def create_completion(
1048
1072
mirostat_tau = mirostat_tau ,
1049
1073
mirostat_eta = mirostat_eta ,
1050
1074
model = model ,
1075
+ logits_processors = logits_processors ,
1076
+ stopping_criterias = stopping_criterias
1077
+
1051
1078
)
1052
1079
if stream :
1053
1080
chunks : Iterator [CompletionChunk ] = completion_or_chunks
0 commit comments