Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f8287f0
custom data loader
ArshdeepSekhon Oct 29, 2020
bb1e021
custom textattack dataset from local files or in memory using hugging…
ArshdeepSekhon Oct 30, 2020
9195e1e
load user dataset from local files and convert to TextAttack dataset …
ArshdeepSekhon Nov 4, 2020
c1bd607
load user dataset from local files and convert to TextAttack dataset …
ArshdeepSekhon Nov 4, 2020
157bd21
load user dataset from local files and convert to textattack dataset …
ArshdeepSekhon Nov 4, 2020
3edd74b
load user dataset from local files and convert to textattack dataset …
ArshdeepSekhon Nov 4, 2020
29b0d9a
custom dataset: add attribute error
ArshdeepSekhon Nov 4, 2020
ea15f9a
custom dataset: remove stray prints
ArshdeepSekhon Nov 4, 2020
34b02ec
fix output column for custom dataset
ArshdeepSekhon Nov 4, 2020
af379af
custom dataset: add support for dict
ArshdeepSekhon Nov 4, 2020
6e07bd5
custom dataset: checks
ArshdeepSekhon Nov 4, 2020
2105de2
option to test on entire dataset
ArshdeepSekhon Oct 22, 2020
5f9a4c2
eval on entire dataset, checks
ArshdeepSekhon Oct 22, 2020
f238449
fix failed checks
ArshdeepSekhon Oct 22, 2020
2f00e33
custom data loader
ArshdeepSekhon Oct 29, 2020
793dbe0
custom textattack dataset from local files or in memory using hugging…
ArshdeepSekhon Oct 30, 2020
ae1c1f0
load user dataset from local files and convert to TextAttack dataset …
ArshdeepSekhon Nov 4, 2020
799f29e
load user dataset from local files and convert to TextAttack dataset …
ArshdeepSekhon Nov 4, 2020
97ea615
load user dataset from local files and convert to textattack dataset …
ArshdeepSekhon Nov 4, 2020
6172e24
load user dataset from local files and convert to textattack dataset …
ArshdeepSekhon Nov 4, 2020
d3e4269
custom dataset: add attribute error
ArshdeepSekhon Nov 4, 2020
92a54a5
custom dataset: remove stray prints
ArshdeepSekhon Nov 4, 2020
7b167ca
fix output column for custom dataset
ArshdeepSekhon Nov 4, 2020
601371d
custom dataset: add support for dict
ArshdeepSekhon Nov 4, 2020
9d0ed54
custom dataset: checks
ArshdeepSekhon Nov 4, 2020
12aab83
skeleton code for custom dataset
ArshdeepSekhon Nov 24, 2020
474bfa7
Merge branch 'custom_dataset' of https://github.com/ArshdeepSekhon/Te…
ArshdeepSekhon Nov 24, 2020
7f746d1
add utils for reading from files
ArshdeepSekhon Nov 25, 2020
7d91be2
add support for reading from csv, df, txt
ArshdeepSekhon Nov 25, 2020
7d2f976
fix format errors
ArshdeepSekhon Dec 4, 2020
9222066
update the confusing word"Successes" to "True Positive/Positive"
qiyanjun Dec 4, 2020
5c172b2
update the confusing uses of "Successes" to "True Positive/Positive"
qiyanjun Dec 4, 2020
11d2930
Merge branch 'master' into custom_dataset
ArshdeepSekhon Dec 4, 2020
36c83b3
black,isort formatting
ArshdeepSekhon Dec 4, 2020
f6fb8c5
Update dataset.py
qiyanjun Dec 5, 2020
41c5ef5
fix a wrong typo
qiyanjun Dec 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 26 additions & 30 deletions textattack/commands/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import textattack
from textattack.commands import TextAttackCommand
from textattack.datasets import CustomDataset

