-
Notifications
You must be signed in to change notification settings - Fork 434
custom dataset loader #324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f8287f0
bb1e021
9195e1e
c1bd607
157bd21
3edd74b
29b0d9a
ea15f9a
34b02ec
af379af
6e07bd5
2105de2
5f9a4c2
f238449
2f00e33
793dbe0
ae1c1f0
799f29e
97ea615
6172e24
d3e4269
92a54a5
7b167ca
601371d
9d0ed54
12aab83
474bfa7
7f746d1
7d91be2
7d2f976
9222066
5c172b2
11d2930
36c83b3
f6fb8c5
41c5ef5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| import collections | ||
| import random | ||
|
|
||
| import datasets | ||
|
|
||
| import textattack | ||
| from textattack.datasets import HuggingFaceDataset | ||
|
|
||
|
|
||
| def _cb(s): | ||
| """Colors some text blue for printing to the terminal.""" | ||
| return textattack.shared.utils.color_text(str(s), color="blue", method="ansi") | ||
|
|
||
|
|
||
| class CustomDataset(HuggingFaceDataset): | ||
| """Loads a Custom Dataset like HuggingFace custom ``datasets`` and prepares it as a | ||
| TextAttack dataset. | ||
|
|
||
| - name(str): the dataset file names | ||
| - file_type(str): Specifies type of file for loading HuggingFaceDataset | ||
| from local_files will be loaded as ``datasets.load_dataset(filetype, data_files=name)``. | ||
| - label_map: Mapping if output labels should be re-mapped. Useful | ||
| if model was trained with a different label arrangement than | ||
| provided in the ``datasets`` version of the dataset. | ||
| - output_scale_factor (float): Factor to divide ground-truth outputs by. | ||
|
||
| Generally, TextAttack goal functions require model outputs | ||
| between 0 and 1. Some datasets test the model's correlation | ||
| with ground-truth output, instead of its accuracy, so these | ||
| outputs may be scaled arbitrarily. | ||
| - dataset_columns (list): dataset_columns[0]: input columns, dataset_columns[1]: output_columns | ||
| - shuffle (bool): Whether to shuffle the dataset on load. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| name, | ||
| filetype="csv", | ||
| split="train", | ||
| label_map=None, | ||
| output_scale_factor=None, | ||
| dataset_columns=[("text",), None], | ||
| shuffle=False, | ||
| ): | ||
|
|
||
| self._name = name | ||
|
|
||
| self._dataset = datasets.load_dataset(filetype, data_files=name)[split] | ||
|
|
||
| subset = None | ||
| subset_print_str = f", subset {_cb(subset)}" if subset else "" | ||
|
|
||
| textattack.shared.logger.info( | ||
| f"Loading {_cb('datasets')} dataset {_cb(name)}{subset_print_str}, split {_cb(split)}." | ||
| ) | ||
| # Input/output column order, like (('premise', 'hypothesis'), 'label') | ||
| if not set(dataset_columns[0]) <= set(self._dataset.column_names): | ||
|
||
| raise ValueError( | ||
| f"Could not find input column {dataset_columns[0]} in CSV. Found keys: {self._dataset.column_names}" | ||
| ) | ||
| self.input_columns = dataset_columns[0] | ||
| self.output_column = dataset_columns[1] | ||
| if ( | ||
| self.output_column is not None | ||
| and self.output_column not in self._dataset.column_names | ||
| ): | ||
| raise ValueError( | ||
| f"Could not find input column {dataset_columns[1]} in CSV. Found keys: {self._dataset.column_names}" | ||
| ) | ||
|
|
||
| self._i = 0 | ||
| self.examples = list(self._dataset) | ||
|
|
||
| self.label_map = label_map | ||
|
|
||
| self.output_scale_factor = output_scale_factor | ||
|
|
||
| try: | ||
| self.label_names = self._dataset.features["label"].names | ||
|
|
||
| # If labels are remapped, the label names have to be remapped as | ||
| # well. | ||
| if label_map: | ||
|
|
||
| self.label_names = [ | ||
| self.label_names[self.label_map[i]] | ||
| for i in range(len(self.label_map)) | ||
| ] | ||
| print(self.label_names) | ||
| except KeyError: | ||
| # This happens when the dataset doesn't have 'features' or a 'label' column. | ||
|
|
||
| self.label_names = None | ||
|
|
||
| except AttributeError: | ||
| # This happens when self._dataset.features["label"] exists | ||
| # but is a single value. | ||
|
|
||
| self.label_names = ("label",) | ||
|
|
||
| if shuffle: | ||
| random.shuffle(self.examples) | ||
|
|
||
| def _format_raw_example(self, raw_example): | ||
|
||
| input_dict = collections.OrderedDict( | ||
| [(c, raw_example[c]) for c in self.input_columns] | ||
| ) | ||
| if self.output_column is not None: | ||
| output = raw_example[self.output_column] | ||
| if self.label_map: | ||
| output = self.label_map[output] | ||
| if self.output_scale_factor: | ||
| output = output / self.output_scale_factor | ||
| else: | ||
| output = None | ||
|
|
||
| return (input_dict, output) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you include an example of how an user should define this
label_map(also is this adict)? For example, should it look like{"Positive": 1, "Negative": 0}?