diff --git a/README.md b/README.md index cfda3f413..7a0d5655e 100644 --- a/README.md +++ b/README.md @@ -301,6 +301,7 @@ for data augmentation: - `textattack.CharSwapAugmenter` augments text by substituting, deleting, inserting, and swapping adjacent characters - `textattack.EasyDataAugmenter` augments text with a combination of word insertions, substitutions and deletions. - `textattack.CheckListAugmenter` augments text by contraction/extension and by substituting names, locations, numbers. +- `textattack.CLAREAugmenter` augments text by replacing, inserting, and merging with a pre-trained masked language model. #### Augmentation Command-Line Interface The easiest way to use our data augmentation tools is with `textattack augment `. `textattack augment` @@ -323,6 +324,9 @@ The command `textattack augment --csv examples.csv --input-column text --recipe will augment the `text` column by altering 10% of each example's words, generating twice as many augmentations as original inputs, and exclude the original inputs from the output CSV. (All of this will be saved to `augment.csv` by default.) +> **Tip:** Just as running attacks interactively, you can also pass `--interactive` to augment samples inputted by the user to quickly try out different augmentation recipes! + + After augmentation, here are the contents of `augment.csv`: ```csv text,label diff --git a/tests/sample_outputs/list_augmentation_recipes.txt b/tests/sample_outputs/list_augmentation_recipes.txt index cbb2e83c5..3fa3f9351 100644 --- a/tests/sample_outputs/list_augmentation_recipes.txt +++ b/tests/sample_outputs/list_augmentation_recipes.txt @@ -1,5 +1,7 @@ charswap (textattack.augmentation.CharSwapAugmenter) checklist (textattack.augmentation.CheckListAugmenter) +clare (textattack.augmentation.CLAREAugmenter) eda (textattack.augmentation.EasyDataAugmenter) embedding (textattack.augmentation.EmbeddingAugmenter) wordnet (textattack.augmentation.WordNetAugmenter) + diff --git a/textattack/augmentation/__init__.py b/textattack/augmentation/__init__.py index 628ea10a2..593f946d6 100644 --- a/textattack/augmentation/__init__.py +++ b/textattack/augmentation/__init__.py @@ -13,4 +13,5 @@ EasyDataAugmenter, CheckListAugmenter, DeletionAugmenter, + CLAREAugmenter, ) diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index 9bc6497b4..ec009f822 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -11,6 +11,7 @@ RepeatModification, StopwordModification, ) +from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder from . import Augmenter @@ -36,9 +37,7 @@ class EasyDataAugmenter(Augmenter): """ def __init__(self, pct_words_to_swap=0.1, transformations_per_example=4): - assert ( - pct_words_to_swap >= 0.0 and pct_words_to_swap <= 1.0 - ), "pct_words_to_swap must be in [0., 1.]" + assert 0.0 <= pct_words_to_swap <= 1.0, "pct_words_to_swap must be in [0., 1.]" assert ( transformations_per_example > 0 ), "transformations_per_example must be a positive integer" @@ -185,3 +184,67 @@ def __init__(self, **kwargs): constraints = [DEFAULT_CONSTRAINTS[0]] super().__init__(transformation, constraints=constraints, **kwargs) + + +class CLAREAugmenter(Augmenter): + """Li, Zhang, Peng, Chen, Brockett, Sun, Dolan. + + "Contextualized Perturbation for Textual Adversarial Attack" (Li et al., 2020) + + https://arxiv.org/abs/2009.07502 + + CLARE builds on a pre-trained masked language model and modifies the inputs in a contextaware manner. + We propose three contextualized perturbations, Replace, Insert and Merge, allowing for generating outputs + of varied lengths. + """ + + def __init__( + self, model="distilroberta-base", tokenizer="distilroberta-base", **kwargs + ): + import transformers + + from textattack.transformations import ( + CompositeTransformation, + WordInsertionMaskedLM, + WordMergeMaskedLM, + WordSwapMaskedLM, + ) + + shared_masked_lm = transformers.AutoModelForCausalLM.from_pretrained(model) + shared_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) + + transformation = CompositeTransformation( + [ + WordSwapMaskedLM( + method="bae", + masked_language_model=shared_masked_lm, + tokenizer=shared_tokenizer, + max_candidates=50, + min_confidence=5e-4, + ), + WordInsertionMaskedLM( + masked_language_model=shared_masked_lm, + tokenizer=shared_tokenizer, + max_candidates=50, + min_confidence=0.0, + ), + WordMergeMaskedLM( + masked_language_model=shared_masked_lm, + tokenizer=shared_tokenizer, + max_candidates=50, + min_confidence=5e-3, + ), + ] + ) + + use_constraint = UniversalSentenceEncoder( + threshold=0.7, + metric="cosine", + compare_against_original=True, + window_size=15, + skip_text_shorter_than_window=True, + ) + + constraints = DEFAULT_CONSTRAINTS + [use_constraint] + + super().__init__(transformation, constraints=constraints, **kwargs) diff --git a/textattack/commands/augment.py b/textattack/commands/augment.py index 4dab1ba99..1ca6dbb74 100644 --- a/textattack/commands/augment.py +++ b/textattack/commands/augment.py @@ -21,6 +21,7 @@ "charswap": "textattack.augmentation.CharSwapAugmenter", "eda": "textattack.augmentation.EasyDataAugmenter", "checklist": "textattack.augmentation.CheckListAugmenter", + "clare": "textattack.augmentation.CLAREAugmenter", }