diff --git a/textattack/transformations/word_swaps/chn_transformations/chinese_word_swap_masked.py b/textattack/transformations/word_swaps/chn_transformations/chinese_word_swap_masked.py index 95cb9c3d..a0b33096 100644 --- a/textattack/transformations/word_swaps/chn_transformations/chinese_word_swap_masked.py +++ b/textattack/transformations/word_swaps/chn_transformations/chinese_word_swap_masked.py @@ -13,23 +13,24 @@ class ChineseWordSwapMaskedLM(WordSwap): model.""" def __init__(self, task="fill-mask", model="xlm-roberta-base", **kwargs): - self.unmasker = pipeline(task, model) + from transformers import BertTokenizer, BertForMaskedLM + import torch + self.tt = BertTokenizer.from_pretrained(model) + self.mm = BertForMaskedLM.from_pretrained(model) + self.mm.to("cuda") super().__init__(**kwargs) def get_replacement_words(self, current_text, indice_to_modify): - masked_text = current_text.replace_word_at_index(indice_to_modify, "") - outputs = self.unmasker(masked_text.text) - words = [] - for dict in outputs: - take = True - for char in dict["token_str"]: - # accept only Chinese characters for potential substitutions - if not is_cjk(char): - take = False - if take: - words.append(dict["token_str"]) - - return words + masked_text = current_text.replace_word_at_index(indice_to_modify, "[MASK]") # 修改前,xlmrberta的模型 + tokens = self.tt.tokenize(masked_text.text) + input_ids = self.tt.convert_tokens_to_ids(tokens) + input_tensor = torch.tensor([input_ids]).to("cuda") + with torch.no_grad(): + outputs = self.mm(input_tensor) + predictions = outputs.logits + predicted_token_ids = torch.argsort(predictions[0, indice_to_modify], descending=True)[:50] + predicted_tokens = self.tt.convert_ids_to_tokens(predicted_token_ids.tolist()[1:]) + return predicted_tokens def _get_transformations(self, current_text, indices_to_modify): words = current_text.words