From 53be760e0b52d4bd0d1fc756849d2a4ce8574093 Mon Sep 17 00:00:00 2001 From: Jin Yong Yoo Date: Sat, 11 Sep 2021 23:07:24 -0400 Subject: [PATCH] [CODE] Add new attack recipe A2T --- README.md | 9 +++ textattack/attack_args.py | 1 + textattack/attack_recipes/__init__.py | 1 + textattack/attack_recipes/a2t_yoo_2021.py | 68 +++++++++++++++++++++++ 4 files changed, 79 insertions(+) create mode 100644 textattack/attack_recipes/a2t_yoo_2021.py diff --git a/README.md b/README.md index 87c5a8455..34c5004de 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,15 @@ To run an attack recipe: `textattack attack --recipe [recipe_name]`
Attacks on classification tasks, like sentiment classification and entailment:
+ +a2t + +Untargeted {Classification, Entailment} +Percentage of words perturbed, Word embedding distance, DistilBERT sentence encoding cosine similarity, part-of-speech consistency +Counter-fitted word embedding swap (or) BERT Masked Token Prediction +Greedy-WIR (gradient) +from (["Towards Improving Adversarial Training of NLP Models" (Yoo et al., 2021)](https://arxiv.org/abs/2109.00544)) + alzantot Untargeted {Classification, Entailment} diff --git a/textattack/attack_args.py b/textattack/attack_args.py index 72f5bdec1..ef510fdda 100644 --- a/textattack/attack_args.py +++ b/textattack/attack_args.py @@ -35,6 +35,7 @@ "pso": "textattack.attack_recipes.PSOZang2020", "checklist": "textattack.attack_recipes.CheckList2020", "clare": "textattack.attack_recipes.CLARE2020", + "a2t": "textattack.attack_recipes.A2TYoo2021", } diff --git a/textattack/attack_recipes/__init__.py b/textattack/attack_recipes/__init__.py index 882eb6f5b..853a2732f 100644 --- a/textattack/attack_recipes/__init__.py +++ b/textattack/attack_recipes/__init__.py @@ -20,6 +20,7 @@ from .attack_recipe import AttackRecipe +from .a2t_yoo_2021 import A2TYoo2021 from .bae_garg_2019 import BAEGarg2019 from .bert_attack_li_2020 import BERTAttackLi2020 from .genetic_algorithm_alzantot_2018 import GeneticAlgorithmAlzantot2018 diff --git a/textattack/attack_recipes/a2t_yoo_2021.py b/textattack/attack_recipes/a2t_yoo_2021.py new file mode 100644 index 000000000..82c813be1 --- /dev/null +++ b/textattack/attack_recipes/a2t_yoo_2021.py @@ -0,0 +1,68 @@ +from textattack import Attack +from textattack.constraints.grammaticality import PartOfSpeech +from textattack.constraints.pre_transformation import ( + InputColumnModification, + MaxModificationRate, + RepeatModification, + StopwordModification, +) +from textattack.constraints.semantics import WordEmbeddingDistance +from textattack.constraints.semantics.sentence_encoders import BERT +from textattack.goal_functions import UntargetedClassification +from textattack.search_methods import GreedyWordSwapWIR +from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM + +from .attack_recipe import AttackRecipe + + +class A2TYoo2021(AttackRecipe): + """Towards Improving Adversarial Training of NLP Models. + + (Yoo et al., 2021) + + https://arxiv.org/abs/2109.00544 + """ + + @staticmethod + def build(model_wrapper, mlm=False): + """Build attack recipe. + + Args: + model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`): + Model wrapper containing both the model and the tokenizer. + mlm (:obj:`bool`, `optional`, defaults to :obj:`False`): + If :obj:`True`, load `A2T-MLM` attack. Otherwise, load regular `A2T` attack. + + Returns: + :class:`~textattack.Attack`: A2T attack. + """ + constraints = [RepeatModification(), StopwordModification()] + input_column_modification = InputColumnModification( + ["premise", "hypothesis"], {"premise"} + ) + constraints.append(input_column_modification) + constraints.append(PartOfSpeech(allow_verb_noun_swap=False)) + constraints.append(MaxModificationRate(max_rate=0.1, min_threshold=4)) + sent_encoder = BERT( + model_name="stsb-distilbert-base", threshold=0.9, metric="cosine" + ) + constraints.append(sent_encoder) + + if mlm: + transformation = transformation = WordSwapMaskedLM( + method="bae", max_candidates=20, min_confidence=0.0, batch_size=16 + ) + else: + transformation = WordSwapEmbedding(max_candidates=20) + constraints.append(WordEmbeddingDistance(min_cos_sim=0.8)) + + # + # Goal is untargeted classification + # + goal_function = UntargetedClassification(model_wrapper, model_batch_size=32) + # + # Greedily swap words with "Word Importance Ranking". + # + search_method = GreedyWordSwapWIR(wir_method="gradient") + + return Attack(goal_function, constraints, transformation, search_method)