Skip to content

Commit 8604740

Browse files
authored
update stub for typing (#1896)
* update stub for typing * up * add ty type checker * update stub * up * some update * add owner to stub? * update * no print * uptime funk * mm * wtf * fix * fix more * some fixses are manual but come on * up * # type: ignore[import] * reduce the scope of ty for less changes * ups * up?
1 parent a5e03ba commit 8604740

34 files changed

+2299
-125
lines changed

.github/workflows/python.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ jobs:
116116
python -m venv .env
117117
source .env/bin/activate
118118
pip install -U pip
119-
pip install pytest requests setuptools_rust numpy pyarrow datasets
119+
pip install pytest requests setuptools_rust numpy pyarrow datasets ty
120120
pip install -e .[dev]
121121
122122
- name: Check style
@@ -125,6 +125,12 @@ jobs:
125125
source .env/bin/activate
126126
make check-style
127127
128+
- name: Type check
129+
working-directory: ./bindings/python
130+
run: |
131+
source .env/bin/activate
132+
ty check py_src tests
133+
128134
- name: Run tests
129135
working-directory: ./bindings/python
130136
run: |

bindings/python/Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ style:
1010
python stub.py
1111
ruff check $(check_dirs) --fix
1212
ruff format $(check_dirs)
13+
ty check py_src tests
1314

1415
# Check the source code is formatted correctly
1516
check-style:
1617
python stub.py --check
1718
ruff check $(check_dirs)
1819
ruff format --check $(check_dirs)
20+
ty check py_src tests
1921

2022
TESTS_RESOURCES = $(DATA_DIR)/small.txt $(DATA_DIR)/roberta.json
2123

bindings/python/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,7 @@ tokenizer = Tokenizer.from_file("byte-level-bpe.tokenizer.json")
168168

169169
encoded = tokenizer.encode("I can feel the magic, can you?")
170170
```
171+
172+
### Typing support and `stub.py`
173+
174+
The compiled PyO3 extension does not expose type annotations, so editors and type checkers would otherwise see most objects as `Any`. The `stub.py` helper walks the loaded extension modules, renders `.pyi` stub files (plus minimal forwarding `__init__.py` shims), and formats them so that tools like mypy/pyright can understand the public API. Run `python stub.py` whenever you change the Python-visible surface to keep the generated stubs in sync.

bindings/python/benches/test_tiktoken.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import time
33
import argparse
44
from datasets import load_dataset
5-
from tiktoken.load import load_tiktoken_bpe
6-
import tiktoken
5+
from tiktoken.load import load_tiktoken_bpe # type: ignore[import]
6+
import tiktoken # type: ignore[import]
77
from tokenizers import Tokenizer
88
from huggingface_hub import hf_hub_download
99
from typing import Tuple, List
@@ -30,7 +30,9 @@ def benchmark_batch(model: str, documents: list[str], num_threads: int, document
3030
num_bytes = sum(map(len, map(str.encode, documents)))
3131
readable_size, unit = format_byte_size(num_bytes)
3232
print(f"==============")
33-
print(f"num_threads: {num_threads}, data size: {readable_size}, documents: {len(documents)} Avg Length: {document_length:.0f}")
33+
print(
34+
f"num_threads: {num_threads}, data size: {readable_size}, documents: {len(documents)} Avg Length: {document_length:.0f}"
35+
)
3436
filename = hf_hub_download(MODEL_ID, "original/tokenizer.model")
3537
mergeable_ranks = load_tiktoken_bpe(filename)
3638
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
@@ -46,20 +48,15 @@ def benchmark_batch(model: str, documents: list[str], num_threads: int, document
4648
"<|end_header_id|>",
4749
"<|reserved_special_token_4|>",
4850
"<|eot_id|>", # end of turn
49-
] + [
50-
f"<|reserved_special_token_{i}|>"
51-
for i in range(5, num_reserved_special_tokens - 5)
52-
]
51+
] + [f"<|reserved_special_token_{i}|>" for i in range(5, num_reserved_special_tokens - 5)]
5352
num_base_tokens = len(mergeable_ranks)
54-
special_tokens = {
55-
token: num_base_tokens + i for i, token in enumerate(special_tokens)
56-
}
53+
special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
5754
enc = tiktoken.Encoding(
58-
name=model,
59-
pat_str=pat_str,
60-
mergeable_ranks=mergeable_ranks,
61-
special_tokens=special_tokens,
62-
)
55+
name=model,
56+
pat_str=pat_str,
57+
mergeable_ranks=mergeable_ranks,
58+
special_tokens=special_tokens,
59+
)
6360
out = enc.encode("This is a test")
6461

6562
hf_enc = Tokenizer.from_pretrained(model)
@@ -74,7 +71,6 @@ def benchmark_batch(model: str, documents: list[str], num_threads: int, document
7471
readable_size, unit = format_byte_size(num_bytes / (end - start) * 1e9)
7572
print(f"tiktoken \t{readable_size} / s")
7673

77-
7874
start = time.perf_counter_ns()
7975
hf_enc.encode_batch_fast(documents)
8076
end = time.perf_counter_ns()
@@ -98,7 +94,7 @@ def test(model: str, dataset: str, dataset_config: str, threads: List[int]):
9894
else:
9995
documents.append(item["premise"]["en"])
10096
if fuse:
101-
documents=["".join(documents)]
97+
documents = ["".join(documents)]
10298

10399
document_length = sum(len(d) for d in documents) / len(documents)
104100

@@ -115,15 +111,14 @@ def test(model: str, dataset: str, dataset_config: str, threads: List[int]):
115111

116112

117113
def main():
118-
119114
parser = argparse.ArgumentParser(
120-
prog='bench_tokenizer',
121-
description='Getting a feel for speed when tokenizing',
115+
prog="bench_tokenizer",
116+
description="Getting a feel for speed when tokenizing",
122117
)
123-
parser.add_argument('-m', '--model', default=MODEL_ID, type=str)
124-
parser.add_argument('-d', '--dataset', default=DATASET, type=str)
125-
parser.add_argument('-ds', '--dataset-config', default=DATASET_CONFIG, type=str)
126-
parser.add_argument('-t', '--threads', nargs='+', default=DEFAULT_THREADS, type=int)
118+
parser.add_argument("-m", "--model", default=MODEL_ID, type=str)
119+
parser.add_argument("-d", "--dataset", default=DATASET, type=str)
120+
parser.add_argument("-ds", "--dataset-config", default=DATASET_CONFIG, type=str)
121+
parser.add_argument("-t", "--threads", nargs="+", default=DEFAULT_THREADS, type=int)
127122
args = parser.parse_args()
128123
test(args.model, args.dataset, args.dataset_config, args.threads)
129124

bindings/python/docs/pyo3.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# PyO3 Usage Notes
2+
3+
## Why we take `self_: PyRef<'_, Self>`
4+
5+
Most of the Python-facing structs are declared with `#[pyclass(extends = ...)]`. The actual data (for example the `processor` field in `PyPostProcessor`) lives in the base class, while the derived Rust structs are often just markers so that Python sees a proper subclass. When we implement a method on the subclass, we still need to reach into the base storage without downcasting a `PyAny` or re-wrapping objects.
6+
7+
Using `self_: PyRef<'_, Self>` gives us a borrowed reference to the Python-owned value that keeps the GIL lifetime, reference counts, and the inheritance chain intact. With it we can call `self_.as_ref()` to view the base `PyPostProcessor` directly and access shared helpers like the processor getters/setters. If we used a plain `&self` we would only see the zero-sized derived struct and would have to convert through a super type just to touch the processors, which adds boilerplate and loses the link to the Python inheritance model. This is the PyO3 equivalent of Python’s `super()`—it keeps the Rust type information while letting us operate on the underlying parent.

bindings/python/examples/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tokenizers.models import BPE, WordPiece
99
from tokenizers.normalizers import BertNormalizer
1010
from tokenizers.processors import BertProcessing
11-
from transformers import BertTokenizer, GPT2Tokenizer
11+
from transformers import BertTokenizer, GPT2Tokenizer # type: ignore[import]
1212

1313
logging.getLogger("transformers").disabled = True
1414
logging.getLogger("transformers.tokenization_utils").disabled = True

bindings/python/examples/train_with_datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
# Build an iterator over this dataset
1616
def batch_iterator():
1717
batch_size = 1000
18-
for batch in dataset.iter(batch_size=batch_size):
18+
for batch in dataset.iter(batch_size=batch_size): # type: ignore[attr-defined]
1919
yield batch["text"]
2020

2121

2222
# And finally train
23-
bpe_tokenizer.train_from_iterator(batch_iterator(), length=len(dataset))
23+
bpe_tokenizer.train_from_iterator(batch_iterator(), length=len(dataset)) # type: ignore[arg-type]

bindings/python/py_src/tokenizers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class SplitDelimiterBehavior(Enum):
7575
CONTIGUOUS = "contiguous"
7676

7777

78-
from .tokenizers import (
78+
from .tokenizers import ( # type: ignore[import]
7979
AddedToken,
8080
Encoding,
8181
NormalizedString,

0 commit comments

Comments
 (0)