Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ To run an attack recipe: `textattack attack --recipe [recipe_name]`
<tbody>
<tr><td style="text-align: center;" colspan="6"><strong><br>Attacks on classification tasks, like sentiment classification and entailment:<br></strong></td></tr>

<tr>
<td><code>a2t</code>
<span class="citation" data-cites="yoo2021a2t"></span></td>
<td><sub>Untargeted {Classification, Entailment}</sub></td>
<td><sub>Percentage of words perturbed, Word embedding distance, DistilBERT sentence encoding cosine similarity, part-of-speech consistency</sub></td>
<td><sub>Counter-fitted word embedding swap (or) BERT Masked Token Prediction</sub></td>
<td><sub>Greedy-WIR (gradient)</sub></td>
<td ><sub>from (["Towards Improving Adversarial Training of NLP Models" (Yoo et al., 2021)](https://arxiv.org/abs/2109.00544))</sub></td>
</tr>
<tr>
<td><code>alzantot</code> <span class="citation" data-cites="Alzantot2018GeneratingNL Jia2019CertifiedRT"></span></td>
<td><sub>Untargeted {Classification, Entailment}</sub></td>
Expand Down
1 change: 1 addition & 0 deletions textattack/attack_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"pso": "textattack.attack_recipes.PSOZang2020",
"checklist": "textattack.attack_recipes.CheckList2020",
"clare": "textattack.attack_recipes.CLARE2020",
"a2t": "textattack.attack_recipes.A2TYoo2021",
}


Expand Down
1 change: 1 addition & 0 deletions textattack/attack_recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .attack_recipe import AttackRecipe

from .a2t_yoo_2021 import A2TYoo2021
from .bae_garg_2019 import BAEGarg2019
from .bert_attack_li_2020 import BERTAttackLi2020
from .genetic_algorithm_alzantot_2018 import GeneticAlgorithmAlzantot2018
Expand Down
68 changes: 68 additions & 0 deletions textattack/attack_recipes/a2t_yoo_2021.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from textattack import Attack
from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.pre_transformation import (
InputColumnModification,
MaxModificationRate,
RepeatModification,
StopwordModification,
)
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.constraints.semantics.sentence_encoders import BERT
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM

from .attack_recipe import AttackRecipe


class A2TYoo2021(AttackRecipe):
"""Towards Improving Adversarial Training of NLP Models.

(Yoo et al., 2021)

https://arxiv.org/abs/2109.00544
"""

@staticmethod
def build(model_wrapper, mlm=False):
"""Build attack recipe.

Args:
model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
Model wrapper containing both the model and the tokenizer.
mlm (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, load `A2T-MLM` attack. Otherwise, load regular `A2T` attack.

Returns:
:class:`~textattack.Attack`: A2T attack.
"""
constraints = [RepeatModification(), StopwordModification()]
input_column_modification = InputColumnModification(
["premise", "hypothesis"], {"premise"}
)
constraints.append(input_column_modification)
constraints.append(PartOfSpeech(allow_verb_noun_swap=False))
constraints.append(MaxModificationRate(max_rate=0.1, min_threshold=4))
sent_encoder = BERT(
model_name="stsb-distilbert-base", threshold=0.9, metric="cosine"
)
constraints.append(sent_encoder)

if mlm:
transformation = transformation = WordSwapMaskedLM(
method="bae", max_candidates=20, min_confidence=0.0, batch_size=16
)
else:
transformation = WordSwapEmbedding(max_candidates=20)
constraints.append(WordEmbeddingDistance(min_cos_sim=0.8))

#
# Goal is untargeted classification
#
goal_function = UntargetedClassification(model_wrapper, model_batch_size=32)
#
# Greedily swap words with "Word Importance Ranking".
#
search_method = GreedyWordSwapWIR(wir_method="gradient")

return Attack(goal_function, constraints, transformation, search_method)