Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions textattack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,6 @@
TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.
"""

from .attack_args import AttackArgs, CommandLineAttackArgs
from .augment_args import AugmenterArgs
from .dataset_args import DatasetArgs
from .model_args import ModelArgs
from .training_args import TrainingArgs, CommandLineTrainingArgs
from .attack import Attack
from .attacker import Attacker
from .trainer import Trainer

from . import (
attack_recipes,
attack_results,
Expand All @@ -33,5 +24,13 @@
shared,
transformations,
)
from .attack import Attack
from .attack_args import AttackArgs, CommandLineAttackArgs
from .attacker import Attacker
from .augment_args import AugmenterArgs
from .dataset_args import DatasetArgs
from .model_args import ModelArgs
from .trainer import Trainer
from .training_args import CommandLineTrainingArgs, TrainingArgs

name = "textattack"
9 changes: 9 additions & 0 deletions textattack/attack_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ class AttackArgs:
Disable displaying individual attack results to stdout.
silent (:obj:`bool`, `optional`, defaults to :obj:`False`):
Disable all logging (except for errors). This is stronger than :obj:`disable_stdout`.
enable_advance_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`):
Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.
"""

num_examples: int = 10
Expand All @@ -193,6 +195,7 @@ class AttackArgs:
log_to_wandb: str = None
disable_stdout: bool = False
silent: bool = False
enable_advance_metrics: bool = False

def __post_init__(self):
if self.num_successful_examples:
Expand Down Expand Up @@ -350,6 +353,12 @@ def _add_parser_args(cls, parser):
default=default_obj.silent,
help="Disable all logging",
)
parser.add_argument(
"--enable-advance-metrics",
action="store_true",
default=default_obj.enable_advance_metrics,
help="Enable advance metric calculations",
)

return parser

Expand Down
8 changes: 8 additions & 0 deletions textattack/attacker.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ def _attack(self):
# Enable summary stdout
if not self.attack_args.silent and self.attack_args.disable_stdout:
self.attack_log_manager.enable_stdout()

if self.attack_args.enable_advance_metrics:
self.attack_log_manager.enable_advance_metrics = True

self.attack_log_manager.log_summary()
self.attack_log_manager.flush()
print()
Expand Down Expand Up @@ -390,6 +394,10 @@ def _attack_parallel(self):
# Enable summary stdout.
if not self.attack_args.silent and self.attack_args.disable_stdout:
self.attack_log_manager.enable_stdout()

if self.attack_args.enable_advance_metrics:
self.attack_log_manager.enable_advance_metrics = True

