forked from EuroEval/EuroEval
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_data_loading.py
More file actions
193 lines (171 loc) · 7.24 KB
/
test_data_loading.py
File metadata and controls
193 lines (171 loc) · 7.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""Tests for the `data_loading` module."""
import os
from collections.abc import Generator
from functools import partial
from pathlib import Path
import pytest
from datasets import DatasetDict
from numpy.random import default_rng
from transformers.models.auto.tokenization_auto import AutoTokenizer
from euroeval.benchmark_modules.litellm import LiteLLMModel
from euroeval.constants import MAX_CONTEXT_LENGTH
from euroeval.data_loading import load_data, load_raw_data
from euroeval.data_models import BenchmarkConfig, DatasetConfig
from euroeval.dataset_configs import get_all_dataset_configs
from euroeval.enums import GenerativeType
from euroeval.generation_utils import apply_prompt, extract_few_shot_examples
from euroeval.tasks import RC
@pytest.fixture(scope="module")
def tokeniser_id() -> Generator[str, None, None]:
"""Fixture for the tokeniser ID.
Yields:
A tokeniser ID.
"""
yield "EuroEval/gemma-3-tokenizer"
class TestLoadData:
"""Tests for the `load_data` function."""
@pytest.fixture(scope="class")
def datasets(
self, benchmark_config: BenchmarkConfig
) -> Generator[list[DatasetDict], None, None]:
"""A loaded dataset.
Yields:
A loaded dataset.
"""
yield load_data(
rng=default_rng(seed=4242),
dataset_config=get_all_dataset_configs(
custom_datasets_file=Path("custom_datasets.py"),
dataset_ids=[],
api_key=os.getenv("HF_TOKEN"),
cache_dir=Path(".euroeval_cache"),
trust_remote_code=True,
run_with_cli=True,
)["multi-wiki-qa-da"],
benchmark_config=benchmark_config,
)
def test_load_data_is_list_of_dataset_dicts(
self, datasets: list[DatasetDict]
) -> None:
"""Test that the `load_data` function returns a list of `DatasetDict`."""
assert isinstance(datasets, list)
assert all(isinstance(d, DatasetDict) for d in datasets)
def test_split_names_are_correct(self, datasets: list[DatasetDict]) -> None:
"""Test that the split names are correct."""
assert all(set(d.keys()) == {"train", "val", "test"} for d in datasets)
def test_number_of_iterations_is_correct(
self, datasets: list[DatasetDict], benchmark_config: BenchmarkConfig
) -> None:
"""Test that the number of iterations is correct."""
assert len(datasets) == benchmark_config.num_iterations
def test_no_empty_examples(self, datasets: list[DatasetDict]) -> None:
"""Test that there are no empty examples in the datasets."""
for dataset in datasets:
for split in dataset.values():
for feature in ["text", "tokens"]:
if feature in split.features:
assert all(len(x) > 0 for x in split[feature])
@pytest.mark.parametrize(
argnames="dataset_config",
argvalues=[
dataset_config
for dataset_config in get_all_dataset_configs(
custom_datasets_file=Path("custom_datasets.py"),
dataset_ids=[],
api_key=os.getenv("HF_TOKEN"),
cache_dir=Path(".euroeval_cache"),
trust_remote_code=True,
run_with_cli=True,
).values()
if os.getenv("CHECK_DATASET") is not None
and (
dataset_config.name in os.environ["CHECK_DATASET"].split(",")
or any(
language.code in os.environ["CHECK_DATASET"].split(",")
for language in dataset_config.languages
)
or "all" in os.environ["CHECK_DATASET"].split(",")
)
],
ids=lambda dc: dc.name,
)
class TestAllDatasets:
"""Tests that are run on all datasets."""
def test_examples_in_official_datasets_are_not_too_long(
self,
dataset_config: DatasetConfig,
benchmark_config: BenchmarkConfig,
tokeniser_id: str,
) -> None:
"""Test that the examples are not too long in official datasets."""
dummy_model_config = LiteLLMModel.get_model_config(
model_id="model", benchmark_config=benchmark_config
)
tokeniser = AutoTokenizer.from_pretrained(tokeniser_id)
dataset = load_raw_data(
dataset_config=dataset_config,
cache_dir=benchmark_config.cache_dir,
api_key=benchmark_config.api_key,
)
for itr_idx in range(10):
if "train" in dataset_config.splits:
few_shot_examples = (
extract_few_shot_examples(
dataset=dataset,
dataset_config=dataset_config,
benchmark_config=benchmark_config,
itr_idx=itr_idx,
)
if not dataset_config.task.requires_zero_shot
else []
)
else:
few_shot_examples = []
for instruction_model in [True, False]:
prepared_test = dataset["test"].map(
partial(
apply_prompt,
few_shot_examples=few_shot_examples,
model_config=dummy_model_config,
dataset_config=dataset_config,
generative_type=(
GenerativeType.INSTRUCTION_TUNED
if instruction_model
else GenerativeType.BASE
),
always_populate_text_field=True,
tokeniser=tokeniser,
),
batched=True,
load_from_cache_file=False,
keep_in_memory=True,
)
max_input_length = max(
len(tokeniser(prompt)["input_ids"])
for prompt in prepared_test["text"]
)
max_output_length = dataset_config.max_generated_tokens
max_length = max_input_length + max_output_length
assert max_length <= MAX_CONTEXT_LENGTH, (
f"Max length of {max_length:,} exceeds the maximum context length "
f"({MAX_CONTEXT_LENGTH:,}) for dataset {dataset_config.name} in "
f"iteration {itr_idx} and when instruction_model="
f"{instruction_model}."
)
def test_reading_comprehension_datasets_have_id_column(
self, dataset_config: DatasetConfig, benchmark_config: BenchmarkConfig
) -> None:
"""Test that reading comprehension datasets have an ID column."""
# Skip if the dataset is not a reading comprehension dataset
if dataset_config.task != RC:
pytest.skip(reason="Skipping test for non-reading comprehension dataset.")
dataset = load_raw_data(
dataset_config=dataset_config,
cache_dir=benchmark_config.cache_dir,
api_key=benchmark_config.api_key,
)
for split in dataset_config.splits:
assert "id" in dataset[split].features, (
f"Dataset {dataset_config.name} is a reading comprehension dataset but "
f"the {split} split does not have an 'id' column."
)