From ca5f15ce90fea7b0d27463ff4afd403c329ca7a1 Mon Sep 17 00:00:00 2001 From: Hanyu Liu Date: Thu, 17 Dec 2020 17:20:23 -0500 Subject: [PATCH 01/10] add clare augmentation recipe --- textattack/augmentation/__init__.py | 2 + textattack/augmentation/recipes.py | 61 +++++++++++++++++++++++++++++ textattack/commands/augment.py | 1 + 3 files changed, 64 insertions(+) diff --git a/textattack/augmentation/__init__.py b/textattack/augmentation/__init__.py index 628ea10a2..3b648cbc3 100644 --- a/textattack/augmentation/__init__.py +++ b/textattack/augmentation/__init__.py @@ -13,4 +13,6 @@ EasyDataAugmenter, CheckListAugmenter, DeletionAugmenter, + CLAREAugmenter ) + diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index 9bc6497b4..e5355fcf4 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -185,3 +185,64 @@ def __init__(self, **kwargs): constraints = [DEFAULT_CONSTRAINTS[0]] super().__init__(transformation, constraints=constraints, **kwargs) + + +class CLAREAugmenter(Augmenter): + """Li, Zhang, Peng, Chen, Brockett, Sun, Dolan. + + "Contextualized Perturbation for Textual Adversarial Attack" (Li et al., 2020) + + https://arxiv.org/abs/2009.07502 + + This method uses greedy search with replace, merge, and insertion transformations that leverage a + pretrained language model. It also uses USE similarity constraint. + """ + + def __init__(self, model="distilroberta-base", tokenizer="distilroberta-base", **kwargs): + from textattack.transformations import ( + CompositeTransformation, + WordInsertionMaskedLM, + WordMergeMaskedLM, + WordSwapMaskedLM, + ) + import transformers + from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder + + shared_masked_lm = transformers.AutoModelForCausalLM.from_pretrained(model) + shared_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) + + transformation = CompositeTransformation( + [ + WordSwapMaskedLM( + method="bae", + masked_language_model=shared_masked_lm, + tokenizer=shared_tokenizer, + max_candidates=20, + min_confidence=5e-4, + ), + WordInsertionMaskedLM( + masked_language_model=shared_masked_lm, + tokenizer=shared_tokenizer, + max_candidates=20, + min_confidence=0.0, + ), + WordMergeMaskedLM( + masked_language_model=shared_masked_lm, + tokenizer=shared_tokenizer, + max_candidates=20, + min_confidence=5e-3, + ), + ] + ) + + use_constraint = UniversalSentenceEncoder( + threshold=0.7, + metric="cosine", + compare_against_original=True, + window_size=15, + skip_text_shorter_than_window=True, + ) + + constraints = DEFAULT_CONSTRAINTS + [use_constraint] + + super().__init__(transformation, constraints=constraints, **kwargs) diff --git a/textattack/commands/augment.py b/textattack/commands/augment.py index 4dab1ba99..2ea297099 100644 --- a/textattack/commands/augment.py +++ b/textattack/commands/augment.py @@ -21,6 +21,7 @@ "charswap": "textattack.augmentation.CharSwapAugmenter", "eda": "textattack.augmentation.EasyDataAugmenter", "checklist": "textattack.augmentation.CheckListAugmenter", + "clare": "textattack.augmentation.CLAREAugmenter" } From adae04ac1a80865ae96e9cdf51d1e7c57a8ce495 Mon Sep 17 00:00:00 2001 From: Hanyu Liu Date: Thu, 17 Dec 2020 17:27:59 -0500 Subject: [PATCH 02/10] formating --- textattack/augmentation/__init__.py | 3 +-- textattack/augmentation/recipes.py | 13 +++++++++---- textattack/commands/augment.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/textattack/augmentation/__init__.py b/textattack/augmentation/__init__.py index 3b648cbc3..593f946d6 100644 --- a/textattack/augmentation/__init__.py +++ b/textattack/augmentation/__init__.py @@ -13,6 +13,5 @@ EasyDataAugmenter, CheckListAugmenter, DeletionAugmenter, - CLAREAugmenter + CLAREAugmenter, ) - diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index e5355fcf4..e9108ad6b 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -194,11 +194,14 @@ class CLAREAugmenter(Augmenter): https://arxiv.org/abs/2009.07502 - This method uses greedy search with replace, merge, and insertion transformations that leverage a - pretrained language model. It also uses USE similarity constraint. + CLARE builds on a pre-trained masked language model and modifies the inputs in a contextaware manner. + We propose three contextualized perturbations, Replace, Insert and Merge, allowing for generating outputs + of varied lengths. """ - def __init__(self, model="distilroberta-base", tokenizer="distilroberta-base", **kwargs): + def __init__( + self, model="distilroberta-base", tokenizer="distilroberta-base", **kwargs + ): from textattack.transformations import ( CompositeTransformation, WordInsertionMaskedLM, @@ -206,7 +209,9 @@ def __init__(self, model="distilroberta-base", tokenizer="distilroberta-base", * WordSwapMaskedLM, ) import transformers - from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder + from textattack.constraints.semantics.sentence_encoders import ( + UniversalSentenceEncoder, + ) shared_masked_lm = transformers.AutoModelForCausalLM.from_pretrained(model) shared_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) diff --git a/textattack/commands/augment.py b/textattack/commands/augment.py index 2ea297099..1ca6dbb74 100644 --- a/textattack/commands/augment.py +++ b/textattack/commands/augment.py @@ -21,7 +21,7 @@ "charswap": "textattack.augmentation.CharSwapAugmenter", "eda": "textattack.augmentation.EasyDataAugmenter", "checklist": "textattack.augmentation.CheckListAugmenter", - "clare": "textattack.augmentation.CLAREAugmenter" + "clare": "textattack.augmentation.CLAREAugmenter", } From 77ae89495777e0b33949a13460cc64fd89266042 Mon Sep 17 00:00:00 2001 From: Hanyu-Liu-123 <65825971+Hanyu-Liu-123@users.noreply.github.com> Date: Thu, 17 Dec 2020 17:35:02 -0500 Subject: [PATCH 03/10] Update README.md --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index cfda3f413..7a0d5655e 100644 --- a/README.md +++ b/README.md @@ -301,6 +301,7 @@ for data augmentation: - `textattack.CharSwapAugmenter` augments text by substituting, deleting, inserting, and swapping adjacent characters - `textattack.EasyDataAugmenter` augments text with a combination of word insertions, substitutions and deletions. - `textattack.CheckListAugmenter` augments text by contraction/extension and by substituting names, locations, numbers. +- `textattack.CLAREAugmenter` augments text by replacing, inserting, and merging with a pre-trained masked language model. #### Augmentation Command-Line Interface The easiest way to use our data augmentation tools is with `textattack augment `. `textattack augment` @@ -323,6 +324,9 @@ The command `textattack augment --csv examples.csv --input-column text --recipe will augment the `text` column by altering 10% of each example's words, generating twice as many augmentations as original inputs, and exclude the original inputs from the output CSV. (All of this will be saved to `augment.csv` by default.) +> **Tip:** Just as running attacks interactively, you can also pass `--interactive` to augment samples inputted by the user to quickly try out different augmentation recipes! + + After augmentation, here are the contents of `augment.csv`: ```csv text,label From 48a04b496951866d345687190523cd8cb7cf8a43 Mon Sep 17 00:00:00 2001 From: Hanyu Liu Date: Thu, 17 Dec 2020 17:36:05 -0500 Subject: [PATCH 04/10] Update recipes.py --- textattack/augmentation/recipes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index e9108ad6b..b92f5aa1f 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -222,19 +222,19 @@ def __init__( method="bae", masked_language_model=shared_masked_lm, tokenizer=shared_tokenizer, - max_candidates=20, + max_candidates=50, min_confidence=5e-4, ), WordInsertionMaskedLM( masked_language_model=shared_masked_lm, tokenizer=shared_tokenizer, - max_candidates=20, + max_candidates=50, min_confidence=0.0, ), WordMergeMaskedLM( masked_language_model=shared_masked_lm, tokenizer=shared_tokenizer, - max_candidates=20, + max_candidates=50, min_confidence=5e-3, ), ] From a6b3d8c16dce120699d030524ad1e261936049a6 Mon Sep 17 00:00:00 2001 From: Hanyu Liu Date: Fri, 18 Dec 2020 00:36:49 -0500 Subject: [PATCH 05/10] Fix errors --- tests/sample_outputs/list_augmentation_recipes.txt | 3 +++ textattack/augmentation/recipes.py | 8 ++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/sample_outputs/list_augmentation_recipes.txt b/tests/sample_outputs/list_augmentation_recipes.txt index cbb2e83c5..b79e03a66 100644 --- a/tests/sample_outputs/list_augmentation_recipes.txt +++ b/tests/sample_outputs/list_augmentation_recipes.txt @@ -1,5 +1,8 @@ charswap (textattack.augmentation.CharSwapAugmenter) checklist (textattack.augmentation.CheckListAugmenter) +wordnet (textattack.augmentation.CLAREAugmenter) eda (textattack.augmentation.EasyDataAugmenter) embedding (textattack.augmentation.EmbeddingAugmenter) wordnet (textattack.augmentation.WordNetAugmenter) + + diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index b92f5aa1f..57f33d389 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -202,16 +202,16 @@ class CLAREAugmenter(Augmenter): def __init__( self, model="distilroberta-base", tokenizer="distilroberta-base", **kwargs ): + import transformers + from textattack.constraints.semantics.sentence_encoders import ( + UniversalSentenceEncoder, + ) from textattack.transformations import ( CompositeTransformation, WordInsertionMaskedLM, WordMergeMaskedLM, WordSwapMaskedLM, ) - import transformers - from textattack.constraints.semantics.sentence_encoders import ( - UniversalSentenceEncoder, - ) shared_masked_lm = transformers.AutoModelForCausalLM.from_pretrained(model) shared_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) From 1565374241f162187d5afd7b2431aceea3a2e69c Mon Sep 17 00:00:00 2001 From: Hanyu Liu Date: Fri, 18 Dec 2020 00:44:01 -0500 Subject: [PATCH 06/10] Update recipes.py --- textattack/augmentation/recipes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index 57f33d389..7845a4b31 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -203,9 +203,11 @@ def __init__( self, model="distilroberta-base", tokenizer="distilroberta-base", **kwargs ): import transformers + from textattack.constraints.semantics.sentence_encoders import ( UniversalSentenceEncoder, ) + from textattack.transformations import ( CompositeTransformation, WordInsertionMaskedLM, From 1909a874ee98f8151ac12e476c72eaddf331ef12 Mon Sep 17 00:00:00 2001 From: Hanyu Liu Date: Fri, 18 Dec 2020 00:55:57 -0500 Subject: [PATCH 07/10] Update recipes.py --- textattack/augmentation/recipes.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index 7845a4b31..9f68b5285 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -7,6 +7,8 @@ """ import random +from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder + from textattack.constraints.pre_transformation import ( RepeatModification, StopwordModification, @@ -37,7 +39,7 @@ class EasyDataAugmenter(Augmenter): def __init__(self, pct_words_to_swap=0.1, transformations_per_example=4): assert ( - pct_words_to_swap >= 0.0 and pct_words_to_swap <= 1.0 + 0.0 <= pct_words_to_swap <= 1.0 ), "pct_words_to_swap must be in [0., 1.]" assert ( transformations_per_example > 0 @@ -204,10 +206,6 @@ def __init__( ): import transformers - from textattack.constraints.semantics.sentence_encoders import ( - UniversalSentenceEncoder, - ) - from textattack.transformations import ( CompositeTransformation, WordInsertionMaskedLM, From cb2237e9b09f468f9fe41f86b77ffd798f7fdab2 Mon Sep 17 00:00:00 2001 From: Hanyu Liu Date: Fri, 18 Dec 2020 00:56:47 -0500 Subject: [PATCH 08/10] Update recipes.py --- textattack/augmentation/recipes.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index 9f68b5285..d815e98c1 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -38,9 +38,7 @@ class EasyDataAugmenter(Augmenter): """ def __init__(self, pct_words_to_swap=0.1, transformations_per_example=4): - assert ( - 0.0 <= pct_words_to_swap <= 1.0 - ), "pct_words_to_swap must be in [0., 1.]" + assert 0.0 <= pct_words_to_swap <= 1.0, "pct_words_to_swap must be in [0., 1.]" assert ( transformations_per_example > 0 ), "transformations_per_example must be a positive integer" From b619607255ddcf3df6616c76112b52bb02494df2 Mon Sep 17 00:00:00 2001 From: Hanyu Liu Date: Fri, 18 Dec 2020 01:32:40 -0500 Subject: [PATCH 09/10] Update recipes.py --- textattack/augmentation/recipes.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/textattack/augmentation/recipes.py b/textattack/augmentation/recipes.py index d815e98c1..ec009f822 100644 --- a/textattack/augmentation/recipes.py +++ b/textattack/augmentation/recipes.py @@ -7,12 +7,11 @@ """ import random -from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder - from textattack.constraints.pre_transformation import ( RepeatModification, StopwordModification, ) +from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder from . import Augmenter From eb51ca31bad720ee6e6e0474ccf49a8a75637d1d Mon Sep 17 00:00:00 2001 From: Hanyu Liu Date: Fri, 18 Dec 2020 01:41:14 -0500 Subject: [PATCH 10/10] more fixing... --- tests/sample_outputs/list_augmentation_recipes.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/sample_outputs/list_augmentation_recipes.txt b/tests/sample_outputs/list_augmentation_recipes.txt index b79e03a66..3fa3f9351 100644 --- a/tests/sample_outputs/list_augmentation_recipes.txt +++ b/tests/sample_outputs/list_augmentation_recipes.txt @@ -1,8 +1,7 @@ charswap (textattack.augmentation.CharSwapAugmenter) checklist (textattack.augmentation.CheckListAugmenter) -wordnet (textattack.augmentation.CLAREAugmenter) +clare (textattack.augmentation.CLAREAugmenter) eda (textattack.augmentation.EasyDataAugmenter) embedding (textattack.augmentation.EmbeddingAugmenter) wordnet (textattack.augmentation.WordNetAugmenter) -