-
Notifications
You must be signed in to change notification settings - Fork 434
custom dataset loader #324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
f8287f0
bb1e021
9195e1e
c1bd607
157bd21
3edd74b
29b0d9a
ea15f9a
34b02ec
af379af
6e07bd5
2105de2
5f9a4c2
f238449
2f00e33
793dbe0
ae1c1f0
799f29e
97ea615
6172e24
d3e4269
92a54a5
7b167ca
601371d
9d0ed54
12aab83
474bfa7
7f746d1
7d91be2
7d2f976
9222066
5c172b2
11d2930
36c83b3
f6fb8c5
41c5ef5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| from .dataset import TextAttackDataset | ||
| from .huggingface_dataset import HuggingFaceDataset | ||
|
|
||
| from .custom_dataset import CustomDataset | ||
| from . import translation |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| import collections | ||
| import random | ||
|
|
||
| import datasets | ||
| import pandas as pd | ||
|
|
||
| import textattack | ||
| from textattack.datasets import TextAttackDataset | ||
|
|
||
|
|
||
| def _cb(s): | ||
| """Colors some text blue for printing to the terminal.""" | ||
| if not isinstance(s, str): | ||
| s = "custom " + str(type(s)) | ||
| return textattack.shared.utils.color_text(str(s), color="blue", method="ansi") | ||
|
|
||
|
|
||
| class CustomDataset(TextAttackDataset): | ||
| """Loads a Custom Dataset from a file/list of files and prepares it as a | ||
| TextAttack dataset. | ||
|
|
||
| - name(Union[str, dict, pd.DataFrame]): the user specified dataset file names, dicts or pandas dataframe | ||
|
||
| - infile_format(str): Specifies type of file for loading HuggingFaceDataset : csv, json, pandas, text | ||
| from local_files will be loaded as ``datasets.load_dataset(filetype, data_files=name)``. | ||
| - label_map: Mapping if output labels should be re-mapped. Useful | ||
|
||
| if model was trained with a different label arrangement than | ||
| provided in the ``datasets`` version of the dataset. | ||
| - output_scale_factor (float): Factor to divide ground-truth outputs by. | ||
|
||
| Generally, TextAttack goal functions require model outputs | ||
| between 0 and 1. Some datasets test the model's correlation | ||
| with ground-truth output, instead of its accuracy, so these | ||
| outputs may be scaled arbitrarily. | ||
| - dataset_columns (list[str]): dataset_columns[0]: input columns specified as a tuple or list, dataset_columns[1]: output_columns | ||
| - shuffle (bool): Whether to shuffle the dataset on load. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| name, | ||
| infile_format=None, | ||
|
||
| split=None, | ||
| label_map=None, | ||
| subset=None, | ||
|
||
| output_scale_factor=None, | ||
| dataset_columns=None, | ||
| shuffle=False, | ||
| ): | ||
|
|
||
| self._name = name | ||
|
|
||
| if infile_format in ["csv", "json", "text", "pandas"]: | ||
| self._dataset = datasets.load_dataset(infile_format, data_files=self._name) | ||
|
|
||
| else: | ||
| if isinstance(self._name, dict): | ||
| self._dataset = datasets.Dataset.from_dict(self._name) | ||
| elif isinstance(self._name, pd.DataFrame): | ||
| self._dataset = datasets.Dataset.from_pandas(self._name) | ||
| else: | ||
| raise ValueError( | ||
| "Only accepts csv, json, text, pandas file infile_format, dict and pandas DataFrame" | ||
| ) | ||
|
|
||
| # if no split in custom data, default split is None | ||
| # if user hasn't specified a split to use, raise error if the dataset has splits | ||
| if split is None: | ||
| if set(self._dataset.keys()) <= set(["train", "validation", "test"]): | ||
|
||
| raise ValueError(f"specify a split to use: {self._dataset.keys()}") | ||
| else: | ||
| self._dataset = self._dataset[split] | ||
|
|
||
| subset_print_str = f", subset {_cb(subset)}" if subset else "" | ||
|
|
||
| textattack.shared.logger.info( | ||
| f"Loading {_cb('datasets')} dataset {_cb(name)}{subset_print_str}, split {_cb(split)}." | ||
| ) | ||
| # Input/output column order, like (('premise', 'hypothesis'), 'label') | ||
|
|
||
| if dataset_columns is None: | ||
| # automatically infer from dataset | ||
| dataset_columns = [] | ||
| dataset_columns.append(self._dataset.column_names) | ||
|
||
|
|
||
| if not set(dataset_columns[0]) <= set(self._dataset.column_names): | ||
|
||
| raise ValueError( | ||
| f"Could not find input column {dataset_columns[0]}. Found keys: {self._dataset.column_names}" | ||
| ) | ||
| self.input_columns = dataset_columns[0] | ||
|
|
||
| if len(dataset_columns) == 1: | ||
| # if user hasnt specified an output column or dataset_columns is None, all dataset_columns are | ||
| # treated as input_columns | ||
| dataset_columns.append(None) | ||
| # if user has specified an output column, check if it exists in the inferred column names | ||
| # user can explicitly specify output column as None | ||
| if ( | ||
| dataset_columns[1] is not None | ||
| and dataset_columns[1] not in self._dataset.column_names | ||
| ): | ||
|
|
||
| raise ValueError( | ||
| f"Could not find output column {dataset_columns[1]}. Found keys: {self._dataset.column_names}" | ||
| ) | ||
| self.output_column = dataset_columns[1] | ||
|
|
||
| self._i = 0 | ||
| self.examples = list(self._dataset) | ||
|
|
||
| self.label_map = label_map | ||
|
|
||
| self.output_scale_factor = output_scale_factor | ||
|
|
||
| try: | ||
|
|
||
| self.label_names = self._dataset.features["label"].names | ||
|
|
||
| # If labels are remapped, the label names have to be remapped as | ||
| # well. | ||
| if label_map: | ||
|
|
||
| self.label_names = [ | ||
| self.label_names[self.label_map[i]] | ||
| for i in range(len(self.label_map)) | ||
| ] | ||
|
|
||
| except KeyError: | ||
| # This happens when the dataset doesn't have 'features' or a 'label' column. | ||
|
|
||
| self.label_names = None | ||
| except AttributeError: | ||
| # This happens when self._dataset.features["label"] exists | ||
| # but is a single value. | ||
| self.label_names = ("label",) | ||
|
|
||
| if shuffle: | ||
| random.shuffle(self.examples) | ||
|
|
||
| def _format_raw_example(self, raw_example): | ||
|
||
| input_dict = collections.OrderedDict( | ||
| [(c, raw_example[c]) for c in self.input_columns] | ||
| ) | ||
| if self.output_column is not None: | ||
| output = raw_example[self.output_column] | ||
| if self.label_map: | ||
| output = self.label_map[output] | ||
| if self.output_scale_factor: | ||
| output = output / self.output_scale_factor | ||
| return (input_dict, output) | ||
|
|
||
| else: | ||
| return (input_dict,) | ||
|
|
||
| def __next__(self): | ||
| if self._i >= len(self.examples): | ||
| raise StopIteration | ||
| raw_example = self.examples[self._i] | ||
| self._i += 1 | ||
| return self._format_raw_example(raw_example) | ||
|
|
||
| def __getitem__(self, i): | ||
| if isinstance(i, int): | ||
| return self._format_raw_example(self.examples[i]) | ||
| else: | ||
| # `i` could be a slice or an integer. if it's a slice, | ||
| # return the formatted version of the proper slice of the list | ||
| return [self._format_raw_example(ex) for ex in self.examples[i]] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like users can pass int a list of files (or a dictionary of files where keys are the split name (e.g. "train", "test"). But the documentation for
nameargument suggests that it's just a string.