self.attack_log_manager.log_summary()
self.attack_log_manager.flush()
print()
Expand Down
146 changes: 58 additions & 88 deletions textattack/loggers/attack_log_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
========================
"""

import numpy as np

from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.metrics.attack_metrics import (
AttackQueries,
AttackSuccessRate,
WordsPerturbed,
)
from textattack.metrics.quality_metrics import Perplexity

from . import CSVLogger, FileLogger, VisdomLogger, WeightsAndBiasesLogger

Expand All @@ -16,6 +19,7 @@ class AttackLogManager:
def __init__(self):
self.loggers = []
self.results = []
self.enable_advance_metrics = False

def enable_stdout(self):
self.loggers.append(FileLogger(stdout=True))
Expand Down Expand Up @@ -72,103 +76,69 @@ def log_summary(self):
total_attacks = len(self.results)
if total_attacks == 0:
return
# Count things about attacks.
all_num_words = np.zeros(len(self.results))
perturbed_word_percentages = np.zeros(len(self.results))
num_words_changed_until_success = np.zeros(
2 ** 16
) # @ TODO: be smarter about this
failed_attacks = 0
skipped_attacks = 0
successful_attacks = 0
max_words_changed = 0
for i, result in enumerate(self.results):
all_num_words[i] = len(result.original_result.attacked_text.words)
if isinstance(result, FailedAttackResult):
failed_attacks += 1
continue
elif isinstance(result, SkippedAttackResult):
skipped_attacks += 1
continue
else:
successful_attacks += 1
num_words_changed = result.original_result.attacked_text.words_diff_num(
result.perturbed_result.attacked_text
)
# num_words_changed = len(
# result.original_result.attacked_text.all_words_diff(
# result.perturbed_result.attacked_text
# )
# )
num_words_changed_until_success[num_words_changed - 1] += 1
max_words_changed = max(
max_words_changed or num_words_changed, num_words_changed
)
if len(result.original_result.attacked_text.words) > 0:
perturbed_word_percentage = (
num_words_changed
* 100.0
/ len(result.original_result.attacked_text.words)
)
else:
perturbed_word_percentage = 0
perturbed_word_percentages[i] = perturbed_word_percentage

# Original classifier success rate on these samples.
original_accuracy = (total_attacks - skipped_attacks) * 100.0 / (total_attacks)
original_accuracy = str(round(original_accuracy, 2)) + "%"

# New classifier success rate on these samples.
accuracy_under_attack = (failed_attacks) * 100.0 / (total_attacks)
accuracy_under_attack = str(round(accuracy_under_attack, 2)) + "%"

# Attack success rate.
if successful_attacks + failed_attacks == 0:
attack_success_rate = 0
else:
attack_success_rate = (
successful_attacks * 100.0 / (successful_attacks + failed_attacks)
)
attack_success_rate = str(round(attack_success_rate, 2)) + "%"

perturbed_word_percentages = perturbed_word_percentages[
perturbed_word_percentages > 0
]
average_perc_words_perturbed = perturbed_word_percentages.mean()
average_perc_words_perturbed = str(round(average_perc_words_perturbed, 2)) + "%"

average_num_words = all_num_words.mean()
average_num_words = str(round(average_num_words, 2))
# Default metrics - calculated on every attack
attack_success_stats = AttackSuccessRate(self.results).calculate()
words_perturbed_stats = WordsPerturbed(self.results).calculate()
attack_query_stats = AttackQueries(self.results).calculate()

# @TODO generate this table based on user input - each column in specific class
# Example to demonstrate:
# summary_table_rows = attack_success_stats.display_row() + words_perturbed_stats.display_row() + ...
summary_table_rows = [
["Number of successful attacks:", str(successful_attacks)],
["Number of failed attacks:", str(failed_attacks)],
["Number of skipped attacks:", str(skipped_attacks)],
["Original accuracy:", original_accuracy],
["Accuracy under attack:", accuracy_under_attack],
["Attack success rate:", attack_success_rate],
["Average perturbed word %:", average_perc_words_perturbed],
["Average num. words per input:", average_num_words],
[
"Number of successful attacks:",
attack_success_stats["successful_attacks"],
],
["Number of failed attacks:", attack_success_stats["failed_attacks"]],
["Number of skipped attacks:", attack_success_stats["skipped_attacks"]],
[
"Original accuracy:",
str(attack_success_stats["original_accuracy"]) + "%",
],
[
"Accuracy under attack:",
str(attack_success_stats["attack_accuracy_perc"]) + "%",
],
[
"Attack success rate:",
str(attack_success_stats["attack_success_rate"]) + "%",
],
[
"Average perturbed word %:",
str(words_perturbed_stats["avg_word_perturbed_perc"]) + "%",
],
[
"Average num. words per input:",
words_perturbed_stats["avg_word_perturbed"],
],
]

num_queries = np.array(
[
r.num_queries
for r in self.results
if not isinstance(r, SkippedAttackResult)
]
summary_table_rows.append(
["Avg num queries:", attack_query_stats["avg_num_queries"]]
)
avg_num_queries = num_queries.mean()
avg_num_queries = str(round(avg_num_queries, 2))
summary_table_rows.append(["Avg num queries:", avg_num_queries])

if self.enable_advance_metrics:
perplexity_stats = Perplexity(self.results).calculate()

summary_table_rows.append(
[
"Avg Original Perplexity:",
perplexity_stats["avg_original_perplexity"],
]
)
summary_table_rows.append(
["Avg Attack Perplexity:", perplexity_stats["avg_attack_perplexity"]]
)

self.log_summary_rows(
summary_table_rows, "Attack Results", "attack_results_summary"
)
# Show histogram of words changed.
numbins = max(max_words_changed, 10)
numbins = max(words_perturbed_stats["max_words_changed"], 10)
for logger in self.loggers:
logger.log_hist(
num_words_changed_until_success[:numbins],
words_perturbed_stats["num_words_changed_until_success"][:numbins],
numbins=numbins,
title="Num Words Perturbed",
window_id="num_words_perturbed",
Expand Down
10 changes: 10 additions & 0 deletions textattack/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
"""

