Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,24 @@ class ChineseWordSwapMaskedLM(WordSwap):
model."""

def __init__(self, task="fill-mask", model="xlm-roberta-base", **kwargs):
self.unmasker = pipeline(task, model)
from transformers import BertTokenizer, BertForMaskedLM
import torch
self.tt = BertTokenizer.from_pretrained(model)
self.mm = BertForMaskedLM.from_pretrained(model)
self.mm.to("cuda")
super().__init__(**kwargs)

def get_replacement_words(self, current_text, indice_to_modify):
masked_text = current_text.replace_word_at_index(indice_to_modify, "<mask>")
outputs = self.unmasker(masked_text.text)
words = []
for dict in outputs:
take = True
for char in dict["token_str"]:
# accept only Chinese characters for potential substitutions
if not is_cjk(char):
take = False
if take:
words.append(dict["token_str"])

return words
masked_text = current_text.replace_word_at_index(indice_to_modify, "[MASK]") # 修改前<mask>,xlmrberta的模型
tokens = self.tt.tokenize(masked_text.text)
input_ids = self.tt.convert_tokens_to_ids(tokens)
input_tensor = torch.tensor([input_ids]).to("cuda")
with torch.no_grad():
outputs = self.mm(input_tensor)
predictions = outputs.logits
predicted_token_ids = torch.argsort(predictions[0, indice_to_modify], descending=True)[:50]
predicted_tokens = self.tt.convert_ids_to_tokens(predicted_token_ids.tolist()[1:])
return predicted_tokens

def _get_transformations(self, current_text, indices_to_modify):
words = current_text.words
Expand Down