20
20
from . import llama_cpp
21
21
from .llama_types import *
22
22
23
+ import numpy as np
24
+ import numpy .typing as npt
25
+
23
26
24
27
class LlamaCache :
25
28
"""Cache for a llama.cpp model."""
@@ -73,11 +76,15 @@ def __init__(
73
76
self ,
74
77
eval_tokens : Deque [int ],
75
78
eval_logits : Deque [List [float ]],
79
+ input_ids : npt .NDArray [np .intc ],
80
+ scores : npt .NDArray [np .single ],
76
81
llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
77
82
llama_state_size : int ,
78
83
):
79
84
self .eval_tokens = eval_tokens
80
85
self .eval_logits = eval_logits
86
+ self .input_ids = input_ids
87
+ self .scores = scores
81
88
self .llama_state = llama_state
82
89
self .llama_state_size = llama_state_size
83
90
@@ -207,27 +214,24 @@ def __init__(
207
214
208
215
self ._n_vocab = self .n_vocab ()
209
216
self ._n_ctx = self .n_ctx ()
210
- data = (llama_cpp .llama_token_data * self ._n_vocab )(
211
- * [
212
- llama_cpp .llama_token_data (
213
- id = llama_cpp .llama_token (i ),
214
- logit = llama_cpp .c_float (0.0 ),
215
- p = llama_cpp .c_float (0.0 ),
216
- )
217
- for i in range (self ._n_vocab )
218
- ]
219
- )
220
217
size = llama_cpp .c_size_t (self ._n_vocab )
221
- sorted = False
218
+ sorted = llama_cpp .c_bool (False )
219
+ self ._candidates_data = np .array (
220
+ [], dtype = [("id" , np .intc ), ("logit" , np .single ), ("p" , np .single )]
221
+ )
222
+ self ._candidates_data .resize (3 , self ._n_vocab )
222
223
candidates = llama_cpp .llama_token_data_array (
223
- data = data ,
224
+ data = self . _candidates_data . ctypes . data_as ( llama_cpp . llama_token_data_p ) ,
224
225
size = size ,
225
226
sorted = sorted ,
226
227
)
227
228
self ._candidates = candidates
228
229
self ._token_nl = Llama .token_nl ()
229
230
self ._token_eos = Llama .token_eos ()
230
231
232
+ self ._input_ids = np .array ([], dtype = np .intc )
233
+ self ._scores = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
234
+
231
235
def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
232
236
"""Tokenize a string.
233
237
@@ -319,13 +323,19 @@ def eval(self, tokens: Sequence[int]):
319
323
raise RuntimeError (f"llama_eval returned { return_code } " )
320
324
# Save tokens
321
325
self .eval_tokens .extend (batch )
326
+ self ._input_ids : npt .NDArray [np .intc ] = np .concatenate (
327
+ (self ._input_ids , np .array (batch , dtype = np .intc )), axis = 0
328
+ )
322
329
# Save logits
323
330
rows = n_tokens if self .params .logits_all else 1
324
331
n_vocab = self ._n_vocab
325
332
cols = n_vocab
326
333
logits_view = llama_cpp .llama_get_logits (self .ctx )
327
334
logits = [logits_view [i * cols : (i + 1 ) * cols ] for i in range (rows )]
328
335
self .eval_logits .extend (logits )
336
+ self ._scores : npt .NDArray [np .single ] = np .concatenate (
337
+ (self ._scores , np .array (logits , dtype = np .single )), axis = 0
338
+ )
329
339
330
340
def _sample (
331
341
self ,
@@ -354,18 +364,23 @@ def _sample(
354
364
if last_n_tokens_size .value < 0
355
365
else last_n_tokens_size
356
366
)
357
- logits = self .eval_logits [- 1 ]
367
+ logits : npt . NDArray [ np . single ] = self ._scores [- 1 , : ]
358
368
359
369
if logits_processor is not None :
360
- logits = logits_processor (list (self .eval_tokens ), logits )
361
- self .eval_logits [- 1 ] = logits
370
+ logits = np .array (
371
+ logits_processor (list (self .eval_tokens ), logits .tolist ()),
372
+ dtype = np .single ,
373
+ )
374
+ self ._scores [- 1 , :] = logits
375
+ self .eval_logits [- 1 ] = logits .tolist ()
362
376
363
377
nl_logit = logits [self ._token_nl ]
364
378
candidates = self ._candidates
365
- for i , logit in enumerate (logits ):
366
- candidates .data [i ].id = llama_cpp .llama_token (i )
367
- candidates .data [i ].logit = llama_cpp .c_float (logit )
368
- candidates .data [i ].p = llama_cpp .c_float (0.0 )
379
+ candidates_data = self ._candidates_data
380
+ candidates_data ["id" ] = np .arange (n_vocab , dtype = np .intc ) # type: ignore
381
+ candidates_data ["logit" ] = logits
382
+ candidates_data ["p" ] = np .zeros (n_vocab , dtype = np .single )
383
+ candidates .data = candidates_data .ctypes .data_as (llama_cpp .llama_token_data_p )
369
384
candidates .sorted = llama_cpp .c_bool (False )
370
385
candidates .size = llama_cpp .c_size_t (n_vocab )
371
386
llama_cpp .llama_sample_repetition_penalty (
@@ -1371,6 +1386,8 @@ def save_state(self) -> LlamaState:
1371
1386
return LlamaState (
1372
1387
eval_tokens = self .eval_tokens .copy (),
1373
1388
eval_logits = self .eval_logits .copy (),
1389
+ scores = self ._scores .copy (),
1390
+ input_ids = self ._input_ids .copy (),
1374
1391
llama_state = llama_state_compact ,
1375
1392
llama_state_size = n_bytes ,
1376
1393
)
@@ -1379,6 +1396,8 @@ def load_state(self, state: LlamaState) -> None:
1379
1396
assert self .ctx is not None
1380
1397
self .eval_tokens = state .eval_tokens .copy ()
1381
1398
self .eval_logits = state .eval_logits .copy ()
1399
+ self ._scores = state .scores .copy ()
1400
+ self ._input_ids = state .input_ids .copy ()
1382
1401
state_size = state .llama_state_size
1383
1402
if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
1384
1403
raise RuntimeError ("Failed to set llama state data" )
0 commit comments