diff --git a/.gitignore b/.gitignore index 68cdf7ff..4992bbd4 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,4 @@ htmlcov Cargo.lock target/ +.chat/ diff --git a/pyproject.toml b/pyproject.toml index 7cc7cb10..59a12451 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" readme = "README.md" license = {file = "LICENSE"} authors = [{name = "Shantanu Jain"}, {email = "shantanu@openai.com"}] -dependencies = ["regex>=2022.1.18", "requests>=2.26.0"] +dependencies = ["requests>=2.26.0"] optional-dependencies = {blobfile = ["blobfile>=2"]} requires-python = ">=3.8" @@ -16,29 +16,4 @@ changelog = "https://github.com/openai/tiktoken/blob/main/CHANGELOG.md" [build-system] build-backend = "setuptools.build_meta" -requires = ["setuptools>=62.4", "wheel", "setuptools-rust>=1.5.2"] - -[tool.cibuildwheel] -build-frontend = "build" -build-verbosity = 1 - -linux.before-all = "curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y" -linux.environment = { PATH = "$PATH:$HOME/.cargo/bin" } -macos.before-all = "rustup target add aarch64-apple-darwin" - -skip = [ - "*-manylinux_i686", - "*-musllinux_i686", - "*-win32", -] -macos.archs = ["x86_64", "arm64"] -# When cross-compiling on Intel, it is not possible to test arm64 wheels. -# Warnings will be silenced with following CIBW_TEST_SKIP -test-skip = "*-macosx_arm64" - -before-test = "pip install pytest hypothesis" -test-command = "pytest {project}/tests --import-mode=append" - -[[tool.cibuildwheel.overrides]] -select = "*linux_aarch64" -test-command = """python -c 'import tiktoken; enc = tiktoken.get_encoding("gpt2"); assert enc.encode("hello world") == [31373, 995]'""" +requires = ["setuptools>=62.4", "wheel"] diff --git a/setup.py b/setup.py index a22e8e5d..5031b72d 100644 --- a/setup.py +++ b/setup.py @@ -1,17 +1,17 @@ from setuptools import setup -from setuptools_rust import Binding, RustExtension +#from setuptools_rust import Binding, RustExtension setup( name="tiktoken", - rust_extensions=[ - RustExtension( - "tiktoken._tiktoken", - binding=Binding.PyO3, - # Between our use of editable installs and wanting to use Rust for performance sensitive - # code, it makes sense to just always use --release - debug=False, - ) - ], + # rust_extensions=[ + # RustExtension( + # "tiktoken._tiktoken", + # binding=Binding.PyO3, + # # Between our use of editable installs and wanting to use Rust for performance sensitive + # # code, it makes sense to just always use --release + # debug=False, + # ) + # ], package_data={"tiktoken": ["py.typed"]}, packages=["tiktoken", "tiktoken_ext"], zip_safe=False, diff --git a/tiktoken/_tiktoken.py b/tiktoken/_tiktoken.py new file mode 100644 index 00000000..1bee253c --- /dev/null +++ b/tiktoken/_tiktoken.py @@ -0,0 +1,97 @@ +import re +import hashlib +from typing import Dict, List, Tuple, Union + +Rank = int + +def _byte_pair_merge(ranks: Dict[bytes, Rank], piece: bytes) -> List[Tuple[bytes, Rank]]: + parts = [] + min_rank = (float('inf'), float('inf')) + for i in range(len(piece) - 1): + rank = ranks.get(piece[i:i + 2], float('inf')) + if rank < min_rank[0]: + min_rank = (rank, i) + parts.append((piece[i:i + 2], rank)) + parts.append((piece[len(piece) - 1:], float('inf'))) + parts.append((piece[len(piece):], float('inf'))) + + while min_rank[0] != float('inf'): + i = min_rank[1] + if i > 0: + parts[i - 1] = (parts[i - 1][0], get_rank_with_ranks(piece, parts, i - 1, ranks)) + parts[i] = (parts[i][0], get_rank_with_ranks(piece, parts, i, ranks)) + del parts[i + 1] + + min_rank = (float('inf'), float('inf')) + for j, (_, rank) in enumerate(parts[:-1]): + if rank < min_rank[0]: + min_rank = (rank, j) + return parts + +def get_rank_with_ranks(piece: bytes, parts: List[Tuple[bytes, Rank]], i: int, ranks: Dict[bytes, Rank]) -> Rank: + if (i + 3) < len(parts): + key = piece[parts[i][0].start:parts[i + 3][0].start] + return ranks.get(key, float('inf')) + else: + return float('inf') + +def byte_pair_encode(piece: bytes, ranks: Dict[bytes, Rank]) -> List[Rank]: + assert len(piece) > 1 + parts = _byte_pair_merge(ranks, piece) + tokens = [] + current_token = [] + for part in parts[:-1]: + if len(current_token) == 0: + current_token.append(part[0]) + elif ranks.get(b''.join(current_token + [part[0]])) is not None: + current_token.append(part[0]) + else: + tokens.append(ranks[b''.join(current_token)]) + current_token = [part[0]] + tokens.append(ranks[b''.join(current_token)]) + return tokens + +def byte_pair_split(piece: bytes, ranks: Dict[bytes, Rank]) -> List[bytes]: + assert len(piece) > 1 + parts = _byte_pair_merge(ranks, piece) + return [part[0] for part in parts[:-1]] + +class CoreBPE: + def __init__(self, encoder: Dict[bytes, Rank], special_tokens_encoder: Dict[str, Rank], pattern: str): + self.encoder = encoder + self.special_tokens_encoder = special_tokens_encoder + self.decoder = {v: k for k, v in encoder.items()} + self.special_tokens_decoder = {v: k.encode('utf-8') for k, v in special_tokens_encoder.items()} + self.regex = re.compile(pattern) + self.special_regex = re.compile('|'.join(map(re.escape, special_tokens_encoder.keys()))) + + def encode_ordinary(self, text: str) -> List[Rank]: + try: + def encode_text(text): + try: + return self.encoder[text] + except Exception: + return 0 + return [encode_text(piece.encode("utf-8")) for piece in self.regex.findall(text)] + except Exception: + return [] + + def encode(self, text: str, allowed_special: set) -> List[Rank]: + tokens = [] + start = 0 + for match in self.special_regex.finditer(text): + if match.start() > start: + tokens.extend(self.encode_ordinary(text[start:match.start()])) + if match.group(0) in allowed_special: + tokens.append(self.special_tokens_encoder[match.group(0)]) + start = match.end() + if start < len(text): + tokens.extend(self.encode_ordinary(text[start:])) + return tokens + + def decode_bytes(self, tokens: List[Rank]) -> bytes: + return b''.join(self.decoder.get(token, self.special_tokens_decoder.get(token)) for token in tokens) + + def token_byte_values(self) -> List[bytes]: + return self.sorted_token_bytes + diff --git a/tiktoken/core.py b/tiktoken/core.py index 3e32d4ce..c5af7f4c 100644 --- a/tiktoken/core.py +++ b/tiktoken/core.py @@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import AbstractSet, Collection, Literal, NoReturn, Optional, Union -import regex +import re as regex from tiktoken import _tiktoken diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py index 6b29a711..2e12dd5a 100644 --- a/tiktoken_ext/openai_public.py +++ b/tiktoken_ext/openai_public.py @@ -18,9 +18,9 @@ def gpt2(): "name": "gpt2", "explicit_n_vocab": 50257, # The pattern in the original GPT-2 release is: - # r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + # r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\w]+| ?[\d]+| ?[^\s\w]+|\s+(?!\S)|\s+""" # This is equivalent, but executes faster: - "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\w+| ?\d+| ?[^\s\w]+|\s+(?!\S)|\s+""", "mergeable_ranks": mergeable_ranks, "special_tokens": {ENDOFTEXT: 50256}, } @@ -34,7 +34,7 @@ def r50k_base(): return { "name": "r50k_base", "explicit_n_vocab": 50257, - "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\w+| ?\d+| ?[^\s\w]+|\s+(?!\S)|\s+""", "mergeable_ranks": mergeable_ranks, "special_tokens": {ENDOFTEXT: 50256}, } @@ -48,7 +48,7 @@ def p50k_base(): return { "name": "p50k_base", "explicit_n_vocab": 50281, - "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\w+| ?\d+| ?[^\s\w]+|\s+(?!\S)|\s+""", "mergeable_ranks": mergeable_ranks, "special_tokens": {ENDOFTEXT: 50256}, } @@ -62,7 +62,7 @@ def p50k_edit(): special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283} return { "name": "p50k_edit", - "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\w+| ?\d+| ?[^\s\w]+|\s+(?!\S)|\s+""", "mergeable_ranks": mergeable_ranks, "special_tokens": special_tokens, } @@ -82,7 +82,7 @@ def cl100k_base(): } return { "name": "cl100k_base", - "pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""", + "pat_str": r"""'(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\w]?\w+|\d{1,3}| ?[^\s\w]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""", "mergeable_ranks": mergeable_ranks, "special_tokens": special_tokens, } @@ -100,10 +100,10 @@ def o200k_base(): # This regex could be made more efficient pat_str = "|".join( [ - r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", - r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", - r"""\p{N}{1,3}""", - r""" ?[^\s\p{L}\p{N}]+[\r\n/]*""", + r"""[^\r\n\w]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", + r"""[^\r\n\w]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", + r"""\d{1,3}""", + r""" ?[^\s\w]+[\r\n/]*""", r"""\s*[\r\n]+""", r"""\s+(?!\S)""", r"""\s+""",