Skip to content

Commit 8eb9769

Browse files
committed
Add support for numpy
1 parent 4c1b7f7 commit 8eb9769

File tree

2 files changed

+39
-22
lines changed

2 files changed

+39
-22
lines changed

llama_cpp/llama.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from . import llama_cpp
2121
from .llama_types import *
2222

23+
import numpy as np
24+
import numpy.typing as npt
25+
2326

2427
class LlamaCache:
2528
"""Cache for a llama.cpp model."""
@@ -73,11 +76,15 @@ def __init__(
7376
self,
7477
eval_tokens: Deque[int],
7578
eval_logits: Deque[List[float]],
79+
input_ids: npt.NDArray[np.intc],
80+
scores: npt.NDArray[np.single],
7681
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
7782
llama_state_size: int,
7883
):
7984
self.eval_tokens = eval_tokens
8085
self.eval_logits = eval_logits
86+
self.input_ids = input_ids
87+
self.scores = scores
8188
self.llama_state = llama_state
8289
self.llama_state_size = llama_state_size
8390

@@ -207,27 +214,24 @@ def __init__(
207214

208215
self._n_vocab = self.n_vocab()
209216
self._n_ctx = self.n_ctx()
210-
data = (llama_cpp.llama_token_data * self._n_vocab)(
211-
*[
212-
llama_cpp.llama_token_data(
213-
id=llama_cpp.llama_token(i),
214-
logit=llama_cpp.c_float(0.0),
215-
p=llama_cpp.c_float(0.0),
216-
)
217-
for i in range(self._n_vocab)
218-
]
219-
)
220217
size = llama_cpp.c_size_t(self._n_vocab)
221-
sorted = False
218+
sorted = llama_cpp.c_bool(False)
219+
self._candidates_data = np.array(
220+
[], dtype=[("id", np.intc), ("logit", np.single), ("p", np.single)]
221+
)
222+
self._candidates_data.resize(3, self._n_vocab)
222223
candidates = llama_cpp.llama_token_data_array(
223-
data=data,
224+
data=self._candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
224225
size=size,
225226
sorted=sorted,
226227
)
227228
self._candidates = candidates
228229
self._token_nl = Llama.token_nl()
229230
self._token_eos = Llama.token_eos()
230231

232+
self._input_ids = np.array([], dtype=np.intc)
233+
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
234+
231235
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
232236
"""Tokenize a string.
233237
@@ -319,13 +323,19 @@ def eval(self, tokens: Sequence[int]):
319323
raise RuntimeError(f"llama_eval returned {return_code}")
320324
# Save tokens
321325
self.eval_tokens.extend(batch)
326+
self._input_ids: npt.NDArray[np.intc] = np.concatenate(
327+
(self._input_ids, np.array(batch, dtype=np.intc)), axis=0
328+
)
322329
# Save logits
323330
rows = n_tokens if self.params.logits_all else 1
324331
n_vocab = self._n_vocab
325332
cols = n_vocab
326333
logits_view = llama_cpp.llama_get_logits(self.ctx)
327334
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
328335
self.eval_logits.extend(logits)
336+
self._scores: npt.NDArray[np.single] = np.concatenate(
337+
(self._scores, np.array(logits, dtype=np.single)), axis=0
338+
)
329339

330340
def _sample(
331341
self,
@@ -354,18 +364,23 @@ def _sample(
354364
if last_n_tokens_size.value < 0
355365
else last_n_tokens_size
356366
)
357-
logits = self.eval_logits[-1]
367+
logits: npt.NDArray[np.single] = self._scores[-1, :]
358368

359369
if logits_processor is not None:
360-
logits = logits_processor(list(self.eval_tokens), logits)
361-
self.eval_logits[-1] = logits
370+
logits = np.array(
371+
logits_processor(list(self.eval_tokens), logits.tolist()),
372+
dtype=np.single,
373+
)
374+
self._scores[-1, :] = logits
375+
self.eval_logits[-1] = logits.tolist()
362376

363377
nl_logit = logits[self._token_nl]
364378
candidates = self._candidates
365-
for i, logit in enumerate(logits):
366-
candidates.data[i].id = llama_cpp.llama_token(i)
367-
candidates.data[i].logit = llama_cpp.c_float(logit)
368-
candidates.data[i].p = llama_cpp.c_float(0.0)
379+
candidates_data = self._candidates_data
380+
candidates_data["id"] = np.arange(n_vocab, dtype=np.intc) # type: ignore
381+
candidates_data["logit"] = logits
382+
candidates_data["p"] = np.zeros(n_vocab, dtype=np.single)
383+
candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p)
369384
candidates.sorted = llama_cpp.c_bool(False)
370385
candidates.size = llama_cpp.c_size_t(n_vocab)
371386
llama_cpp.llama_sample_repetition_penalty(
@@ -1371,6 +1386,8 @@ def save_state(self) -> LlamaState:
13711386
return LlamaState(
13721387
eval_tokens=self.eval_tokens.copy(),
13731388
eval_logits=self.eval_logits.copy(),
1389+
scores=self._scores.copy(),
1390+
input_ids=self._input_ids.copy(),
13741391
llama_state=llama_state_compact,
13751392
llama_state_size=n_bytes,
13761393
)
@@ -1379,6 +1396,8 @@ def load_state(self, state: LlamaState) -> None:
13791396
assert self.ctx is not None
13801397
self.eval_tokens = state.eval_tokens.copy()
13811398
self.eval_logits = state.eval_logits.copy()
1399+
self._scores = state.scores.copy()
1400+
self._input_ids = state.input_ids.copy()
13821401
state_size = state.llama_state_size
13831402
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
13841403
raise RuntimeError("Failed to set llama state data")

setup.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
license="MIT",
1717
package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"},
1818
packages=["llama_cpp", "llama_cpp.server"],
19-
install_requires=[
20-
"typing-extensions>=4.5.0",
21-
],
19+
install_requires=["typing-extensions>=4.5.0", "numpy>=1.24.2"],
2220
extras_require={
2321
"server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"],
2422
},

0 commit comments

Comments
 (0)