Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[CODE] Transfer from #475 + new structure
  • Loading branch information
sanchit97 committed Aug 20, 2021
commit c0e2993947094e16b2c072d9e60a2b767d0743a6
5 changes: 5 additions & 0 deletions textattack/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
"""

from .attack_metrics import AttackMetric
# from .quality_metrics import QualityMetric
14 changes: 14 additions & 0 deletions textattack/metrics/attack_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""

attack_metrics:
======================

TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display.


"""

from .attack_metric import AttackMetric
from .attack_queries import AttackQueries
from .attack_success_rate import AttackSuccessRate
from .words_perturbed import WordsPerturbed
25 changes: 25 additions & 0 deletions textattack/metrics/attack_metrics/attack_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
Attack Metrics Class
========================

"""

from abc import ABC, abstractmethod

from textattack.attack_results import AttackResult


class AttackMetric:
"""A metric for evaluating Adversarial Attack candidates."""

@abstractmethod
def __init__(self, results, **kwargs):
"""Creates pre-built :class:`~textattack.AttackMetric` that correspond to
evaluation metrics for adversarial examples.
"""
raise NotImplementedError()

@abstractmethod
def calculate():
""" Abstract function for computing any values which are to be calculated as a whole during initialization"""
raise NotImplementedError
36 changes: 36 additions & 0 deletions textattack/metrics/attack_metrics/attack_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np

from textattack.attack_results import SkippedAttackResult

from .attack_metric import AttackMetric


class AttackQueries(AttackMetric):
"""Calculates all metrics related to number of queries in an attack

Args:
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`):
Attack results for each instance in dataset
"""

def __init__(self, results):
self.results = results

self.all_metrics = {}

def calculate(self):
self.num_queries = np.array(
[
r.num_queries
for r in self.results
if not isinstance(r, SkippedAttackResult)
]
)
self.all_metrics['avg_num_queries'] = self.avg_num_queries()

return self.all_metrics

def avg_num_queries(self):
avg_num_queries = self.num_queries.mean()
avg_num_queries = round(avg_num_queries, 2)
return avg_num_queries
69 changes: 69 additions & 0 deletions textattack/metrics/attack_metrics/attack_success_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from textattack.attack_results import FailedAttackResult, SkippedAttackResult

from .attack_metric import AttackMetric


class AttackSuccessRate(AttackMetric):
"""Calculates all metrics related to number of succesful, failed and skipped results in an attack

Args:
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`):
Attack results for each instance in dataset
"""

def __init__(self, results):
self.results = results
self.failed_attacks = 0
self.skipped_attacks = 0
self.successful_attacks = 0
self.total_attacks = len(self.results)

self.all_metrics = {}

def calculate(self):
for i, result in enumerate(self.results):
if isinstance(result, FailedAttackResult):
self.failed_attacks += 1
continue
elif isinstance(result, SkippedAttackResult):
self.skipped_attacks += 1
continue
else:
self.successful_attacks += 1

# Calculated numbers
self.all_metrics['successful_attacks'] = self.successful_attacks
self.all_metrics['failed_attacks'] = self.failed_attacks
self.all_metrics['skipped_attacks'] = self.skipped_attacks

# Percentages wrt the calculations
self.all_metrics['original_accuracy'] = self.original_accuracy_perc()
self.all_metrics['attack_accuracy_perc'] = self.attack_accuracy_perc()
self.all_metrics['attack_success_rate'] = self.attack_success_rate_perc()

return self.all_metrics


def original_accuracy_perc(self):
original_accuracy = (
(self.total_attacks - self.skipped_attacks) * 100.0 / (self.total_attacks)
)
original_accuracy = round(original_accuracy, 2)
return original_accuracy

def attack_accuracy_perc(self):
accuracy_under_attack = (self.failed_attacks) * 100.0 / (self.total_attacks)
accuracy_under_attack = round(accuracy_under_attack, 2)
return accuracy_under_attack

def attack_success_rate_perc(self):
if self.successful_attacks + self.failed_attacks == 0:
attack_success_rate = 0
else:
attack_success_rate = (
self.successful_attacks
* 100.0
/ (self.successful_attacks + self.failed_attacks)
)
attack_success_rate = round(attack_success_rate, 2)
return attack_success_rate
65 changes: 65 additions & 0 deletions textattack/metrics/attack_metrics/words_perturbed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np

from textattack.attack_results import FailedAttackResult, SkippedAttackResult

from .attack_metric import AttackMetric


class WordsPerturbed(AttackMetric):
def __init__(self, results):
self.results = results
self.total_attacks = len(self.results)
self.all_num_words = np.zeros(len(self.results))
self.perturbed_word_percentages = np.zeros(len(self.results))
self.num_words_changed_until_success = np.zeros(2 ** 16)
self.all_metrics = {}

def calculate(self):
self.max_words_changed = 0
for i, result in enumerate(self.results):
self.all_num_words[i] = len(result.original_result.attacked_text.words)

if isinstance(result, FailedAttackResult) or isinstance(
result, SkippedAttackResult
):
continue

num_words_changed = len(
result.original_result.attacked_text.all_words_diff(
result.perturbed_result.attacked_text
)
)
self.num_words_changed_until_success[num_words_changed - 1] += 1
self.max_words_changed = max(
self.max_words_changed or num_words_changed, num_words_changed
)
if len(result.original_result.attacked_text.words) > 0:
perturbed_word_percentage = (
num_words_changed
* 100.0
/ len(result.original_result.attacked_text.words)
)
else:
perturbed_word_percentage = 0

self.perturbed_word_percentages[i] = perturbed_word_percentage

self.all_metrics['avg_word_perturbed'] = self.avg_number_word_perturbed_num()
self.all_metrics['avg_word_perturbed_perc'] = self.avg_perturbation_perc()
self.all_metrics['max_words_changed'] = self.max_words_changed
self.all_metrics['num_words_changed_until_success'] = self.num_words_changed_until_success

return self.all_metrics

def avg_number_word_perturbed_num(self):
average_num_words = self.all_num_words.mean()
average_num_words = round(average_num_words, 2)
return average_num_words

def avg_perturbation_perc(self):
self.perturbed_word_percentages = self.perturbed_word_percentages[
self.perturbed_word_percentages > 0
]
average_perc_words_perturbed = self.perturbed_word_percentages.mean()
average_perc_words_perturbed = round(average_perc_words_perturbed, 2)
return average_perc_words_perturbed
12 changes: 12 additions & 0 deletions textattack/metrics/quality_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""

attack_metrics:
======================

TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display.


"""

from .quality_metric import QualityMetric

Empty file.