AUGMENTATION_RECIPE_NAMES = {
"wordnet": "textattack.augmentation.WordNetAugmenter",
Expand Down Expand Up @@ -98,13 +99,13 @@ def run(self, args):
else:
textattack.shared.utils.set_seed(args.random_seed)
start_time = time.time()
if not (args.csv and args.input_column):
if not (args.infile and args.input_column):
raise ArgumentError(
"The following arguments are required: --csv, --input-column/--i"
"The following arguments are required: --infile, --input-column/--i"
)
# Validate input/output paths.
if not os.path.exists(args.csv):
raise FileNotFoundError(f"Can't find CSV at location {args.csv}")
if not os.path.exists(args.infile):
raise FileNotFoundError(f"Can't find file at location {args.infile}")
if os.path.exists(args.outfile):
if args.overwrite:
textattack.shared.logger.info(
Expand All @@ -114,25 +115,9 @@ def run(self, args):
raise OSError(
f"Outfile {args.outfile} exists and --overwrite not set."
)
# Read in CSV file as a list of dictionaries. Use the CSV sniffer to
# try and automatically infer the correct CSV format.
csv_file = open(args.csv, "r")
dialect = csv.Sniffer().sniff(csv_file.readline(), delimiters=";,")
csv_file.seek(0)
rows = [
row
for row in csv.DictReader(
csv_file, dialect=dialect, skipinitialspace=True
)
]
# Validate input column.
row_keys = set(rows[0].keys())
if args.input_column not in row_keys:
raise ValueError(
f"Could not find input column {args.input_column} in CSV. Found keys: {row_keys}"
)
textattack.shared.logger.info(
f"Read {len(rows)} rows from {args.csv}. Found columns {row_keys}."
# Read in file using huggingface dataloader
dataset_to_augment = CustomDataset(
infile_format=args.infile_format, name=args.infile
)

augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])(
Expand All @@ -141,24 +126,28 @@ def run(self, args):
)

output_rows = []
for row in tqdm.tqdm(rows, desc="Augmenting rows"):
text_input = row[args.input_column]

for row in tqdm.tqdm(dataset_to_augment, desc="Augmenting rows"):

text_input = row[0][args.input_column]
if not args.exclude_original:
output_rows.append(row)
for augmentation in augmenter.augment(text_input):
augmented_row = row.copy()
augmented_row[args.input_column] = augmentation
augmented_row = row

augmented_row[0][args.input_column] = augmentation
output_rows.append(augmented_row)
# Print to file.
with open(args.outfile, "w") as outfile:
csv_writer = csv.writer(
outfile, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL
)
# Write header.
csv_writer.writerow(output_rows[0].keys())
csv_writer.writerow(output_rows[0][0].keys())
# Write rows.
for row in output_rows:
csv_writer.writerow(row.values())

csv_writer.writerow(row[0].values())
textattack.shared.logger.info(
f"Wrote {len(output_rows)} augmentations to {args.outfile} in {time.time() - start_time}s."
)
Expand All @@ -171,7 +160,14 @@ def register_subcommand(main_parser: ArgumentParser):
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--csv",
"--infile_format",
help="input file type ",
type=str,
required=False,
default="csv",
)
parser.add_argument(
"--infile",
help="input csv file to augment",
type=str,
required=False,
Expand Down
2 changes: 1 addition & 1 deletion textattack/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .dataset import TextAttackDataset
from .huggingface_dataset import HuggingFaceDataset

from .custom_dataset import CustomDataset
from . import translation
166 changes: 166 additions & 0 deletions textattack/datasets/custom_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import collections
import random

import datasets
import pandas as pd

import textattack
from textattack.datasets import TextAttackDataset


def _cb(s):
"""Colors some text blue for printing to the terminal."""
if not isinstance(s, str):
s = "custom " + str(type(s))
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")


class CustomDataset(TextAttackDataset):
"""Loads a Custom Dataset from a file/list of files and prepares it as a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like users can pass int a list of files (or a dictionary of files where keys are the split name (e.g. "train", "test"). But the documentation for name argument suggests that it's just a string.

TextAttack dataset.

- name(Union[str, dict, pd.DataFrame]): the user specified dataset file names, dicts or pandas dataframe
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you think name argument would be too confusing? Maybe file_name_or_data?

- infile_format(str): Specifies type of file for loading HuggingFaceDataset : csv, json, pandas, text
from local_files will be loaded as ``datasets.load_dataset(filetype, data_files=name)``.
- label_map: Mapping if output labels should be re-mapped. Useful
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you include an example of how an user should define this label_map (also is this a dict)? For example, should it look like {"Positive": 1, "Negative": 0}?

if model was trained with a different label arrangement than
provided in the ``datasets`` version of the dataset.
- output_scale_factor (float): Factor to divide ground-truth outputs by.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused what output_scale_factor is for. Can you maybe give an example of when this would be used?

Generally, TextAttack goal functions require model outputs
between 0 and 1. Some datasets test the model's correlation
with ground-truth output, instead of its accuracy, so these
outputs may be scaled arbitrarily.
- dataset_columns (list[str]): dataset_columns[0]: input columns specified as a tuple or list, dataset_columns[1]: output_columns
- shuffle (bool): Whether to shuffle the dataset on load.
"""

def __init__(
self,
name,
infile_format=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think file_format is an easier term to remember.

split=None,
label_map=None,
subset=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see subset argument being used anywhere in the code when loading the examples (other than in logging part). I think subset is used in HuggingFaceDataset because datasets like glue has subsets (e.g. sst2), but I don't see the point here. Do you think it is redundant?

output_scale_factor=None,
dataset_columns=None,
shuffle=False,
):

self._name = name

if infile_format in ["csv", "json", "text", "pandas"]:
self._dataset = datasets.load_dataset(infile_format, data_files=self._name)

else:
if isinstance(self._name, dict):
self._dataset = datasets.Dataset.from_dict(self._name)
elif isinstance(self._name, pd.DataFrame):
self._dataset = datasets.Dataset.from_pandas(self._name)
else:
raise ValueError(
"Only accepts csv, json, text, pandas file infile_format, dict and pandas DataFrame"
)

# if no split in custom data, default split is None
# if user hasn't specified a split to use, raise error if the dataset has splits
if split is None:
if set(self._dataset.keys()) <= set(["train", "validation", "test"]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two cases pop into my mind:

  1. What happens if set(self._dataset.keys()) is empty? Then, the condition is true and will raise the ValueError. But what if the CSV/JSON file we loaded is for train split and thus self._dataset doesn't really have any concept of "splits"?
  2. What happens if set(self._dataset.keys()) don't have split names that are explicitly train, validation, and test? Maybe the user could have done passed in train, val, test, which would cause the condition to be False.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One solution could be that if it there are splits (e.g. set(self._dataset.keys()) is True), we require split to be not be None.

raise ValueError(f"specify a split to use: {self._dataset.keys()}")
else:
self._dataset = self._dataset[split]

subset_print_str = f", subset {_cb(subset)}" if subset else ""

textattack.shared.logger.info(
f"Loading {_cb('datasets')} dataset {_cb(name)}{subset_print_str}, split {_cb(split)}."
)
# Input/output column order, like (('premise', 'hypothesis'), 'label')

if dataset_columns is None:
# automatically infer from dataset
dataset_columns = []
dataset_columns.append(self._dataset.column_names)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that self._dataset.column_names is a list. Why are we appending it to another list instead of setting dataset_columns = self._dataset.column_names?


if not set(dataset_columns[0]) <= set(self._dataset.column_names):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm just not familiar with how datasets work, but does this mean first column of the data should always be input data and the second column is output?

My main concern is that this would be too restrictive for users. For example, NLI datasets have two input columns "premise" and "hypothesis".

raise ValueError(
f"Could not find input column {dataset_columns[0]}. Found keys: {self._dataset.column_names}"
)
self.input_columns = dataset_columns[0]

if len(dataset_columns) == 1:
# if user hasnt specified an output column or dataset_columns is None, all dataset_columns are
# treated as input_columns
dataset_columns.append(None)
# if user has specified an output column, check if it exists in the inferred column names
# user can explicitly specify output column as None
if (
dataset_columns[1] is not None
and dataset_columns[1] not in self._dataset.column_names
):

raise ValueError(
f"Could not find output column {dataset_columns[1]}. Found keys: {self._dataset.column_names}"
)
self.output_column = dataset_columns[1]

self._i = 0
self.examples = list(self._dataset)

self.label_map = label_map

self.output_scale_factor = output_scale_factor

try:

self.label_names = self._dataset.features["label"].names

# If labels are remapped, the label names have to be remapped as
# well.
if label_map:

self.label_names = [
self.label_names[self.label_map[i]]
for i in range(len(self.label_map))
]

except KeyError:
# This happens when the dataset doesn't have 'features' or a 'label' column.

self.label_names = None
except AttributeError:
# This happens when self._dataset.features["label"] exists
# but is a single value.
self.label_names = ("label",)

if shuffle:
random.shuffle(self.examples)

def _format_raw_example(self, raw_example):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the last three methods are just copied from HuggingFaceDataset, would it be more appropriate to just inherit from HuggingFaceDataset?

input_dict = collections.OrderedDict(
[(c, raw_example[c]) for c in self.input_columns]
)
if self.output_column is not None:
output = raw_example[self.output_column]
if self.label_map:
output = self.label_map[output]
if self.output_scale_factor:
output = output / self.output_scale_factor
return (input_dict, output)

else:
return (input_dict,)

def __next__(self):
if self._i >= len(self.examples):
raise StopIteration
raw_example = self.examples[self._i]
self._i += 1
return self._format_raw_example(raw_example)

def __getitem__(self, i):
if isinstance(i, int):
return self._format_raw_example(self.examples[i])
else:
# `i` could be a slice or an integer. if it's a slice,
# return the formatted version of the proper slice of the list
return [self._format_raw_example(ex) for ex in self.examples[i]]