Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ for data augmentation:
- `textattack.CharSwapAugmenter` augments text by substituting, deleting, inserting, and swapping adjacent characters
- `textattack.EasyDataAugmenter` augments text with a combination of word insertions, substitutions and deletions.
- `textattack.CheckListAugmenter` augments text by contraction/extension and by substituting names, locations, numbers.
- `textattack.CLAREAugmenter` augments text by replacing, inserting, and merging with a pre-trained masked language model.

#### Augmentation Command-Line Interface
The easiest way to use our data augmentation tools is with `textattack augment <args>`. `textattack augment`
Expand All @@ -323,6 +324,9 @@ The command `textattack augment --csv examples.csv --input-column text --recipe
will augment the `text` column by altering 10% of each example's words, generating twice as many augmentations as original inputs, and exclude the original inputs from the
output CSV. (All of this will be saved to `augment.csv` by default.)

> **Tip:** Just as running attacks interactively, you can also pass `--interactive` to augment samples inputted by the user to quickly try out different augmentation recipes!


After augmentation, here are the contents of `augment.csv`:
```csv
text,label
Expand Down
2 changes: 2 additions & 0 deletions tests/sample_outputs/list_augmentation_recipes.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
charswap (textattack.augmentation.CharSwapAugmenter)
checklist (textattack.augmentation.CheckListAugmenter)
clare (textattack.augmentation.CLAREAugmenter)
eda (textattack.augmentation.EasyDataAugmenter)
embedding (textattack.augmentation.EmbeddingAugmenter)
wordnet (textattack.augmentation.WordNetAugmenter)

1 change: 1 addition & 0 deletions textattack/augmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
EasyDataAugmenter,
CheckListAugmenter,
DeletionAugmenter,
CLAREAugmenter,
)
69 changes: 66 additions & 3 deletions textattack/augmentation/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
RepeatModification,
StopwordModification,
)
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder

from . import Augmenter

Expand All @@ -36,9 +37,7 @@ class EasyDataAugmenter(Augmenter):
"""

def __init__(self, pct_words_to_swap=0.1, transformations_per_example=4):
assert (
pct_words_to_swap >= 0.0 and pct_words_to_swap <= 1.0
), "pct_words_to_swap must be in [0., 1.]"
assert 0.0 <= pct_words_to_swap <= 1.0, "pct_words_to_swap must be in [0., 1.]"
assert (
transformations_per_example > 0
), "transformations_per_example must be a positive integer"
Expand Down Expand Up @@ -185,3 +184,67 @@ def __init__(self, **kwargs):
constraints = [DEFAULT_CONSTRAINTS[0]]

super().__init__(transformation, constraints=constraints, **kwargs)


class CLAREAugmenter(Augmenter):
"""Li, Zhang, Peng, Chen, Brockett, Sun, Dolan.

"Contextualized Perturbation for Textual Adversarial Attack" (Li et al., 2020)

https://arxiv.org/abs/2009.07502

CLARE builds on a pre-trained masked language model and modifies the inputs in a contextaware manner.
We propose three contextualized perturbations, Replace, Insert and Merge, allowing for generating outputs
of varied lengths.
"""

def __init__(
self, model="distilroberta-base", tokenizer="distilroberta-base", **kwargs
):
import transformers

from textattack.transformations import (
CompositeTransformation,
WordInsertionMaskedLM,
WordMergeMaskedLM,
WordSwapMaskedLM,
)

shared_masked_lm = transformers.AutoModelForCausalLM.from_pretrained(model)
shared_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer)

transformation = CompositeTransformation(
[
WordSwapMaskedLM(
method="bae",
masked_language_model=shared_masked_lm,
tokenizer=shared_tokenizer,
max_candidates=50,
min_confidence=5e-4,
),
WordInsertionMaskedLM(
masked_language_model=shared_masked_lm,
tokenizer=shared_tokenizer,
max_candidates=50,
min_confidence=0.0,
),
WordMergeMaskedLM(
masked_language_model=shared_masked_lm,
tokenizer=shared_tokenizer,
max_candidates=50,
min_confidence=5e-3,
),
]
)

use_constraint = UniversalSentenceEncoder(
threshold=0.7,
metric="cosine",
compare_against_original=True,
window_size=15,
skip_text_shorter_than_window=True,
)

constraints = DEFAULT_CONSTRAINTS + [use_constraint]

super().__init__(transformation, constraints=constraints, **kwargs)
1 change: 1 addition & 0 deletions textattack/commands/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"charswap": "textattack.augmentation.CharSwapAugmenter",
"eda": "textattack.augmentation.EasyDataAugmenter",
"checklist": "textattack.augmentation.CheckListAugmenter",
"clare": "textattack.augmentation.CLAREAugmenter",
}


Expand Down