from .metric import Metric

from .attack_metrics import AttackSuccessRate
from .attack_metrics import WordsPerturbed
from .attack_metrics import AttackQueries

from .quality_metrics import Perplexity
13 changes: 13 additions & 0 deletions textattack/metrics/attack_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""

attack_metrics:
======================

TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display.


"""

from .attack_queries import AttackQueries
from .attack_success_rate import AttackSuccessRate
from .words_perturbed import WordsPerturbed
35 changes: 35 additions & 0 deletions textattack/metrics/attack_metrics/attack_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np

from textattack.attack_results import SkippedAttackResult
from textattack.metrics import Metric


class AttackQueries(Metric):
"""Calculates all metrics related to number of queries in an attack

Args:
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`):
Attack results for each instance in dataset
"""

def __init__(self, results):
self.results = results

self.all_metrics = {}

def calculate(self):
self.num_queries = np.array(
[
r.num_queries
for r in self.results
if not isinstance(r, SkippedAttackResult)
]
)
self.all_metrics["avg_num_queries"] = self.avg_num_queries()

return self.all_metrics

def avg_num_queries(self):
avg_num_queries = self.num_queries.mean()
avg_num_queries = round(avg_num_queries, 2)
return avg_num_queries
67 changes: 67 additions & 0 deletions textattack/metrics/attack_metrics/attack_success_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.metrics import Metric


class AttackSuccessRate(Metric):
"""Calculates all metrics related to number of succesful, failed and skipped results in an attack

Args:
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`):
Attack results for each instance in dataset
"""

def __init__(self, results):
self.results = results
self.failed_attacks = 0
self.skipped_attacks = 0
self.successful_attacks = 0
self.total_attacks = len(self.results)

self.all_metrics = {}

def calculate(self):
for i, result in enumerate(self.results):
if isinstance(result, FailedAttackResult):
self.failed_attacks += 1
continue
elif isinstance(result, SkippedAttackResult):
self.skipped_attacks += 1
continue
else:
self.successful_attacks += 1

# Calculated numbers
self.all_metrics["successful_attacks"] = self.successful_attacks
self.all_metrics["failed_attacks"] = self.failed_attacks
self.all_metrics["skipped_attacks"] = self.skipped_attacks

# Percentages wrt the calculations
self.all_metrics["original_accuracy"] = self.original_accuracy_perc()
self.all_metrics["attack_accuracy_perc"] = self.attack_accuracy_perc()
self.all_metrics["attack_success_rate"] = self.attack_success_rate_perc()

return self.all_metrics

def original_accuracy_perc(self):
original_accuracy = (
(self.total_attacks - self.skipped_attacks) * 100.0 / (self.total_attacks)
)
original_accuracy = round(original_accuracy, 2)
return original_accuracy

def attack_accuracy_perc(self):
accuracy_under_attack = (self.failed_attacks) * 100.0 / (self.total_attacks)
accuracy_under_attack = round(accuracy_under_attack, 2)
return accuracy_under_attack

def attack_success_rate_perc(self):
if self.successful_attacks + self.failed_attacks == 0:
attack_success_rate = 0
else:
attack_success_rate = (
self.successful_attacks
* 100.0
/ (self.successful_attacks + self.failed_attacks)
)
attack_success_rate = round(attack_success_rate, 2)
return attack_success_rate
Loading